supermachine-tobacco/main_test.py
2022-07-26 17:23:25 +08:00

83 lines
3.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- codeing = utf-8 -*-
# Time : 2022/7/19 10:49
# @Auther : zhouchao
# @File: main_test.py
# @Software:PyCharm
import os
import time
import cv2
import matplotlib.pyplot as plt
import numpy as np
from config import Config
from models import Detector, AnonymousColorDetector, ManualTree
from utils import read_labeled_img, size_threshold
def pony_run(test_img=None, test_img_dir=None, test_spectra=False, test_rgb=False):
"""
虚拟读图测试程序
:param test_img: 测试图像rgb格式的图片或者路径
:param test_img_dir: 测试图像文件夹
:param test_spectra: 是否测试光谱
:param test_rgb: 是否测试rgb
:return:
"""
if (test_img is not None) or (test_img_dir is not None):
threshold = Config.spec_size_threshold
rgb_threshold = Config.rgb_size_threshold
manual_tree = ManualTree(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path)
tobacco_detector = AnonymousColorDetector(file_path=Config.rgb_tobacco_model_path)
background_detector = AnonymousColorDetector(file_path=Config.rgb_background_model_path)
if test_img is not None:
if isinstance(test_img, str):
img = cv2.imread(test_img)[:, :, ::-1]
elif isinstance(test_img, np.ndarray):
img = test_img
else:
raise TypeError("test img should be np.ndarray or str")
if test_img_dir is not None:
image_names = [img_name for img_name in os.listdir(test_img_dir) if img_name.endswith('.png')]
for image_name in image_names:
rgb_data = cv2.imread(os.path.join(test_img_dir, image_name))[..., ::-1]
# 识别
t1 = time.time()
if test_spectra:
# spectra part
pixel_predict_result = manual_tree.pixel_predict_ml_dilation(data=img_data, iteration=1)
blk_predict_result = manual_tree.blk_predict(data=img_data)
mask = (pixel_predict_result & blk_predict_result).astype(np.uint8)
mask_spec = size_threshold(mask, Config.blk_size, threshold)
if test_rgb:
# rgb part
rgb_data = tobacco_detector.pretreatment(rgb_data)
background = background_detector.predict(rgb_data)
tobacco = tobacco_detector.predict(rgb_data)
tobacco_d = tobacco_detector.swell(tobacco)
rgb_predict_result = 1 - (background | tobacco_d)
mask_rgb = size_threshold(rgb_predict_result, Config.blk_size, Config.rgb_size_threshold)
fig, axs = plt.subplots(5, 1, figsize=(12, 10), constrained_layout=True)
axs[0].imshow(rgb_data)
axs[0].set_title("rgb raw data")
axs[1].imshow(background)
axs[1].set_title("background")
axs[2].imshow(tobacco)
axs[2].set_title("tobacco")
axs[3].imshow(rgb_predict_result)
axs[3].set_title("1 - (background + dilate(tobacco))")
axs[4].imshow(mask_rgb)
axs[4].set_title("final mask")
plt.show()
mask_result = (mask | mask_rgb).astype(np.uint8)
# mask_result = rgb_predict_result
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
t2 = time.time()
print(f'rgb len = {len(rgb_data)}')
if __name__ == '__main__':
pony_run(test_img_dir=r'E:\zhouchao\725data', test_rgb=True)