diff --git a/main.py b/main.py index a491067..b9a25db 100755 --- a/main.py +++ b/main.py @@ -11,7 +11,7 @@ def main(only_spec=False, only_color=False): spec_detector = SpecDetector(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path) rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path, background_model_path=Config.rgb_background_model_path) - _, _ = spec_detector.predict(np.ones(Config.nRows, Config.nCols, Config.nBands, dtype=float)*0.4),\ + _, _ = spec_detector.predict(np.ones((Config.nRows, Config.nCols, Config.nBands), dtype=float)*0.4),\ rgb_detector.predict(np.ones((Config.nRgbRows, Config.nRgbCols, Config.nRgbBands), dtype=np.uint8)*40) total_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节 total_rgb = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量 diff --git a/utils.py b/utils.py index 298d869..b067976 100755 --- a/utils.py +++ b/utils.py @@ -132,11 +132,13 @@ def lab_scatter(dataset: dict, class_max_num=None, is_3d=False, is_ps_color_spac plt.show() -def size_threshold(img, blk_size, threshold): +def size_threshold(img, blk_size, threshold, last_end: np.ndarray=None) -> np.ndarray: mask = img.reshape(img.shape[0], img.shape[1] // blk_size, blk_size).sum(axis=2). \ reshape(img.shape[0] // blk_size, blk_size, img.shape[1] // blk_size).sum(axis=1) mask[mask <= threshold] = 0 mask[mask > threshold] = 1 + if last_end is not None: + mask_up = np.concatenate((last_end, img[:-(blk_size//2), :])) return mask