From 37d8dc43b08895466e21f428df6b306cecb814fd Mon Sep 17 00:00:00 2001
From: "li.zhenye"
Date: Fri, 5 Aug 2022 11:53:33 +0800
Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E5=A4=A7=E5=B0=8F=E6=BB=A4?=
=?UTF-8?q?=E6=B3=A2?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
config.py | 2 +-
main.py | 1 +
models.py | 7 +++++--
utils.py | 12 ++++++++++--
4 files changed, 17 insertions(+), 5 deletions(-)
diff --git a/config.py b/config.py
index 1bd2e56..0b2782e 100644
--- a/config.py
+++ b/config.py
@@ -18,7 +18,7 @@ class Config:
green_bands = [i for i in range(1, 21)]
# 光谱模型参数
- blk_size = 4
+ blk_size = 4 # 必须是2的倍数,不然会出错
pixel_model_path = r"./models/pixel_2022-08-02_15-22.model"
blk_model_path = r"./models/rf_4x4_c22_20_sen8_9.model"
spec_size_threshold = 3
diff --git a/main.py b/main.py
index b9a25db..7130fb8 100755
--- a/main.py
+++ b/main.py
@@ -23,6 +23,7 @@ def main(only_spec=False, only_color=False):
os.mkfifo(mask_fifo_path, 0o777)
if not os.access(rgb_mask_fifo_path, os.F_OK):
os.mkfifo(rgb_mask_fifo_path, 0o777)
+
while True:
fd_img = os.open(img_fifo_path, os.O_RDONLY)
fd_rgb = os.open(rgb_fifo_path, os.O_RDONLY)
diff --git a/models.py b/models.py
index 9af462f..da48a54 100755
--- a/models.py
+++ b/models.py
@@ -339,16 +339,19 @@ class SpecDetector(Detector):
self.blk_model = None
self.pixel_model_ml = None
self.load(blk_model_path, pixel_model_path)
+ self.spare_part = np.zeros((Config.blk_size//2, Config.nCols))
def load(self, blk_model_path, pixel_model_path):
self.pixel_model_ml = PixelModelML(pixel_model_path)
self.blk_model = BlkModel(blk_model_path)
- def predict(self, img_data):
+ def predict(self, img_data: np.ndarray, save_part: bool = False) -> np.ndarray:
pixel_predict_result = self.pixel_predict_ml_dilation(data=img_data, iteration=1)
blk_predict_result = self.blk_predict(data=img_data)
mask = (pixel_predict_result & blk_predict_result).astype(np.uint8)
- mask = size_threshold(mask, Config.blk_size, Config.spec_size_threshold)
+ if save_part:
+ self.spare_part = mask[-(Config.blk_size//2):, :]
+ mask = size_threshold(mask, Config.blk_size, Config.spec_size_threshold, self.spare_part)
return mask
def save(self, *args, **kwargs):
diff --git a/utils.py b/utils.py
index b067976..a19fa22 100755
--- a/utils.py
+++ b/utils.py
@@ -132,13 +132,21 @@ def lab_scatter(dataset: dict, class_max_num=None, is_3d=False, is_ps_color_spac
plt.show()
-def size_threshold(img, blk_size, threshold, last_end: np.ndarray=None) -> np.ndarray:
+def size_threshold(img, blk_size, threshold, last_end: np.ndarray = None) -> np.ndarray:
mask = img.reshape(img.shape[0], img.shape[1] // blk_size, blk_size).sum(axis=2). \
reshape(img.shape[0] // blk_size, blk_size, img.shape[1] // blk_size).sum(axis=1)
mask[mask <= threshold] = 0
mask[mask > threshold] = 1
if last_end is not None:
- mask_up = np.concatenate((last_end, img[:-(blk_size//2), :]))
+ half_blk_size = blk_size // 2
+ assert (last_end.shape[0] == half_blk_size) and (last_end.shape[1] == img.shape[1])
+ mask_up = np.concatenate((last_end, img[:-half_blk_size, :]), axis=0)
+ mask_up_right = np.concatenate((mask_up[:, half_blk_size:],
+ np.zeros((img.shape[0], half_blk_size), dtype=np.uint8)), axis=1)
+ mask_up = size_threshold(mask_up, blk_size, threshold)
+ mask_up_right = size_threshold(mask_up_right, blk_size, threshold)
+ mask[:-1, :] = mask_up[1:, :]
+ mask[:-1, 1:] = mask_up_right[1:, :-1]
return mask