mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 06:13:53 +00:00
243 lines
11 KiB
Python
243 lines
11 KiB
Python
# -*- codeing = utf-8 -*-
|
|
# Time : 2022/7/19 10:49
|
|
# @Auther : zhouchao
|
|
# @File: main_test.py
|
|
# @Software:PyCharm
|
|
import datetime
|
|
import itertools
|
|
import logging
|
|
import os
|
|
import time
|
|
import socket
|
|
import typing
|
|
from shutil import copyfile
|
|
|
|
import cv2
|
|
import matplotlib.pyplot as plt
|
|
import numpy as np
|
|
from matplotlib.gridspec import GridSpec
|
|
import tqdm
|
|
|
|
import transmit
|
|
import utils
|
|
from config import Config
|
|
from models import Detector, AnonymousColorDetector, ManualTree, SpecDetector, RgbDetector
|
|
from utils import read_labeled_img, size_threshold, natural_sort
|
|
|
|
|
|
class TestMain:
|
|
def __init__(self):
|
|
self._spec_detector = SpecDetector(blk_model_path=Config.blk_model_path,
|
|
pixel_model_path=Config.pixel_model_path)
|
|
self._rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path,
|
|
background_model_path=Config.rgb_background_model_path,
|
|
ai_path=Config.ai_path)
|
|
|
|
def pony_run(self, test_path, test_spectra=False, test_rgb=False,
|
|
convert_dir=None, get_delta=False, silent=False):
|
|
"""
|
|
虚拟读图测试程序
|
|
|
|
:param test_path: 测试文件夹或者图片
|
|
:param test_spectra: 是否测试光谱
|
|
:param test_rgb: 是否测试rgb
|
|
:param convert_dir: 格式转化文件夹
|
|
:param get_delta:是否计算偏差
|
|
:param silent: 是否静默模式运行
|
|
:return:
|
|
"""
|
|
if os.path.isdir(test_path):
|
|
rgb_file_names, spec_file_names = [[file_name for file_name in os.listdir(test_path) if
|
|
file_name.startswith(file_type)] for file_type in ['rgb', 'spec']]
|
|
rgb_file_names, spec_file_names = natural_sort(rgb_file_names), natural_sort(spec_file_names)
|
|
else:
|
|
if test_spectra:
|
|
with open(test_path, 'rb') as f:
|
|
data = f.read()
|
|
spec_img = transmit.BeforeAfterMethods.spec_data_post_process(data)
|
|
_ = self.test_spec(spec_img=spec_img, img_name=test_path)
|
|
elif test_rgb:
|
|
with open(test_path, 'rb') as f:
|
|
data = f.read()
|
|
rgb_img = transmit.BeforeAfterMethods.rgb_data_post_process(data)
|
|
_ = self.test_rgb(rgb_img, img_name=test_path)
|
|
return
|
|
if silent and convert_dir is not None:
|
|
bar = tqdm.tqdm(desc='正在转换', total=len(rgb_file_names), ncols=80)
|
|
for rgb_file_name, spec_file_name in zip(rgb_file_names, spec_file_names):
|
|
rgb_temp, spec_temp = rgb_file_name[:], spec_file_name[:]
|
|
assert rgb_temp.replace('rgb', '') == spec_temp.replace('spec', '')
|
|
if test_spectra:
|
|
with open(os.path.join(test_path, spec_file_name), 'rb') as f:
|
|
data = f.read()
|
|
try:
|
|
spec_img = transmit.BeforeAfterMethods.spec_data_post_process(data)
|
|
except Exception as e:
|
|
raise FileExistsError(f'文件 {spec_file_name} 读取失败,长度不对. 请清理文件夹{test_path},'
|
|
f'去除{spec_file_name}\n {e}')
|
|
|
|
if convert_dir is not None:
|
|
spec_img_show = np.asarray(np.clip(spec_img[..., [21, 3, 0]] * 255, a_min=0, a_max=255),
|
|
dtype=np.uint8)
|
|
cv2.imwrite(os.path.join(convert_dir, spec_file_name + '.bmp'), spec_img_show[..., ::-1])
|
|
hdr_string = utils.generate_hdr(f'File imported from {spec_file_name} at'
|
|
f' {datetime.datetime.now().strftime("%m/%d/%Y, %H:%M")}')
|
|
copyfile(os.path.join(test_path, spec_file_name), os.path.join(convert_dir, spec_file_name+'.raw'))
|
|
|
|
with open(os.path.join(convert_dir, spec_file_name + '.hdr'), 'w') as f:
|
|
f.write(hdr_string)
|
|
spec_mask = self.test_spec(spec_img, img_name=spec_file_name, show=False)
|
|
if test_rgb:
|
|
with open(os.path.join(test_path, rgb_file_name), 'rb') as f:
|
|
data = f.read()
|
|
try:
|
|
rgb_img = transmit.BeforeAfterMethods.rgb_data_post_process(data)
|
|
except Exception as e:
|
|
raise FileExistsError(f'文件 {rgb_file_name} 读取失败, 长度不对. 请清理文件夹{test_path},'
|
|
f' 去除{rgb_file_name} \n {e}')
|
|
|
|
if convert_dir is not None:
|
|
cv2.imwrite(os.path.join(convert_dir, rgb_file_name + '.bmp'), rgb_img[..., ::-1])
|
|
rgb_mask = self.test_rgb(rgb_img, img_name=rgb_file_name, show=False)
|
|
if test_rgb and test_spectra:
|
|
if get_delta:
|
|
spec_cv = np.clip(spec_img[..., [21, 3, 0]], a_min=0, a_max=1) * 255
|
|
spec_cv = spec_cv.astype(np.uint8)
|
|
delta = self.calculate_delta(rgb_img, spec_cv)
|
|
print(delta)
|
|
self.merge(rgb_img=rgb_img, rgb_mask=rgb_mask,
|
|
spec_img=spec_img[..., [21, 3, 0]], spec_mask=spec_mask,
|
|
rgb_file_name=rgb_file_name, spec_file_name=spec_file_name,
|
|
show=not silent)
|
|
if silent and convert_dir is not None:
|
|
bar.update()
|
|
|
|
def test_rgb(self, rgb_img, img_name, show=True):
|
|
rgb_mask = self._rgb_detector.predict(rgb_img)
|
|
if show:
|
|
fig, axs = plt.subplots(2, 1)
|
|
axs[0].imshow(rgb_img)
|
|
axs[0].set_title(f"rgb img {img_name}")
|
|
axs[1].imshow(rgb_mask)
|
|
axs[1].set_title('rgb mask')
|
|
plt.show()
|
|
return rgb_mask
|
|
|
|
def test_spec(self, spec_img, img_name, show=True):
|
|
spec_mask = self._spec_detector.predict(spec_img)
|
|
if show:
|
|
fig, axs = plt.subplots(2, 1)
|
|
axs[0].imshow(spec_img[..., [21, 3, 0]])
|
|
axs[0].set_title(f"spec img {img_name}")
|
|
axs[1].imshow(spec_mask)
|
|
axs[1].set_title('spec mask')
|
|
plt.show()
|
|
return spec_mask
|
|
|
|
@staticmethod
|
|
def merge(rgb_img, rgb_mask, spec_img, spec_mask, rgb_file_name, spec_file_name, show=True):
|
|
mask_result = (spec_mask | rgb_mask).astype(np.uint8)
|
|
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
|
|
if show:
|
|
fig = plt.figure(constrained_layout=True)
|
|
gs = GridSpec(3, 2, figure=fig)
|
|
ax1 = fig.add_subplot(gs[0, 0])
|
|
ax2 = fig.add_subplot(gs[0, 1])
|
|
ax3 = fig.add_subplot(gs[1, 0])
|
|
ax4 = fig.add_subplot(gs[1, 1])
|
|
ax5 = fig.add_subplot(gs[2, :])
|
|
|
|
fig.suptitle("Result")
|
|
|
|
ax1.imshow(rgb_img)
|
|
ax1.set_title(rgb_file_name)
|
|
ax2.imshow(rgb_mask)
|
|
ax2.set_title("rgb mask")
|
|
|
|
ax3.imshow(np.clip(spec_img, a_min=0, a_max=1))
|
|
ax3.set_title(spec_file_name)
|
|
ax4.imshow(spec_mask)
|
|
ax4.set_title('spec mask')
|
|
|
|
ax5.imshow(mask_result)
|
|
ax5.set_title('merge mask')
|
|
plt.show()
|
|
return mask_result
|
|
|
|
def calculate_delta(self, rgb_img, spec_img, search_area_size=(400, 200), eps=1):
|
|
rgb_grey, spec_grey = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2GRAY), cv2.cvtColor(spec_img, cv2.COLOR_RGB2GRAY)
|
|
_, rgb_bin = cv2.threshold(rgb_grey, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
|
_, spec_bin = cv2.threshold(spec_grey, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
|
spec_bin = cv2.resize(spec_bin, dsize=(rgb_bin.shape[1], rgb_bin.shape[0]))
|
|
search_area = np.zeros(search_area_size)
|
|
for x in range(0, search_area_size[0], eps):
|
|
for y in range(0, search_area_size[1], eps):
|
|
delta_x, delta_y = x - search_area_size[0] // 2, y - search_area_size[1] // 2
|
|
rgb_cross_area = self.get_cross_area(rgb_bin, delta_x, delta_y)
|
|
spce_cross_area = self.get_cross_area(spec_bin, -delta_x, -delta_y)
|
|
response_altitude = np.sum(np.sum(rgb_cross_area & spce_cross_area))
|
|
search_area[x, y] = response_altitude
|
|
delta = np.unravel_index(np.argmax(search_area), search_area.shape)
|
|
delta = (delta[0] - search_area_size[1] // 2, delta[1] - search_area_size[1] // 2)
|
|
delta_x, delta_y = delta
|
|
|
|
rgb_cross_area = self.get_cross_area(rgb_bin, delta_x, delta_y)
|
|
spce_cross_area = self.get_cross_area(spec_bin, -delta_x, -delta_y)
|
|
|
|
human_word = "SPEC is " + str(abs(delta_x)) + " pixels "
|
|
human_word += 'after' if delta_x >= 0 else ' before '
|
|
human_word += "RGB and " + str(abs(delta_y)) + " pixels "
|
|
human_word += "right " if delta_y >= 0 else "left "
|
|
human_word += "the RGB"
|
|
|
|
fig, axs = plt.subplots(3, 1)
|
|
axs[0].imshow(rgb_img)
|
|
axs[0].set_title("RGB img")
|
|
axs[1].imshow(spec_img)
|
|
axs[1].set_title("spec img")
|
|
axs[2].imshow(rgb_cross_area & spce_cross_area)
|
|
axs[2].set_title("cross part")
|
|
plt.suptitle(human_word)
|
|
plt.show()
|
|
|
|
print(human_word)
|
|
return delta
|
|
|
|
@staticmethod
|
|
def get_cross_area(img_bin, delta_x, delta_y):
|
|
if delta_x >= 0:
|
|
cross_area = img_bin[delta_x:, :]
|
|
else:
|
|
cross_area = img_bin[:delta_x, :]
|
|
if delta_y >= 0:
|
|
cross_area = cross_area[:, delta_y:]
|
|
else:
|
|
cross_area = cross_area[:, :delta_y]
|
|
return cross_area
|
|
|
|
|
|
if __name__ == '__main__':
|
|
import argparse
|
|
parser = argparse.ArgumentParser(description='Run image test or image_file')
|
|
parser.add_argument('path', default=r'E:\zhouchao\8.4\zazhi', help='测试文件或文件夹')
|
|
parser.add_argument('-test_rgb', default=True, action='store_false', help='是否测试RGB图')
|
|
parser.add_argument('-test_spec', default=True, action='store_false', help='是否测试光谱图')
|
|
parser.add_argument('-get_delta', default=False, action='store_true', help='是否进行误差计算')
|
|
parser.add_argument('-convert_dir', default=None, help='是否将c语言采集的buffer进行转换')
|
|
parser.add_argument('-s', '--silent', default=False, action='store_true', help='是否显示')
|
|
args = parser.parse_args()
|
|
# file path check
|
|
if args.convert_dir is not None:
|
|
if os.path.exists(args.convert_dir):
|
|
if not os.path.isdir(args.convert_dir):
|
|
raise TypeError("转换文件夹需要是个文件夹")
|
|
else:
|
|
if os.path.abspath(args.path) == os.path.abspath(args.convert_dir):
|
|
print("警告!您的输出文件夹和输入文件夹用了同一个位置")
|
|
else:
|
|
print(f"已创建需要存放转换文件的文件夹 {args.convert_dir}")
|
|
os.makedirs(args.convert_dir, mode=0o777, exist_ok=False)
|
|
tester = TestMain()
|
|
tester.pony_run(test_path=args.path, test_rgb=args.test_rgb, test_spectra=args.test_spec,
|
|
get_delta=args.get_delta, convert_dir=args.convert_dir, silent=args.silent)
|