mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 14:23:55 +00:00
83 lines
3.5 KiB
Python
83 lines
3.5 KiB
Python
# -*- codeing = utf-8 -*-
|
||
# Time : 2022/7/19 10:49
|
||
# @Auther : zhouchao
|
||
# @File: main_test.py
|
||
# @Software:PyCharm
|
||
import os
|
||
import time
|
||
|
||
import cv2
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
|
||
from config import Config
|
||
from models import Detector, AnonymousColorDetector, ManualTree
|
||
from utils import read_labeled_img, size_threshold
|
||
|
||
|
||
def pony_run(test_img=None, test_img_dir=None, test_spectra=False, test_rgb=False):
|
||
"""
|
||
虚拟读图测试程序
|
||
|
||
:param test_img: 测试图像,rgb格式的图片或者路径
|
||
:param test_img_dir: 测试图像文件夹
|
||
:param test_spectra: 是否测试光谱
|
||
:param test_rgb: 是否测试rgb
|
||
:return:
|
||
"""
|
||
if (test_img is not None) or (test_img_dir is not None):
|
||
threshold = Config.spec_size_threshold
|
||
rgb_threshold = Config.rgb_size_threshold
|
||
manual_tree = ManualTree(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path)
|
||
tobacco_detector = AnonymousColorDetector(file_path=Config.rgb_tobacco_model_path)
|
||
background_detector = AnonymousColorDetector(file_path=Config.rgb_background_model_path)
|
||
if test_img is not None:
|
||
if isinstance(test_img, str):
|
||
img = cv2.imread(test_img)[:, :, ::-1]
|
||
elif isinstance(test_img, np.ndarray):
|
||
img = test_img
|
||
else:
|
||
raise TypeError("test img should be np.ndarray or str")
|
||
if test_img_dir is not None:
|
||
image_names = [img_name for img_name in os.listdir(test_img_dir) if img_name.endswith('.png')]
|
||
for image_name in image_names:
|
||
rgb_data = cv2.imread(os.path.join(test_img_dir, image_name))[..., ::-1]
|
||
# 识别
|
||
t1 = time.time()
|
||
if test_spectra:
|
||
# spectra part
|
||
pixel_predict_result = manual_tree.pixel_predict_ml_dilation(data=img_data, iteration=1)
|
||
blk_predict_result = manual_tree.blk_predict(data=img_data)
|
||
mask = (pixel_predict_result & blk_predict_result).astype(np.uint8)
|
||
mask_spec = size_threshold(mask, Config.blk_size, threshold)
|
||
if test_rgb:
|
||
# rgb part
|
||
rgb_data = tobacco_detector.pretreatment(rgb_data)
|
||
background = background_detector.predict(rgb_data)
|
||
tobacco = tobacco_detector.predict(rgb_data)
|
||
tobacco_d = tobacco_detector.swell(tobacco)
|
||
rgb_predict_result = 1 - (background | tobacco_d)
|
||
mask_rgb = size_threshold(rgb_predict_result, Config.blk_size, Config.rgb_size_threshold)
|
||
fig, axs = plt.subplots(5, 1, figsize=(12, 10), constrained_layout=True)
|
||
axs[0].imshow(rgb_data)
|
||
axs[0].set_title("rgb raw data")
|
||
axs[1].imshow(background)
|
||
axs[1].set_title("background")
|
||
axs[2].imshow(tobacco)
|
||
axs[2].set_title("tobacco")
|
||
axs[3].imshow(rgb_predict_result)
|
||
axs[3].set_title("1 - (background + dilate(tobacco))")
|
||
axs[4].imshow(mask_rgb)
|
||
axs[4].set_title("final mask")
|
||
plt.show()
|
||
|
||
mask_result = (mask | mask_rgb).astype(np.uint8)
|
||
# mask_result = rgb_predict_result
|
||
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
|
||
t2 = time.time()
|
||
print(f'rgb len = {len(rgb_data)}')
|
||
|
||
|
||
if __name__ == '__main__':
|
||
pony_run(test_img_dir=r'E:\zhouchao\725data', test_rgb=True)
|