模型整合2

This commit is contained in:
FEIJINTI 2022-07-25 10:15:11 +08:00
parent 874fe2c69c
commit 75ab881bc5
2 changed files with 16 additions and 13 deletions

14
main.py
View File

@ -9,7 +9,7 @@ from models import ManualTree, AnonymousColorDetector
def main():
threshold = Config.threshold
rgb_threshold = Config.rgb_threshold
manualTree = ManualTree(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path)
manual_tree = ManualTree(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path)
tobacco_detector = AnonymousColorDetector(file_path=Config.rgb_tobacco_model_path)
background_detector = AnonymousColorDetector(file_path=Config.rgb_background_model_path)
@ -29,7 +29,7 @@ def main():
# 读取(开启一个管道)
if len(data) < 3:
threshold = int(float(data))
print(threshold)
print("[INFO] Get threshold: ", threshold)
continue
else:
data_total = data
@ -46,12 +46,11 @@ def main():
# 识别
t1 = time.time()
img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)).transpose(0,
2,
1)
img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands,
-1)).transpose(0, 2, 1)
rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8)
pixel_predict_result = manualTree.pixel_predict_ml_dilation(data=img_data, iteration=1)
blk_predict_result = manualTree.blk_predict(data=img_data)
pixel_predict_result = manual_tree.pixel_predict_ml_dilation(data=img_data, iteration=1)
blk_predict_result = manual_tree.blk_predict(data=img_data)
rgb_data = tobacco_detector.pretreatment(rgb_data)
rgb_predict_result = 1 - (
background_detector.predict(rgb_data) | tobacco_detector.swell(tobacco_detector.predict(rgb_data)))
@ -65,7 +64,6 @@ def main():
mask = mask.reshape(Config.nRows, Config.nCols // Config.blk_size, Config.blk_size) \
.sum(axis=2).reshape(Config.nRows // 4, Config.blk_size, Config.nCols // Config.blk_size) \
.sum(axis=1)
# print(threshold)
mask[mask <= threshold] = 0
mask[mask > threshold] = 1
mask_result = (mask | mask_rgb).astype(np.uint8)

View File

@ -3,6 +3,7 @@
# @Auther : zhouchao
# @File: main_test.py
# @Software:PyCharm
import os
import time
import cv2
@ -13,7 +14,7 @@ from models import Detector, AnonymousColorDetector
from utils import read_labeled_img
def virtual_main(detector: Detector, test_img=None, test_img_dir=None, test_model=False):
def virtual_main(detector: AnonymousColorDetector, test_img=None, test_img_dir=None, test_model=False):
"""
虚拟读图测试程序
@ -45,6 +46,10 @@ def virtual_main(detector: Detector, test_img=None, test_img_dir=None, test_mode
axs[0].set_title(
f' resize {(t2 - t1) * 1000:.2f} ms, predict {(t3 - t2) * 1000:.2f} ms, total {(t3 - t1) * 1000:.2f} ms')
plt.show()
if test_img_dir is not None:
image_names = os.listdir(test_img_dir)
for image_name in image_names:
img = cv2.imread(os.path.join(test_img_dir, image_name))[..., ::-1]
if test_model:
data_dir = "data/dataset"
color_dict = {(0, 0, 255): "yangeng"}
@ -55,13 +60,13 @@ def virtual_main(detector: Detector, test_img=None, test_img_dir=None, test_mode
if __name__ == '__main__':
detector = AnonymousColorDetector(file_path='dt_2022-07-20_14-40.model')
virtual_main(detector,
model = AnonymousColorDetector(file_path='dt_2022-07-20_14-40.model')
virtual_main(model,
test_img=r'C:\Users\FEIJINTI\Desktop\720\binning1\tobacco\Image_2022_0720_1354_46_472-003051.bmp',
test_model=True)
virtual_main(detector,
virtual_main(model,
test_img=r'C:\Users\FEIJINTI\Desktop\720\binning1\tobacco\Image_2022_0720_1354_46_472-003051.bmp',
test_model=True)
virtual_main(detector,
virtual_main(model,
test_img=r'C:\Users\FEIJINTI\Desktop\720\binning1\tobacco\Image_2022_0720_1354_46_472-003051.bmp',
test_model=True)