From 34d9e9a67617194c4f786226be7a2e5930d66efc Mon Sep 17 00:00:00 2001 From: "li.zhenye" Date: Fri, 5 Aug 2022 10:26:20 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A4=A7=E5=B0=8F=E6=BB=A4=E6=B3=A2=E4=BF=AE?= =?UTF-8?q?=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 2 +- utils.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/main.py b/main.py index a491067..b9a25db 100755 --- a/main.py +++ b/main.py @@ -11,7 +11,7 @@ def main(only_spec=False, only_color=False): spec_detector = SpecDetector(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path) rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path, background_model_path=Config.rgb_background_model_path) - _, _ = spec_detector.predict(np.ones(Config.nRows, Config.nCols, Config.nBands, dtype=float)*0.4),\ + _, _ = spec_detector.predict(np.ones((Config.nRows, Config.nCols, Config.nBands), dtype=float)*0.4),\ rgb_detector.predict(np.ones((Config.nRgbRows, Config.nRgbCols, Config.nRgbBands), dtype=np.uint8)*40) total_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节 total_rgb = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量 diff --git a/utils.py b/utils.py index 298d869..b067976 100755 --- a/utils.py +++ b/utils.py @@ -132,11 +132,13 @@ def lab_scatter(dataset: dict, class_max_num=None, is_3d=False, is_ps_color_spac plt.show() -def size_threshold(img, blk_size, threshold): +def size_threshold(img, blk_size, threshold, last_end: np.ndarray=None) -> np.ndarray: mask = img.reshape(img.shape[0], img.shape[1] // blk_size, blk_size).sum(axis=2). \ reshape(img.shape[0] // blk_size, blk_size, img.shape[1] // blk_size).sum(axis=1) mask[mask <= threshold] = 0 mask[mask > threshold] = 1 + if last_end is not None: + mask_up = np.concatenate((last_end, img[:-(blk_size//2), :])) return mask