From 34d9e9a67617194c4f786226be7a2e5930d66efc Mon Sep 17 00:00:00 2001
From: "li.zhenye"
Date: Fri, 5 Aug 2022 10:26:20 +0800
Subject: [PATCH] =?UTF-8?q?=E5=A4=A7=E5=B0=8F=E6=BB=A4=E6=B3=A2=E4=BF=AE?=
=?UTF-8?q?=E6=94=B9?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
main.py | 2 +-
utils.py | 4 +++-
2 files changed, 4 insertions(+), 2 deletions(-)
diff --git a/main.py b/main.py
index a491067..b9a25db 100755
--- a/main.py
+++ b/main.py
@@ -11,7 +11,7 @@ def main(only_spec=False, only_color=False):
spec_detector = SpecDetector(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path)
rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path,
background_model_path=Config.rgb_background_model_path)
- _, _ = spec_detector.predict(np.ones(Config.nRows, Config.nCols, Config.nBands, dtype=float)*0.4),\
+ _, _ = spec_detector.predict(np.ones((Config.nRows, Config.nCols, Config.nBands), dtype=float)*0.4),\
rgb_detector.predict(np.ones((Config.nRgbRows, Config.nRgbCols, Config.nRgbBands), dtype=np.uint8)*40)
total_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节
total_rgb = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量
diff --git a/utils.py b/utils.py
index 298d869..b067976 100755
--- a/utils.py
+++ b/utils.py
@@ -132,11 +132,13 @@ def lab_scatter(dataset: dict, class_max_num=None, is_3d=False, is_ps_color_spac
plt.show()
-def size_threshold(img, blk_size, threshold):
+def size_threshold(img, blk_size, threshold, last_end: np.ndarray=None) -> np.ndarray:
mask = img.reshape(img.shape[0], img.shape[1] // blk_size, blk_size).sum(axis=2). \
reshape(img.shape[0] // blk_size, blk_size, img.shape[1] // blk_size).sum(axis=1)
mask[mask <= threshold] = 0
mask[mask > threshold] = 1
+ if last_end is not None:
+ mask_up = np.concatenate((last_end, img[:-(blk_size//2), :]))
return mask