diff --git a/config.py b/config.py index cfc6424..ff0b31e 100644 --- a/config.py +++ b/config.py @@ -26,6 +26,9 @@ class Config: # blk_model_path = r"/home/dt/tobacco-color/weights/rf_4x4_c22_20_sen8_9.model" # 机器上部署的路径 spec_size_threshold = 3 + s_threshold_a = 125 # s_a的最高允许值 + s_threshold_b = 125 # s_b的最高允许值 + # rgb模型参数 rgb_tobacco_model_path = r"weights/tobacco_dt_2022-08-27_14-43.model" # 开发时的路径 # rgb_tobacco_model_path = r"/home/dt/tobacco-color/weights/tobacco_dt_2022-08-27_14-43.model" # 机器上部署的路径 diff --git a/models/__init__.py b/models/__init__.py index b52b2e5..06c125f 100755 --- a/models/__init__.py +++ b/models/__init__.py @@ -337,7 +337,7 @@ class RgbDetector(Detector): lab_b = cv2.cvtColor(rgb_data, cv2.COLOR_RGB2LAB)[..., 2] < Config.threshold_b lab_predict_result = lab_a | lab_b mask_lab = size_threshold(lab_predict_result, Config.blk_size, Config.lab_size_threshold) - mask_rgb = mask_rgb | mask_lab + mask_rgb = mask_rgb.astype(np.uint8) | mask_lab.astype(np.uint8) # # 测试时间 # end = time.time() # print("lab time: ", end - start) @@ -370,9 +370,19 @@ class SpecDetector(Detector): pixel_predict_result = self.pixel_predict_ml_dilation(data=img_data, iteration=1) blk_predict_result = self.blk_predict(data=img_data) mask = (pixel_predict_result & blk_predict_result).astype(np.uint8) + spec_cv = np.clip(img_data[..., [21, 3, 0]], a_min=0, a_max=1) * 255 + spec_cv = spec_cv.astype(np.uint8) + # spec转lab 提取a通道,识别绿色杂质 + lab_a = cv2.cvtColor(spec_cv, cv2.COLOR_RGB2LAB)[..., 1] < Config.s_threshold_a + # spec转lab 提取b通道,识别蓝色杂质 + lab_b = cv2.cvtColor(spec_cv, cv2.COLOR_RGB2LAB)[..., 2] < Config.s_threshold_b + lab_predict_result = lab_a | lab_b + mask_lab = size_threshold(lab_predict_result, Config.blk_size, Config.lab_size_threshold) + if save_part: self.spare_part = mask[-(Config.blk_size//2):, :] mask = size_threshold(mask, Config.blk_size, Config.spec_size_threshold, self.spare_part) + mask = mask.astype(np.uint8) | mask_lab.astype(np.uint8) return mask def save(self, *args, **kwargs):