diff --git a/main.py b/main.py index 0a242f5..0a1cfe6 100644 --- a/main.py +++ b/main.py @@ -9,7 +9,7 @@ from models import ManualTree, AnonymousColorDetector def main(): threshold = Config.threshold rgb_threshold = Config.rgb_threshold - manualTree = ManualTree(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path) + manual_tree = ManualTree(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path) tobacco_detector = AnonymousColorDetector(file_path=Config.rgb_tobacco_model_path) background_detector = AnonymousColorDetector(file_path=Config.rgb_background_model_path) @@ -29,7 +29,7 @@ def main(): # 读取(开启一个管道) if len(data) < 3: threshold = int(float(data)) - print(threshold) + print("[INFO] Get threshold: ", threshold) continue else: data_total = data @@ -46,12 +46,11 @@ def main(): # 识别 t1 = time.time() - img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)).transpose(0, - 2, - 1) + img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, + -1)).transpose(0, 2, 1) rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8) - pixel_predict_result = manualTree.pixel_predict_ml_dilation(data=img_data, iteration=1) - blk_predict_result = manualTree.blk_predict(data=img_data) + pixel_predict_result = manual_tree.pixel_predict_ml_dilation(data=img_data, iteration=1) + blk_predict_result = manual_tree.blk_predict(data=img_data) rgb_data = tobacco_detector.pretreatment(rgb_data) rgb_predict_result = 1 - ( background_detector.predict(rgb_data) | tobacco_detector.swell(tobacco_detector.predict(rgb_data))) @@ -65,7 +64,6 @@ def main(): mask = mask.reshape(Config.nRows, Config.nCols // Config.blk_size, Config.blk_size) \ .sum(axis=2).reshape(Config.nRows // 4, Config.blk_size, Config.nCols // Config.blk_size) \ .sum(axis=1) - # print(threshold) mask[mask <= threshold] = 0 mask[mask > threshold] = 1 mask_result = (mask | mask_rgb).astype(np.uint8) diff --git a/main_test.py b/main_test.py index 1165ae5..880f8a6 100644 --- a/main_test.py +++ b/main_test.py @@ -3,6 +3,7 @@ # @Auther : zhouchao # @File: main_test.py # @Software:PyCharm +import os import time import cv2 @@ -13,7 +14,7 @@ from models import Detector, AnonymousColorDetector from utils import read_labeled_img -def virtual_main(detector: Detector, test_img=None, test_img_dir=None, test_model=False): +def virtual_main(detector: AnonymousColorDetector, test_img=None, test_img_dir=None, test_model=False): """ 虚拟读图测试程序 @@ -45,6 +46,10 @@ def virtual_main(detector: Detector, test_img=None, test_img_dir=None, test_mode axs[0].set_title( f' resize {(t2 - t1) * 1000:.2f} ms, predict {(t3 - t2) * 1000:.2f} ms, total {(t3 - t1) * 1000:.2f} ms') plt.show() + if test_img_dir is not None: + image_names = os.listdir(test_img_dir) + for image_name in image_names: + img = cv2.imread(os.path.join(test_img_dir, image_name))[..., ::-1] if test_model: data_dir = "data/dataset" color_dict = {(0, 0, 255): "yangeng"} @@ -55,13 +60,13 @@ def virtual_main(detector: Detector, test_img=None, test_img_dir=None, test_mode if __name__ == '__main__': - detector = AnonymousColorDetector(file_path='dt_2022-07-20_14-40.model') - virtual_main(detector, + model = AnonymousColorDetector(file_path='dt_2022-07-20_14-40.model') + virtual_main(model, test_img=r'C:\Users\FEIJINTI\Desktop\720\binning1\tobacco\Image_2022_0720_1354_46_472-003051.bmp', test_model=True) - virtual_main(detector, + virtual_main(model, test_img=r'C:\Users\FEIJINTI\Desktop\720\binning1\tobacco\Image_2022_0720_1354_46_472-003051.bmp', test_model=True) - virtual_main(detector, + virtual_main(model, test_img=r'C:\Users\FEIJINTI\Desktop\720\binning1\tobacco\Image_2022_0720_1354_46_472-003051.bmp', test_model=True)