模型整合2

This commit is contained in:
FEIJINTI 2022-07-25 10:15:11 +08:00
parent 874fe2c69c
commit 75ab881bc5
2 changed files with 16 additions and 13 deletions

14
main.py
View File

@ -9,7 +9,7 @@ from models import ManualTree, AnonymousColorDetector
def main(): def main():
threshold = Config.threshold threshold = Config.threshold
rgb_threshold = Config.rgb_threshold rgb_threshold = Config.rgb_threshold
manualTree = ManualTree(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path) manual_tree = ManualTree(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path)
tobacco_detector = AnonymousColorDetector(file_path=Config.rgb_tobacco_model_path) tobacco_detector = AnonymousColorDetector(file_path=Config.rgb_tobacco_model_path)
background_detector = AnonymousColorDetector(file_path=Config.rgb_background_model_path) background_detector = AnonymousColorDetector(file_path=Config.rgb_background_model_path)
@ -29,7 +29,7 @@ def main():
# 读取(开启一个管道) # 读取(开启一个管道)
if len(data) < 3: if len(data) < 3:
threshold = int(float(data)) threshold = int(float(data))
print(threshold) print("[INFO] Get threshold: ", threshold)
continue continue
else: else:
data_total = data data_total = data
@ -46,12 +46,11 @@ def main():
# 识别 # 识别
t1 = time.time() t1 = time.time()
img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)).transpose(0, img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands,
2, -1)).transpose(0, 2, 1)
1)
rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8) rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8)
pixel_predict_result = manualTree.pixel_predict_ml_dilation(data=img_data, iteration=1) pixel_predict_result = manual_tree.pixel_predict_ml_dilation(data=img_data, iteration=1)
blk_predict_result = manualTree.blk_predict(data=img_data) blk_predict_result = manual_tree.blk_predict(data=img_data)
rgb_data = tobacco_detector.pretreatment(rgb_data) rgb_data = tobacco_detector.pretreatment(rgb_data)
rgb_predict_result = 1 - ( rgb_predict_result = 1 - (
background_detector.predict(rgb_data) | tobacco_detector.swell(tobacco_detector.predict(rgb_data))) background_detector.predict(rgb_data) | tobacco_detector.swell(tobacco_detector.predict(rgb_data)))
@ -65,7 +64,6 @@ def main():
mask = mask.reshape(Config.nRows, Config.nCols // Config.blk_size, Config.blk_size) \ mask = mask.reshape(Config.nRows, Config.nCols // Config.blk_size, Config.blk_size) \
.sum(axis=2).reshape(Config.nRows // 4, Config.blk_size, Config.nCols // Config.blk_size) \ .sum(axis=2).reshape(Config.nRows // 4, Config.blk_size, Config.nCols // Config.blk_size) \
.sum(axis=1) .sum(axis=1)
# print(threshold)
mask[mask <= threshold] = 0 mask[mask <= threshold] = 0
mask[mask > threshold] = 1 mask[mask > threshold] = 1
mask_result = (mask | mask_rgb).astype(np.uint8) mask_result = (mask | mask_rgb).astype(np.uint8)

View File

@ -3,6 +3,7 @@
# @Auther : zhouchao # @Auther : zhouchao
# @File: main_test.py # @File: main_test.py
# @Software:PyCharm # @Software:PyCharm
import os
import time import time
import cv2 import cv2
@ -13,7 +14,7 @@ from models import Detector, AnonymousColorDetector
from utils import read_labeled_img from utils import read_labeled_img
def virtual_main(detector: Detector, test_img=None, test_img_dir=None, test_model=False): def virtual_main(detector: AnonymousColorDetector, test_img=None, test_img_dir=None, test_model=False):
""" """
虚拟读图测试程序 虚拟读图测试程序
@ -45,6 +46,10 @@ def virtual_main(detector: Detector, test_img=None, test_img_dir=None, test_mode
axs[0].set_title( axs[0].set_title(
f' resize {(t2 - t1) * 1000:.2f} ms, predict {(t3 - t2) * 1000:.2f} ms, total {(t3 - t1) * 1000:.2f} ms') f' resize {(t2 - t1) * 1000:.2f} ms, predict {(t3 - t2) * 1000:.2f} ms, total {(t3 - t1) * 1000:.2f} ms')
plt.show() plt.show()
if test_img_dir is not None:
image_names = os.listdir(test_img_dir)
for image_name in image_names:
img = cv2.imread(os.path.join(test_img_dir, image_name))[..., ::-1]
if test_model: if test_model:
data_dir = "data/dataset" data_dir = "data/dataset"
color_dict = {(0, 0, 255): "yangeng"} color_dict = {(0, 0, 255): "yangeng"}
@ -55,13 +60,13 @@ def virtual_main(detector: Detector, test_img=None, test_img_dir=None, test_mode
if __name__ == '__main__': if __name__ == '__main__':
detector = AnonymousColorDetector(file_path='dt_2022-07-20_14-40.model') model = AnonymousColorDetector(file_path='dt_2022-07-20_14-40.model')
virtual_main(detector, virtual_main(model,
test_img=r'C:\Users\FEIJINTI\Desktop\720\binning1\tobacco\Image_2022_0720_1354_46_472-003051.bmp', test_img=r'C:\Users\FEIJINTI\Desktop\720\binning1\tobacco\Image_2022_0720_1354_46_472-003051.bmp',
test_model=True) test_model=True)
virtual_main(detector, virtual_main(model,
test_img=r'C:\Users\FEIJINTI\Desktop\720\binning1\tobacco\Image_2022_0720_1354_46_472-003051.bmp', test_img=r'C:\Users\FEIJINTI\Desktop\720\binning1\tobacco\Image_2022_0720_1354_46_472-003051.bmp',
test_model=True) test_model=True)
virtual_main(detector, virtual_main(model,
test_img=r'C:\Users\FEIJINTI\Desktop\720\binning1\tobacco\Image_2022_0720_1354_46_472-003051.bmp', test_img=r'C:\Users\FEIJINTI\Desktop\720\binning1\tobacco\Image_2022_0720_1354_46_472-003051.bmp',
test_model=True) test_model=True)