mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 14:23:55 +00:00
[ext] 喷阀结果合并
添加了喷阀结果合并功能,并在utils.py的main当中添加了测试,测试已经通过。
This commit is contained in:
parent
621b21651f
commit
4495d7f706
@ -32,6 +32,7 @@ class Config:
|
|||||||
|
|
||||||
# mask parameter
|
# mask parameter
|
||||||
target_size = (1024, 1024) # (Width, Height) of mask
|
target_size = (1024, 1024) # (Width, Height) of mask
|
||||||
|
valve_merge_size = 2 # 每两个喷阀当中有任意一个出现杂质则认为都是杂质
|
||||||
|
|
||||||
# save part
|
# save part
|
||||||
offset_vertical = 0
|
offset_vertical = 0
|
||||||
|
|||||||
7
main.py
7
main.py
@ -3,6 +3,7 @@ import cv2
|
|||||||
import time
|
import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
import utils
|
||||||
from config import Config
|
from config import Config
|
||||||
from models import RgbDetector, SpecDetector
|
from models import RgbDetector, SpecDetector
|
||||||
|
|
||||||
@ -66,8 +67,10 @@ def main(only_spec=False, only_color=False):
|
|||||||
mask_spec = spec_detector.predict(img_data)
|
mask_spec = spec_detector.predict(img_data)
|
||||||
mask_rgb = rgb_detector.predict(rgb_data)
|
mask_rgb = rgb_detector.predict(rgb_data)
|
||||||
|
|
||||||
# control the size of the output masks
|
# 进行喷阀的合并
|
||||||
masks = [cv2.resize(mask.astype(np.uint8), Config.target_size) for mask in [mask_spec, mask_rgb]]
|
masks = [utils.valve_merge(mask, merge_size=Config.valve_merge_size) for mask in [mask_spec, mask_rgb]]
|
||||||
|
# control the size of the output masks, 在resize前,图像的宽度是和喷阀对应的
|
||||||
|
masks = [cv2.resize(mask.astype(np.uint8), Config.target_size) for mask in masks]
|
||||||
# 写出
|
# 写出
|
||||||
output_fifos = [mask_fifo_path, rgb_mask_fifo_path]
|
output_fifos = [mask_fifo_path, rgb_mask_fifo_path]
|
||||||
for fifo, mask in zip(output_fifos, masks):
|
for fifo, mask in zip(output_fifos, masks):
|
||||||
|
|||||||
14
utils.py
14
utils.py
@ -150,6 +150,16 @@ def size_threshold(img, blk_size, threshold, last_end: np.ndarray = None) -> np.
|
|||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
def valve_merge(img: np.ndarray, merge_size: int = 2) -> np.ndarray:
|
||||||
|
assert img.shape[1] % merge_size == 0 # 列数必须能够被整除
|
||||||
|
img_shape = (img.shape[1], img.shape[0])
|
||||||
|
img = img.reshape((img.shape[0], img.shape[1]//merge_size, merge_size))
|
||||||
|
img = np.sum(img, axis=2)
|
||||||
|
img[img > 0] = 1
|
||||||
|
img = cv2.resize(img.astype(np.uint8), dsize=img_shape)
|
||||||
|
return img
|
||||||
|
|
||||||
|
|
||||||
def read_envi_ascii(file_name, save_xy=False, hdr_file_name=None):
|
def read_envi_ascii(file_name, save_xy=False, hdr_file_name=None):
|
||||||
"""
|
"""
|
||||||
Read envi ascii file. Use ENVI ROI Tool -> File -> output ROIs to ASCII...
|
Read envi ascii file. Use ENVI ROI Tool -> File -> output ROIs to ASCII...
|
||||||
@ -210,3 +220,7 @@ if __name__ == '__main__':
|
|||||||
(255, 0, 255): "chengsebangbangtang", (0, 255, 255): "lvdianxian"}
|
(255, 0, 255): "chengsebangbangtang", (0, 255, 255): "lvdianxian"}
|
||||||
dataset = read_labeled_img("data/dataset", color_dict=color_dict, is_ps_color_space=False)
|
dataset = read_labeled_img("data/dataset", color_dict=color_dict, is_ps_color_space=False)
|
||||||
lab_scatter(dataset, class_max_num=20000, is_3d=False, is_ps_color_space=False)
|
lab_scatter(dataset, class_max_num=20000, is_3d=False, is_ps_color_space=False)
|
||||||
|
# a = np.array([[1, 1, 0, 0, 1, 0, 0, 1], [0, 0, 1, 0, 0, 1, 1, 1]]).astype(np.uint8)
|
||||||
|
# a.repeat(3, axis=0)
|
||||||
|
# b = valve_merge(a, 2)
|
||||||
|
# print(b)
|
||||||
Loading…
Reference in New Issue
Block a user