mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 14:23:55 +00:00
294 lines
12 KiB
Python
294 lines
12 KiB
Python
# -*- codeing = utf-8 -*-
|
||
# Time : 2022/7/19 10:49
|
||
# @Auther : zhouchao
|
||
# @File: main_test.py
|
||
# @Software:PyCharm
|
||
import itertools
|
||
import logging
|
||
import os
|
||
import time
|
||
import socket
|
||
import typing
|
||
|
||
import cv2
|
||
import matplotlib.pyplot as plt
|
||
import numpy as np
|
||
|
||
import transmit
|
||
from config import Config
|
||
from models import Detector, AnonymousColorDetector, ManualTree, SpecDetector, RgbDetector
|
||
from utils import read_labeled_img, size_threshold, natural_sort
|
||
|
||
|
||
class TestMain:
|
||
def __init__(self):
|
||
self._spec_detector = SpecDetector(blk_model_path=Config.blk_model_path,
|
||
pixel_model_path=Config.pixel_model_path)
|
||
self._rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path,
|
||
background_model_path=Config.rgb_background_model_path)
|
||
|
||
def pony_run(self, test_path, test_spectra=False, test_rgb=False,
|
||
convert=False, get_delta=False):
|
||
"""
|
||
虚拟读图测试程序
|
||
|
||
:param test_path: 测试文件夹或者图片
|
||
:param test_spectra: 是否测试光谱
|
||
:param test_rgb: 是否测试rgb
|
||
:param convert: 是否进行格式转化
|
||
:return:
|
||
"""
|
||
if os.path.isdir(test_path):
|
||
rgb_file_names, spec_file_names = [[file_name for file_name in os.listdir(test_path) if
|
||
file_name.startswith(file_type)] for file_type in ['rgb', 'spec']]
|
||
rgb_file_names, spec_file_names = natural_sort(rgb_file_names), natural_sort(spec_file_names)
|
||
else:
|
||
if test_spectra:
|
||
with open(test_path, 'rb') as f:
|
||
data = f.read()
|
||
spec_img = transmit.BeforeAfterMethods.spec_data_post_process(data)
|
||
_ = self.test_spec(spec_img=spec_img, img_name=test_path)
|
||
elif test_rgb:
|
||
with open(test_path, 'rb') as f:
|
||
data = f.read()
|
||
rgb_img = transmit.BeforeAfterMethods.rgb_data_post_process(data)
|
||
_ = self.test_rgb(rgb_img, img_name=test_path)
|
||
return
|
||
for rgb_file_name, spec_file_name in zip(rgb_file_names, spec_file_names):
|
||
if test_spectra:
|
||
with open(os.path.join(test_path, spec_file_name), 'rb') as f:
|
||
data = f.read()
|
||
spec_img = transmit.BeforeAfterMethods.spec_data_post_process(data)
|
||
spec_mask = self.test_spec(spec_img, img_name=spec_file_name)
|
||
if test_rgb:
|
||
with open(os.path.join(test_path, rgb_file_name), 'rb') as f:
|
||
data = f.read()
|
||
rgb_img = transmit.BeforeAfterMethods.rgb_data_post_process(data)
|
||
rgb_mask = self.test_rgb(rgb_img, img_name=rgb_file_name)
|
||
if test_rgb and test_spectra:
|
||
if get_delta:
|
||
spec_cv = np.clip(spec_img[..., [21, 3, 0]], a_min=0, a_max=1) * 255
|
||
spec_cv = spec_cv.astype(np.uint8)
|
||
delta = self.calculate_delta(rgb_img, spec_cv)
|
||
print(delta)
|
||
self.merge(rgb_img=rgb_img, rgb_mask=rgb_mask,
|
||
spec_img=spec_img[..., [21, 3, 0]], spec_mask=spec_mask,
|
||
rgb_file_name=rgb_file_name, spec_file_name=spec_file_name)
|
||
|
||
def test_rgb(self, rgb_img, img_name):
|
||
rgb_mask = self._rgb_detector.predict(rgb_img)
|
||
fig, axs = plt.subplots(2, 1)
|
||
axs[0].imshow(rgb_img)
|
||
axs[0].set_title(f"rgb img {img_name}")
|
||
axs[1].imshow(rgb_mask)
|
||
axs[1].set_title('rgb mask')
|
||
plt.show()
|
||
return rgb_mask
|
||
|
||
def test_spec(self, spec_img, img_name):
|
||
spec_mask = self._spec_detector.predict(spec_img)
|
||
fig, axs = plt.subplots(2, 1)
|
||
axs[0].imshow(spec_img[..., [21, 3, 0]])
|
||
axs[0].set_title(f"spec img {img_name}")
|
||
axs[1].imshow(spec_mask)
|
||
axs[1].set_title('spec mask')
|
||
plt.show()
|
||
return spec_mask
|
||
|
||
@staticmethod
|
||
def merge(rgb_img, rgb_mask, spec_img, spec_mask, rgb_file_name, spec_file_name):
|
||
mask_result = (spec_mask | rgb_mask).astype(np.uint8)
|
||
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
|
||
fig, axs = plt.subplots(3, 2)
|
||
axs[0, 0].set_title(rgb_file_name)
|
||
axs[0, 0].imshow(rgb_img)
|
||
axs[1, 0].imshow(spec_img)
|
||
axs[1, 0].set_title(spec_file_name)
|
||
axs[2, 0].imshow(mask_result)
|
||
axs[0, 1].imshow(rgb_mask)
|
||
axs[1, 1].imshow(spec_mask)
|
||
axs[2, 1].imshow(mask_result)
|
||
plt.show()
|
||
return mask_result
|
||
|
||
def calculate_delta(self, rgb_img, spec_img, search_area_size=(400, 200), eps=1):
|
||
rgb_grey, spec_grey = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2GRAY), cv2.cvtColor(spec_img, cv2.COLOR_RGB2GRAY)
|
||
_, rgb_bin = cv2.threshold(rgb_grey, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||
_, spec_bin = cv2.threshold(spec_grey, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||
spec_bin = cv2.resize(spec_bin, dsize=(rgb_bin.shape[1], rgb_bin.shape[0]))
|
||
search_area = np.zeros(search_area_size)
|
||
for x in range(0, search_area_size[0], eps):
|
||
for y in range(0, search_area_size[1], eps):
|
||
delta_x, delta_y = x - search_area_size[0] // 2, y - search_area_size[1] // 2
|
||
rgb_cross_area = self.get_cross_area(rgb_bin, delta_x, delta_y)
|
||
spce_cross_area = self.get_cross_area(spec_bin, -delta_x, -delta_y)
|
||
response_altitude = np.sum(np.sum(rgb_cross_area & spce_cross_area))
|
||
search_area[x, y] = response_altitude
|
||
delta = np.unravel_index(np.argmax(search_area), search_area.shape)
|
||
delta = (delta[0] - search_area_size[1]//2, delta[1] - search_area_size[1]//2)
|
||
delta_x, delta_y = delta
|
||
|
||
rgb_cross_area = self.get_cross_area(rgb_bin, delta_x, delta_y)
|
||
spce_cross_area = self.get_cross_area(spec_bin, -delta_x, -delta_y)
|
||
|
||
human_word = "SPEC is " + str(abs(delta_x)) + " pixels "
|
||
human_word += 'after' if delta_x >= 0 else ' before '
|
||
human_word += "RGB and " + str(abs(delta_y)) + " pixels "
|
||
human_word += "right " if delta_y >= 0 else "left "
|
||
human_word += "the RGB"
|
||
|
||
fig, axs = plt.subplots(3, 1)
|
||
axs[0].imshow(rgb_img)
|
||
axs[0].set_title("RGB img")
|
||
axs[1].imshow(spec_img)
|
||
axs[1].set_title("spec img")
|
||
axs[2].imshow(rgb_cross_area & spce_cross_area)
|
||
axs[2].set_title("cross part")
|
||
plt.suptitle(human_word)
|
||
plt.show()
|
||
|
||
print(human_word)
|
||
return delta
|
||
|
||
@staticmethod
|
||
def get_cross_area(img_bin, delta_x, delta_y):
|
||
if delta_x >= 0:
|
||
cross_area = img_bin[delta_x:, :]
|
||
else:
|
||
cross_area = img_bin[:delta_x, :]
|
||
if delta_y >= 0:
|
||
cross_area = cross_area[:, delta_y:]
|
||
else:
|
||
cross_area = cross_area[:, :delta_y]
|
||
return cross_area
|
||
|
||
|
||
class ValveTest:
|
||
def __init__(self):
|
||
self.reminder = """
|
||
快,给我个指令:
|
||
a. 开始命令 st e. 设置 光谱(a)相机 的延时,格式 e,500
|
||
b. 停止命令 sp f. 设置 彩色(b)相机 的延时, 格式 f,500
|
||
c. 设置光谱相机分频, 得是4的倍数而且>=8, 格式: c,8 g. 发个da和db完全重叠的mask
|
||
d. 阀板的脉冲分频系数, >=2即可 h. 发个da和db呈现出X形的mask
|
||
或者你给我个小于256的数字,我就测试对应的喷阀,按q键我就退出。\n
|
||
"""
|
||
self.s = socket.socket() # 创建 socket 对象
|
||
host = socket.gethostname() # 获取本地主机名
|
||
port = 13452 # 设置端口
|
||
self.s.bind((host, port)) # 绑定端口
|
||
self.s.listen(5) # 等待客户端连接
|
||
|
||
def run(self):
|
||
while True:
|
||
logging.info("我在等连接...")
|
||
c, addr = self.s.accept() # 建立客户端连接
|
||
print('Connection Established:', addr)
|
||
value = input(self.reminder)
|
||
if value == 'q':
|
||
break
|
||
else:
|
||
self.process_cmd(value)
|
||
c.close() # 关闭连接
|
||
|
||
def pad_cmd(self, cmd):
|
||
return b'\xAA'+cmd+b'\xFF\xFF\xBB'
|
||
|
||
def param_cmd_parser(self, cmd, default_value, checker=None):
|
||
try:
|
||
value = int(cmd.split(',')[-1])
|
||
except:
|
||
print(f'你给的值不对啊,我先给你弄个{default_value}吧')
|
||
value = default_value
|
||
if checker is not None:
|
||
if not checker(value):
|
||
return None
|
||
return value
|
||
|
||
def process_cmd(self, value):
|
||
if value == 'a':
|
||
# a.开始命令
|
||
cmd = b'\x00\x03' + 'sa'.encode('ascii') + b'\xFF'
|
||
elif value == 'b':
|
||
# b.停止命令
|
||
cmd = b'\x00\x03' + 'sb'.encode('ascii') + b'\xFF'
|
||
elif value.startswith('c'):
|
||
# c. 设置光谱相机分频,得是4的倍数而且>=8,格式:c,8
|
||
checker = lambda x: (x // 4 == 0) and (x >= 8)
|
||
value = self.param_cmd_parser(value, default_value=8, checker=checker)
|
||
if value is None:
|
||
print("值需要是4的倍数且大于8")
|
||
return
|
||
cmd = b'\x00\x0a' + 'sc'.encode('ascii') + f"{value:08d}".encode('ascii')
|
||
elif value.startswith('d'):
|
||
# d. 阀板的脉冲分频系数,>=2即可
|
||
checker = lambda x: x >= 2
|
||
value = self.param_cmd_parser(value, default_value=2, checker=checker)
|
||
if value is None:
|
||
print("你得大于等于2")
|
||
return
|
||
cmd = b'\x00\x0a' + 'sv'.encode('ascii') + f"{value:08d}".encode('ascii')
|
||
elif value.startswith('e'):
|
||
# e. 设置 光谱(a)相机 的延时,格式 e,500
|
||
checker = lambda x: (x >= 0)
|
||
value = self.param_cmd_parser(value, default_value=2, checker=checker)
|
||
if value is None:
|
||
print("你得大于等于0")
|
||
return
|
||
cmd = b'\x00\x0a' + 'sa'.encode('ascii') + f"{value:08d}".encode('ascii')
|
||
elif value.startswith('f'):
|
||
# f. 设置 RGB(b)相机 的延时,格式 e,500
|
||
checker = lambda x: (x >= 0)
|
||
value = self.param_cmd_parser(value, default_value=2, checker=checker)
|
||
if value is None:
|
||
print("你得大于等于0")
|
||
return
|
||
cmd = b'\x00\x0a' + 'sb'.encode('ascii') + f"{value:08d}".encode('ascii')
|
||
elif value == 'g':
|
||
# g.发个da和db完全重叠的mask
|
||
mask_a, mask_b = np.eye(256, dtype=np.uint8), np.eye(256, dtype=np.uint8)
|
||
len_a, data_a = self.format_data(mask_a)
|
||
len_b, data_b = self.format_data(mask_b)
|
||
cmd = len_a + 'da'.encode('ascii') + mask_a
|
||
self.send(cmd)
|
||
cmd = len_b + 'db'.encode('ascii') + mask_b
|
||
elif value == 'h':
|
||
# h.发个da和db呈现出X形的mask
|
||
mask_a, mask_b = np.eye(256, dtype=np.uint8), np.eye(256, dtype=np.uint8).T
|
||
len_a, data_a = self.format_data(mask_a)
|
||
len_b, data_b = self.format_data(mask_b)
|
||
cmd = len_a + 'da'.encode('ascii') + mask_a
|
||
self.send(cmd)
|
||
cmd = len_b + 'db'.encode('ascii') + mask_b
|
||
else:
|
||
try:
|
||
value = int(value)
|
||
except Exception as e:
|
||
print(e)
|
||
print(f"你给的指令: {value} 咋看都不对")
|
||
return
|
||
self.send(cmd)
|
||
|
||
def send(self, cmd: bytes) -> None:
|
||
print("我要send 这个了:\n")
|
||
print(cmd)
|
||
cmd = self.pad_cmd(cmd)
|
||
try:
|
||
self.s.send(cmd)
|
||
except Exception as e:
|
||
print(f"发失败了, 这是我找到的错误信息\n{e}")
|
||
return
|
||
print("发好了")
|
||
|
||
def format_data(self, array_to_send: np.ndarray) -> (bytes, bytes):
|
||
data = np.packbits(array_to_send, axis=-1)
|
||
data = data.tobytes()
|
||
data_len = (len(data)+2).to_bytes(2, 'big')
|
||
return data_len, data
|
||
|
||
|
||
if __name__ == '__main__':
|
||
testor = TestMain()
|
||
testor.pony_run(test_path=r'/home/lzy/2022.7.30/tobacco_v1_0/saved_img/',
|
||
test_rgb=True, test_spectra=True, get_delta=False)
|