diff --git a/main.py b/main.py index 175ed84..aabfd67 100755 --- a/main.py +++ b/main.py @@ -67,17 +67,17 @@ def main(only_spec=False, only_color=False): logging.error(f'毁灭性错误!收到的rgb数据长度为{len(rgb_data)}无法转化成指定形状 {e}') if only_spec: # 光谱识别 - mask_spec = spec_detector.predict(img_data) - mask_rgb = rgb_detector.predict(rgb_data) + mask_spec = spec_detector.predict(img_data).astype(np.uint8) + _ = rgb_detector.predict(rgb_data) mask_rgb = np.zeros_like(mask_spec, dtype=np.uint8) elif only_color: # rgb识别 - mask_spec = spec_detector.predict(img_data) - mask_rgb = rgb_detector.predict(rgb_data) + _ = spec_detector.predict(img_data) + mask_rgb = rgb_detector.predict(rgb_data).astype(np.uint8) mask_spec = np.zeros_like(mask_rgb, dtype=np.uint8) else: - mask_spec = spec_detector.predict(img_data) - mask_rgb = rgb_detector.predict(rgb_data) + mask_spec = spec_detector.predict(img_data).astype(np.uint8) + mask_rgb = rgb_detector.predict(rgb_data).astype(np.uint8) # 进行喷阀的合并 masks = [utils.valve_expend(mask) for mask in [mask_spec, mask_rgb]] # control the size of the output masks, 在resize前,图像的宽度是和喷阀对应的 diff --git a/utils.py b/utils.py index 57183cb..805a757 100755 --- a/utils.py +++ b/utils.py @@ -162,7 +162,8 @@ def valve_merge(img: np.ndarray, merge_size: int = 2) -> np.ndarray: def valve_expend(img: np.ndarray) -> np.ndarray: kernel = np.ones((1, 3), np.uint8) - return cv2.dilate(img, kernel) + img = cv2.dilate(img, kernel, iterations=1) + return img def read_envi_ascii(file_name, save_xy=False, hdr_file_name=None): @@ -221,11 +222,13 @@ def read_envi_ascii(file_name, save_xy=False, hdr_file_name=None): if __name__ == '__main__': - color_dict = {(0, 0, 255): "yangeng", (255, 0, 0): "bejing", (0, 255, 0): "hongdianxian", - (255, 0, 255): "chengsebangbangtang", (0, 255, 255): "lvdianxian"} - dataset = read_labeled_img("data/dataset", color_dict=color_dict, is_ps_color_space=False) - lab_scatter(dataset, class_max_num=20000, is_3d=False, is_ps_color_space=False) - # a = np.array([[1, 1, 0, 0, 1, 0, 0, 1], [0, 0, 1, 0, 0, 1, 1, 1]]).astype(np.uint8) + # color_dict = {(0, 0, 255): "yangeng", (255, 0, 0): "bejing", (0, 255, 0): "hongdianxian", + # (255, 0, 255): "chengsebangbangtang", (0, 255, 255): "lvdianxian"} + # dataset = read_labeled_img("data/dataset", color_dict=color_dict, is_ps_color_space=False) + # lab_scatter(dataset, class_max_num=20000, is_3d=False, is_ps_color_space=False) + a = np.array([[1, 1, 0, 0, 1, 0, 0, 1], [0, 0, 1, 0, 0, 1, 1, 1]]).astype(np.uint8) # a.repeat(3, axis=0) # b = valve_merge(a, 2) - # print(b) \ No newline at end of file + # print(b) + c = valve_expend(a) + print(c)