修改了float类型报错

This commit is contained in:
FEIJINTI 2022-08-06 14:57:13 +08:00
parent 30041c1aed
commit 49b2ea4be7
2 changed files with 16 additions and 13 deletions

12
main.py
View File

@ -67,17 +67,17 @@ def main(only_spec=False, only_color=False):
logging.error(f'毁灭性错误!收到的rgb数据长度为{len(rgb_data)}无法转化成指定形状 {e}')
if only_spec:
# 光谱识别
mask_spec = spec_detector.predict(img_data)
mask_rgb = rgb_detector.predict(rgb_data)
mask_spec = spec_detector.predict(img_data).astype(np.uint8)
_ = rgb_detector.predict(rgb_data)
mask_rgb = np.zeros_like(mask_spec, dtype=np.uint8)
elif only_color:
# rgb识别
mask_spec = spec_detector.predict(img_data)
mask_rgb = rgb_detector.predict(rgb_data)
_ = spec_detector.predict(img_data)
mask_rgb = rgb_detector.predict(rgb_data).astype(np.uint8)
mask_spec = np.zeros_like(mask_rgb, dtype=np.uint8)
else:
mask_spec = spec_detector.predict(img_data)
mask_rgb = rgb_detector.predict(rgb_data)
mask_spec = spec_detector.predict(img_data).astype(np.uint8)
mask_rgb = rgb_detector.predict(rgb_data).astype(np.uint8)
# 进行喷阀的合并
masks = [utils.valve_expend(mask) for mask in [mask_spec, mask_rgb]]
# control the size of the output masks, 在resize前图像的宽度是和喷阀对应的

View File

@ -162,7 +162,8 @@ def valve_merge(img: np.ndarray, merge_size: int = 2) -> np.ndarray:
def valve_expend(img: np.ndarray) -> np.ndarray:
kernel = np.ones((1, 3), np.uint8)
return cv2.dilate(img, kernel)
img = cv2.dilate(img, kernel, iterations=1)
return img
def read_envi_ascii(file_name, save_xy=False, hdr_file_name=None):
@ -221,11 +222,13 @@ def read_envi_ascii(file_name, save_xy=False, hdr_file_name=None):
if __name__ == '__main__':
color_dict = {(0, 0, 255): "yangeng", (255, 0, 0): "bejing", (0, 255, 0): "hongdianxian",
(255, 0, 255): "chengsebangbangtang", (0, 255, 255): "lvdianxian"}
dataset = read_labeled_img("data/dataset", color_dict=color_dict, is_ps_color_space=False)
lab_scatter(dataset, class_max_num=20000, is_3d=False, is_ps_color_space=False)
# a = np.array([[1, 1, 0, 0, 1, 0, 0, 1], [0, 0, 1, 0, 0, 1, 1, 1]]).astype(np.uint8)
# color_dict = {(0, 0, 255): "yangeng", (255, 0, 0): "bejing", (0, 255, 0): "hongdianxian",
# (255, 0, 255): "chengsebangbangtang", (0, 255, 255): "lvdianxian"}
# dataset = read_labeled_img("data/dataset", color_dict=color_dict, is_ps_color_space=False)
# lab_scatter(dataset, class_max_num=20000, is_3d=False, is_ps_color_space=False)
a = np.array([[1, 1, 0, 0, 1, 0, 0, 1], [0, 0, 1, 0, 0, 1, 1, 1]]).astype(np.uint8)
# a.repeat(3, axis=0)
# b = valve_merge(a, 2)
# print(b)
# print(b)
c = valve_expend(a)
print(c)