大小滤波修改

This commit is contained in:
li.zhenye 2022-08-05 10:26:20 +08:00
parent 4d6fd3cd3a
commit 34d9e9a676
2 changed files with 4 additions and 2 deletions

View File

@ -11,7 +11,7 @@ def main(only_spec=False, only_color=False):
spec_detector = SpecDetector(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path)
rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path,
background_model_path=Config.rgb_background_model_path)
_, _ = spec_detector.predict(np.ones(Config.nRows, Config.nCols, Config.nBands, dtype=float)*0.4),\
_, _ = spec_detector.predict(np.ones((Config.nRows, Config.nCols, Config.nBands), dtype=float)*0.4),\
rgb_detector.predict(np.ones((Config.nRgbRows, Config.nRgbCols, Config.nRgbBands), dtype=np.uint8)*40)
total_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节
total_rgb = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量

View File

@ -132,11 +132,13 @@ def lab_scatter(dataset: dict, class_max_num=None, is_3d=False, is_ps_color_spac
plt.show()
def size_threshold(img, blk_size, threshold):
def size_threshold(img, blk_size, threshold, last_end: np.ndarray=None) -> np.ndarray:
mask = img.reshape(img.shape[0], img.shape[1] // blk_size, blk_size).sum(axis=2). \
reshape(img.shape[0] // blk_size, blk_size, img.shape[1] // blk_size).sum(axis=1)
mask[mask <= threshold] = 0
mask[mask > threshold] = 1
if last_end is not None:
mask_up = np.concatenate((last_end, img[:-(blk_size//2), :]))
return mask