From 3d720d5c2ba0b05550d143d3f6a84a0e0c87f714 Mon Sep 17 00:00:00 2001 From: "li.zhenye" <李> Date: Wed, 27 Jul 2022 13:08:19 +0800 Subject: [PATCH] 022.7.27 complete version --- config.py | 8 +++++--- main.py | 48 +++++++++++++++++++++++++++++++++++------------- models.py | 4 ++-- utils.py | 0 4 files changed, 42 insertions(+), 18 deletions(-) mode change 100644 => 100755 config.py mode change 100644 => 100755 main.py mode change 100644 => 100755 models.py mode change 100644 => 100755 utils.py diff --git a/config.py b/config.py old mode 100644 new mode 100755 index 6e8d4e1..483c629 --- a/config.py +++ b/config.py @@ -4,8 +4,8 @@ import numpy as np class Config: # 文件相关参数 - nRows, nCols, nBands, spec_size_threshold = 256, 1024, 22, 3 - nRgbRows, nRgbCols, nRgbBands, rgb_size_threshold = 1024, 4096, 3, 4 + nRows, nCols, nBands = 256, 1024, 22 + nRgbRows, nRgbCols, nRgbBands = 1024, 4096, 3 # 需要设置的谱段等参数 selected_bands = [127, 201, 202, 294] @@ -21,9 +21,11 @@ class Config: blk_size = 4 pixel_model_path = r"./models/dt.p" blk_model_path = r"./models/rf_4x4_c22_20_sen8_8.model" + spec_size_threshold = 3 # rgb模型参数 rgb_tobacco_model_path = r"models/tobacco_dt_2022-07-26_15-57.model" rgb_background_model_path = r"models/background_dt_2022-07-27_08-11.model" - threshold_low, threshold_high = 5, 255 + threshold_low, threshold_high = 10, 230 threshold_s = 175 + rgb_size_threshold = 4 diff --git a/main.py b/main.py old mode 100644 new mode 100755 index 3f5a818..32602f9 --- a/main.py +++ b/main.py @@ -8,6 +8,9 @@ from models import RgbDetector, SpecDetector, ManualTree, AnonymousColorDetector import cv2 +SAVE_IMG, SAVE_NUM = False, 30 + + def main(): spec_detector = SpecDetector(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path) rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path, @@ -20,39 +23,52 @@ def main(): os.mkfifo(mask_fifo_path, 0o777) if not os.access(rgb_fifo_path, os.F_OK): os.mkfifo(rgb_fifo_path, 0o777) + if SAVE_IMG: + img_list = [] while True: fd_img = os.open(img_fifo_path, os.O_RDONLY) fd_rgb = os.open(rgb_fifo_path, os.O_RDONLY) + + # spec data read data = os.read(fd_img, total_len) - # 读取(开启一个管道) if len(data) < 3: threshold = int(float(data)) Config.spec_size_threshold = threshold - print("[INFO] Get threshold: ", threshold) - continue + print("[INFO] Get spec threshold: ", threshold) else: data_total = data - rgb_data = os.read(fd_rgb, total_rgb) - if len(rgb_data) < 3: - rgb_threshold = int(float(rgb_data)) - Config.rgb_size_threshold = rgb_threshold - print(rgb_threshold) - continue - else: - rgb_data_total = rgb_data os.close(fd_img) + + # rgb data read + rgb_data = os.read(fd_rgb, total_rgb) + if len(rgb_data) < 3: + rgb_threshold = int(float(rgb_data)) + Config.rgb_size_threshold = rgb_threshold + print("[INFO] Get rgb threshold", rgb_threshold) + continue + else: + rgb_data_total = rgb_data os.close(fd_rgb) + # 识别 t1 = time.time() img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)) \ .transpose(0, 2, 1) rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1)) + if SAVE_IMG: + SAVE_NUM -= 1 + img_list.append((rgb_data, img_data)) + if SAVE_NUM <= 0: + break # 光谱识别 mask = spec_detector.predict(img_data) # rgb识别 mask_rgb = rgb_detector.predict(rgb_data) + # 结果合并 mask_result = (mask | mask_rgb).astype(np.uint8) + + # mask_result = mask_rgb.astype(np.uint8) mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8) t2 = time.time() print(f'rgb len = {len(rgb_data)}') @@ -63,6 +79,12 @@ def main(): os.close(fd_mask) t3 = time.time() print(f'total time is:{t3 - t1}') + for i, img in enumerate(img_list): + print(f"writing img {i}...") + cv2.imwrite(f"./{i}.png", img[0][..., ::-1]) + np.save(f'./{i}.npy', img[1]) + i += 1 + def save_main(): @@ -131,8 +153,8 @@ def save_main(): .sum(axis=1) mask[mask <= threshold] = 0 mask[mask > threshold] = 1 - # mask_result = (mask | mask_rgb).astype(np.uint8) - mask_result = mask_rgb + mask_result = (mask | mask_rgb).astype(np.uint8) + # mask_result = mask_rgb mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8) t2 = time.time() print(f'rgb len = {len(rgb_data)}') diff --git a/models.py b/models.py old mode 100644 new mode 100755 index 5716f15..95d7f4a --- a/models.py +++ b/models.py @@ -17,7 +17,7 @@ from sklearn.model_selection import train_test_split from config import Config from utils import lab_scatter, read_labeled_img, size_threshold -deploy = False +deploy = True if not deploy: print("Training env") from tqdm import tqdm @@ -415,7 +415,7 @@ class SpecDetector(Detector): # 烟梗mask中将背景赋值为0,将烟梗赋值为2 yellow_things[yellow_things] = tobacco yellow_things = yellow_things + 0 - yellow_things = binary_dilation(yellow_things, iterations=iteration) + # yellow_things = binary_dilation(yellow_things, iterations=iteration) yellow_things = yellow_things + 0 yellow_things[yellow_things == 1] = 2 diff --git a/utils.py b/utils.py old mode 100644 new mode 100755