修改了float类型报错

This commit is contained in:
FEIJINTI 2022-08-06 14:57:13 +08:00
parent 30041c1aed
commit 49b2ea4be7
2 changed files with 16 additions and 13 deletions

12
main.py
View File

@ -67,17 +67,17 @@ def main(only_spec=False, only_color=False):
logging.error(f'毁灭性错误!收到的rgb数据长度为{len(rgb_data)}无法转化成指定形状 {e}') logging.error(f'毁灭性错误!收到的rgb数据长度为{len(rgb_data)}无法转化成指定形状 {e}')
if only_spec: if only_spec:
# 光谱识别 # 光谱识别
mask_spec = spec_detector.predict(img_data) mask_spec = spec_detector.predict(img_data).astype(np.uint8)
mask_rgb = rgb_detector.predict(rgb_data) _ = rgb_detector.predict(rgb_data)
mask_rgb = np.zeros_like(mask_spec, dtype=np.uint8) mask_rgb = np.zeros_like(mask_spec, dtype=np.uint8)
elif only_color: elif only_color:
# rgb识别 # rgb识别
mask_spec = spec_detector.predict(img_data) _ = spec_detector.predict(img_data)
mask_rgb = rgb_detector.predict(rgb_data) mask_rgb = rgb_detector.predict(rgb_data).astype(np.uint8)
mask_spec = np.zeros_like(mask_rgb, dtype=np.uint8) mask_spec = np.zeros_like(mask_rgb, dtype=np.uint8)
else: else:
mask_spec = spec_detector.predict(img_data) mask_spec = spec_detector.predict(img_data).astype(np.uint8)
mask_rgb = rgb_detector.predict(rgb_data) mask_rgb = rgb_detector.predict(rgb_data).astype(np.uint8)
# 进行喷阀的合并 # 进行喷阀的合并
masks = [utils.valve_expend(mask) for mask in [mask_spec, mask_rgb]] masks = [utils.valve_expend(mask) for mask in [mask_spec, mask_rgb]]
# control the size of the output masks, 在resize前图像的宽度是和喷阀对应的 # control the size of the output masks, 在resize前图像的宽度是和喷阀对应的

View File

@ -162,7 +162,8 @@ def valve_merge(img: np.ndarray, merge_size: int = 2) -> np.ndarray:
def valve_expend(img: np.ndarray) -> np.ndarray: def valve_expend(img: np.ndarray) -> np.ndarray:
kernel = np.ones((1, 3), np.uint8) kernel = np.ones((1, 3), np.uint8)
return cv2.dilate(img, kernel) img = cv2.dilate(img, kernel, iterations=1)
return img
def read_envi_ascii(file_name, save_xy=False, hdr_file_name=None): def read_envi_ascii(file_name, save_xy=False, hdr_file_name=None):
@ -221,11 +222,13 @@ def read_envi_ascii(file_name, save_xy=False, hdr_file_name=None):
if __name__ == '__main__': if __name__ == '__main__':
color_dict = {(0, 0, 255): "yangeng", (255, 0, 0): "bejing", (0, 255, 0): "hongdianxian", # color_dict = {(0, 0, 255): "yangeng", (255, 0, 0): "bejing", (0, 255, 0): "hongdianxian",
(255, 0, 255): "chengsebangbangtang", (0, 255, 255): "lvdianxian"} # (255, 0, 255): "chengsebangbangtang", (0, 255, 255): "lvdianxian"}
dataset = read_labeled_img("data/dataset", color_dict=color_dict, is_ps_color_space=False) # dataset = read_labeled_img("data/dataset", color_dict=color_dict, is_ps_color_space=False)
lab_scatter(dataset, class_max_num=20000, is_3d=False, is_ps_color_space=False) # lab_scatter(dataset, class_max_num=20000, is_3d=False, is_ps_color_space=False)
# a = np.array([[1, 1, 0, 0, 1, 0, 0, 1], [0, 0, 1, 0, 0, 1, 1, 1]]).astype(np.uint8) a = np.array([[1, 1, 0, 0, 1, 0, 0, 1], [0, 0, 1, 0, 0, 1, 1, 1]]).astype(np.uint8)
# a.repeat(3, axis=0) # a.repeat(3, axis=0)
# b = valve_merge(a, 2) # b = valve_merge(a, 2)
# print(b) # print(b)
c = valve_expend(a)
print(c)