diff --git a/config.py b/config.py index 521dae3..483c629 100644 --- a/config.py +++ b/config.py @@ -26,6 +26,6 @@ class Config: # rgb模型参数 rgb_tobacco_model_path = r"models/tobacco_dt_2022-07-26_15-57.model" rgb_background_model_path = r"models/background_dt_2022-07-27_08-11.model" - threshold_low, threshold_high = 5, 230 + threshold_low, threshold_high = 10, 230 threshold_s = 175 rgb_size_threshold = 4 diff --git a/main.py b/main.py old mode 100644 new mode 100755 index 3f5a818..32602f9 --- a/main.py +++ b/main.py @@ -8,6 +8,9 @@ from models import RgbDetector, SpecDetector, ManualTree, AnonymousColorDetector import cv2 +SAVE_IMG, SAVE_NUM = False, 30 + + def main(): spec_detector = SpecDetector(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path) rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path, @@ -20,39 +23,52 @@ def main(): os.mkfifo(mask_fifo_path, 0o777) if not os.access(rgb_fifo_path, os.F_OK): os.mkfifo(rgb_fifo_path, 0o777) + if SAVE_IMG: + img_list = [] while True: fd_img = os.open(img_fifo_path, os.O_RDONLY) fd_rgb = os.open(rgb_fifo_path, os.O_RDONLY) + + # spec data read data = os.read(fd_img, total_len) - # 读取(开启一个管道) if len(data) < 3: threshold = int(float(data)) Config.spec_size_threshold = threshold - print("[INFO] Get threshold: ", threshold) - continue + print("[INFO] Get spec threshold: ", threshold) else: data_total = data - rgb_data = os.read(fd_rgb, total_rgb) - if len(rgb_data) < 3: - rgb_threshold = int(float(rgb_data)) - Config.rgb_size_threshold = rgb_threshold - print(rgb_threshold) - continue - else: - rgb_data_total = rgb_data os.close(fd_img) + + # rgb data read + rgb_data = os.read(fd_rgb, total_rgb) + if len(rgb_data) < 3: + rgb_threshold = int(float(rgb_data)) + Config.rgb_size_threshold = rgb_threshold + print("[INFO] Get rgb threshold", rgb_threshold) + continue + else: + rgb_data_total = rgb_data os.close(fd_rgb) + # 识别 t1 = time.time() img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)) \ .transpose(0, 2, 1) rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1)) + if SAVE_IMG: + SAVE_NUM -= 1 + img_list.append((rgb_data, img_data)) + if SAVE_NUM <= 0: + break # 光谱识别 mask = spec_detector.predict(img_data) # rgb识别 mask_rgb = rgb_detector.predict(rgb_data) + # 结果合并 mask_result = (mask | mask_rgb).astype(np.uint8) + + # mask_result = mask_rgb.astype(np.uint8) mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8) t2 = time.time() print(f'rgb len = {len(rgb_data)}') @@ -63,6 +79,12 @@ def main(): os.close(fd_mask) t3 = time.time() print(f'total time is:{t3 - t1}') + for i, img in enumerate(img_list): + print(f"writing img {i}...") + cv2.imwrite(f"./{i}.png", img[0][..., ::-1]) + np.save(f'./{i}.npy', img[1]) + i += 1 + def save_main(): @@ -131,8 +153,8 @@ def save_main(): .sum(axis=1) mask[mask <= threshold] = 0 mask[mask > threshold] = 1 - # mask_result = (mask | mask_rgb).astype(np.uint8) - mask_result = mask_rgb + mask_result = (mask | mask_rgb).astype(np.uint8) + # mask_result = mask_rgb mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8) t2 = time.time() print(f'rgb len = {len(rgb_data)}') diff --git a/models.py b/models.py old mode 100644 new mode 100755 index 5716f15..95d7f4a --- a/models.py +++ b/models.py @@ -17,7 +17,7 @@ from sklearn.model_selection import train_test_split from config import Config from utils import lab_scatter, read_labeled_img, size_threshold -deploy = False +deploy = True if not deploy: print("Training env") from tqdm import tqdm @@ -415,7 +415,7 @@ class SpecDetector(Detector): # 烟梗mask中将背景赋值为0,将烟梗赋值为2 yellow_things[yellow_things] = tobacco yellow_things = yellow_things + 0 - yellow_things = binary_dilation(yellow_things, iterations=iteration) + # yellow_things = binary_dilation(yellow_things, iterations=iteration) yellow_things = yellow_things + 0 yellow_things[yellow_things == 1] = 2 diff --git a/utils.py b/utils.py old mode 100644 new mode 100755