mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 14:23:55 +00:00
修改大小滤波
This commit is contained in:
parent
34d9e9a676
commit
37d8dc43b0
@ -18,7 +18,7 @@ class Config:
|
||||
green_bands = [i for i in range(1, 21)]
|
||||
|
||||
# 光谱模型参数
|
||||
blk_size = 4
|
||||
blk_size = 4 # 必须是2的倍数,不然会出错
|
||||
pixel_model_path = r"./models/pixel_2022-08-02_15-22.model"
|
||||
blk_model_path = r"./models/rf_4x4_c22_20_sen8_9.model"
|
||||
spec_size_threshold = 3
|
||||
|
||||
1
main.py
1
main.py
@ -23,6 +23,7 @@ def main(only_spec=False, only_color=False):
|
||||
os.mkfifo(mask_fifo_path, 0o777)
|
||||
if not os.access(rgb_mask_fifo_path, os.F_OK):
|
||||
os.mkfifo(rgb_mask_fifo_path, 0o777)
|
||||
|
||||
while True:
|
||||
fd_img = os.open(img_fifo_path, os.O_RDONLY)
|
||||
fd_rgb = os.open(rgb_fifo_path, os.O_RDONLY)
|
||||
|
||||
@ -339,16 +339,19 @@ class SpecDetector(Detector):
|
||||
self.blk_model = None
|
||||
self.pixel_model_ml = None
|
||||
self.load(blk_model_path, pixel_model_path)
|
||||
self.spare_part = np.zeros((Config.blk_size//2, Config.nCols))
|
||||
|
||||
def load(self, blk_model_path, pixel_model_path):
|
||||
self.pixel_model_ml = PixelModelML(pixel_model_path)
|
||||
self.blk_model = BlkModel(blk_model_path)
|
||||
|
||||
def predict(self, img_data):
|
||||
def predict(self, img_data: np.ndarray, save_part: bool = False) -> np.ndarray:
|
||||
pixel_predict_result = self.pixel_predict_ml_dilation(data=img_data, iteration=1)
|
||||
blk_predict_result = self.blk_predict(data=img_data)
|
||||
mask = (pixel_predict_result & blk_predict_result).astype(np.uint8)
|
||||
mask = size_threshold(mask, Config.blk_size, Config.spec_size_threshold)
|
||||
if save_part:
|
||||
self.spare_part = mask[-(Config.blk_size//2):, :]
|
||||
mask = size_threshold(mask, Config.blk_size, Config.spec_size_threshold, self.spare_part)
|
||||
return mask
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
|
||||
12
utils.py
12
utils.py
@ -132,13 +132,21 @@ def lab_scatter(dataset: dict, class_max_num=None, is_3d=False, is_ps_color_spac
|
||||
plt.show()
|
||||
|
||||
|
||||
def size_threshold(img, blk_size, threshold, last_end: np.ndarray=None) -> np.ndarray:
|
||||
def size_threshold(img, blk_size, threshold, last_end: np.ndarray = None) -> np.ndarray:
|
||||
mask = img.reshape(img.shape[0], img.shape[1] // blk_size, blk_size).sum(axis=2). \
|
||||
reshape(img.shape[0] // blk_size, blk_size, img.shape[1] // blk_size).sum(axis=1)
|
||||
mask[mask <= threshold] = 0
|
||||
mask[mask > threshold] = 1
|
||||
if last_end is not None:
|
||||
mask_up = np.concatenate((last_end, img[:-(blk_size//2), :]))
|
||||
half_blk_size = blk_size // 2
|
||||
assert (last_end.shape[0] == half_blk_size) and (last_end.shape[1] == img.shape[1])
|
||||
mask_up = np.concatenate((last_end, img[:-half_blk_size, :]), axis=0)
|
||||
mask_up_right = np.concatenate((mask_up[:, half_blk_size:],
|
||||
np.zeros((img.shape[0], half_blk_size), dtype=np.uint8)), axis=1)
|
||||
mask_up = size_threshold(mask_up, blk_size, threshold)
|
||||
mask_up_right = size_threshold(mask_up_right, blk_size, threshold)
|
||||
mask[:-1, :] = mask_up[1:, :]
|
||||
mask[:-1, 1:] = mask_up_right[1:, :-1]
|
||||
return mask
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user