From 37d8dc43b08895466e21f428df6b306cecb814fd Mon Sep 17 00:00:00 2001 From: "li.zhenye" Date: Fri, 5 Aug 2022 11:53:33 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E5=A4=A7=E5=B0=8F=E6=BB=A4?= =?UTF-8?q?=E6=B3=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.py | 2 +- main.py | 1 + models.py | 7 +++++-- utils.py | 12 ++++++++++-- 4 files changed, 17 insertions(+), 5 deletions(-) diff --git a/config.py b/config.py index 1bd2e56..0b2782e 100644 --- a/config.py +++ b/config.py @@ -18,7 +18,7 @@ class Config: green_bands = [i for i in range(1, 21)] # 光谱模型参数 - blk_size = 4 + blk_size = 4 # 必须是2的倍数,不然会出错 pixel_model_path = r"./models/pixel_2022-08-02_15-22.model" blk_model_path = r"./models/rf_4x4_c22_20_sen8_9.model" spec_size_threshold = 3 diff --git a/main.py b/main.py index b9a25db..7130fb8 100755 --- a/main.py +++ b/main.py @@ -23,6 +23,7 @@ def main(only_spec=False, only_color=False): os.mkfifo(mask_fifo_path, 0o777) if not os.access(rgb_mask_fifo_path, os.F_OK): os.mkfifo(rgb_mask_fifo_path, 0o777) + while True: fd_img = os.open(img_fifo_path, os.O_RDONLY) fd_rgb = os.open(rgb_fifo_path, os.O_RDONLY) diff --git a/models.py b/models.py index 9af462f..da48a54 100755 --- a/models.py +++ b/models.py @@ -339,16 +339,19 @@ class SpecDetector(Detector): self.blk_model = None self.pixel_model_ml = None self.load(blk_model_path, pixel_model_path) + self.spare_part = np.zeros((Config.blk_size//2, Config.nCols)) def load(self, blk_model_path, pixel_model_path): self.pixel_model_ml = PixelModelML(pixel_model_path) self.blk_model = BlkModel(blk_model_path) - def predict(self, img_data): + def predict(self, img_data: np.ndarray, save_part: bool = False) -> np.ndarray: pixel_predict_result = self.pixel_predict_ml_dilation(data=img_data, iteration=1) blk_predict_result = self.blk_predict(data=img_data) mask = (pixel_predict_result & blk_predict_result).astype(np.uint8) - mask = size_threshold(mask, Config.blk_size, Config.spec_size_threshold) + if save_part: + self.spare_part = mask[-(Config.blk_size//2):, :] + mask = size_threshold(mask, Config.blk_size, Config.spec_size_threshold, self.spare_part) return mask def save(self, *args, **kwargs): diff --git a/utils.py b/utils.py index b067976..a19fa22 100755 --- a/utils.py +++ b/utils.py @@ -132,13 +132,21 @@ def lab_scatter(dataset: dict, class_max_num=None, is_3d=False, is_ps_color_spac plt.show() -def size_threshold(img, blk_size, threshold, last_end: np.ndarray=None) -> np.ndarray: +def size_threshold(img, blk_size, threshold, last_end: np.ndarray = None) -> np.ndarray: mask = img.reshape(img.shape[0], img.shape[1] // blk_size, blk_size).sum(axis=2). \ reshape(img.shape[0] // blk_size, blk_size, img.shape[1] // blk_size).sum(axis=1) mask[mask <= threshold] = 0 mask[mask > threshold] = 1 if last_end is not None: - mask_up = np.concatenate((last_end, img[:-(blk_size//2), :])) + half_blk_size = blk_size // 2 + assert (last_end.shape[0] == half_blk_size) and (last_end.shape[1] == img.shape[1]) + mask_up = np.concatenate((last_end, img[:-half_blk_size, :]), axis=0) + mask_up_right = np.concatenate((mask_up[:, half_blk_size:], + np.zeros((img.shape[0], half_blk_size), dtype=np.uint8)), axis=1) + mask_up = size_threshold(mask_up, blk_size, threshold) + mask_up_right = size_threshold(mask_up_right, blk_size, threshold) + mask[:-1, :] = mask_up[1:, :] + mask[:-1, 1:] = mask_up_right[1:, :-1] return mask