From 4495d7f706475b294df3d01970f61b0c1a951fd2 Mon Sep 17 00:00:00 2001 From: "li.zhenye" Date: Fri, 5 Aug 2022 16:48:09 +0800 Subject: [PATCH] =?UTF-8?q?[ext]=20=E5=96=B7=E9=98=80=E7=BB=93=E6=9E=9C?= =?UTF-8?q?=E5=90=88=E5=B9=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 添加了喷阀结果合并功能,并在utils.py的main当中添加了测试,测试已经通过。 --- config.py | 1 + main.py | 7 +++++-- utils.py | 14 ++++++++++++++ 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/config.py b/config.py index 2147a2b..7a318a3 100644 --- a/config.py +++ b/config.py @@ -32,6 +32,7 @@ class Config: # mask parameter target_size = (1024, 1024) # (Width, Height) of mask + valve_merge_size = 2 # 每两个喷阀当中有任意一个出现杂质则认为都是杂质 # save part offset_vertical = 0 diff --git a/main.py b/main.py index 7130fb8..17314d8 100755 --- a/main.py +++ b/main.py @@ -3,6 +3,7 @@ import cv2 import time import numpy as np +import utils from config import Config from models import RgbDetector, SpecDetector @@ -66,8 +67,10 @@ def main(only_spec=False, only_color=False): mask_spec = spec_detector.predict(img_data) mask_rgb = rgb_detector.predict(rgb_data) - # control the size of the output masks - masks = [cv2.resize(mask.astype(np.uint8), Config.target_size) for mask in [mask_spec, mask_rgb]] + # 进行喷阀的合并 + masks = [utils.valve_merge(mask, merge_size=Config.valve_merge_size) for mask in [mask_spec, mask_rgb]] + # control the size of the output masks, 在resize前,图像的宽度是和喷阀对应的 + masks = [cv2.resize(mask.astype(np.uint8), Config.target_size) for mask in masks] # 写出 output_fifos = [mask_fifo_path, rgb_mask_fifo_path] for fifo, mask in zip(output_fifos, masks): diff --git a/utils.py b/utils.py index a19fa22..538082e 100755 --- a/utils.py +++ b/utils.py @@ -150,6 +150,16 @@ def size_threshold(img, blk_size, threshold, last_end: np.ndarray = None) -> np. return mask +def valve_merge(img: np.ndarray, merge_size: int = 2) -> np.ndarray: + assert img.shape[1] % merge_size == 0 # 列数必须能够被整除 + img_shape = (img.shape[1], img.shape[0]) + img = img.reshape((img.shape[0], img.shape[1]//merge_size, merge_size)) + img = np.sum(img, axis=2) + img[img > 0] = 1 + img = cv2.resize(img.astype(np.uint8), dsize=img_shape) + return img + + def read_envi_ascii(file_name, save_xy=False, hdr_file_name=None): """ Read envi ascii file. Use ENVI ROI Tool -> File -> output ROIs to ASCII... @@ -210,3 +220,7 @@ if __name__ == '__main__': (255, 0, 255): "chengsebangbangtang", (0, 255, 255): "lvdianxian"} dataset = read_labeled_img("data/dataset", color_dict=color_dict, is_ps_color_space=False) lab_scatter(dataset, class_max_num=20000, is_3d=False, is_ps_color_space=False) + # a = np.array([[1, 1, 0, 0, 1, 0, 0, 1], [0, 0, 1, 0, 0, 1, 1, 1]]).astype(np.uint8) + # a.repeat(3, axis=0) + # b = valve_merge(a, 2) + # print(b) \ No newline at end of file