mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 14:23:55 +00:00
Merge branch 'master' of https://github.com/FEIJINTI/tobacco_color
This commit is contained in:
commit
cdbe341591
@ -20,7 +20,7 @@ class Config:
|
||||
# 光谱模型参数
|
||||
blk_size = 4
|
||||
pixel_model_path = r"./models/dt.p"
|
||||
blk_model_path = r"./models/rf_4x4_c22_20_sen8_8.model"
|
||||
blk_model_path = r"./models/rf_4x4_c22_20_sen8_9.model"
|
||||
spec_size_threshold = 3
|
||||
|
||||
# rgb模型参数
|
||||
|
||||
51
efficient_ui.py
Normal file
51
efficient_ui.py
Normal file
@ -0,0 +1,51 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
import transmit
|
||||
from config import Config
|
||||
|
||||
from utils import ImgQueue as Queue
|
||||
|
||||
|
||||
class EfficientUI(object):
|
||||
def __init__(self):
|
||||
# 相关参数
|
||||
img_fifo_path = "/tmp/dkimg.fifo"
|
||||
mask_fifo_path = "/tmp/dkmask.fifo"
|
||||
rgb_fifo_path = "/tmp/dkrgb.fifo"
|
||||
# 创建队列用于链接各个线程
|
||||
rgb_img_queue, spec_img_queue = Queue(), Queue()
|
||||
detector_queue, save_queue, self.visual_queue = Queue(), Queue(), Queue()
|
||||
mask_queue = Queue()
|
||||
# 两个接收者,接收光谱和rgb图像
|
||||
spec_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节
|
||||
rgb_len = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量
|
||||
spec_receiver = transmit.FifoReceiver(fifo_path=img_fifo_path, output=spec_img_queue, read_max_num=spec_len)
|
||||
rgb_receiver = transmit.FifoReceiver(fifo_path=rgb_fifo_path, output=rgb_img_queue, read_max_num=rgb_len)
|
||||
# 指令执行与图像流向控制
|
||||
subscribers = {'detector': detector_queue, 'visualize': self.visual_queue, 'save': save_queue}
|
||||
cmd_img_controller = transmit.CmdImgSplitMidware(rgb_queue=rgb_img_queue, spec_queue=spec_img_queue,
|
||||
subscribers=subscribers)
|
||||
# 探测器
|
||||
detector = transmit.ThreadDetector(input_queue=detector_queue, output_queue=mask_queue)
|
||||
# 发送
|
||||
sender = transmit.FifoSender(output_fifo_path=mask_fifo_path, source=mask_queue)
|
||||
# 启动所有线程
|
||||
spec_receiver.start(post_process_func=transmit.BeforeAfterMethods.spec_data_post_process, name='spce_thread')
|
||||
rgb_receiver.start(post_process_func=transmit.BeforeAfterMethods.rgb_data_post_process, name='rgb_thread')
|
||||
cmd_img_controller.start(name='control_thread')
|
||||
detector.start(name='detector_thread')
|
||||
sender.start(pre_process=transmit.BeforeAfterMethods.mask_preprocess, name='sender_thread')
|
||||
|
||||
def start(self):
|
||||
# 启动图形化
|
||||
while True:
|
||||
cv2.imshow('image_show', mat=np.ones((256, 1024)))
|
||||
key_code = cv2.waitKey(30)
|
||||
if key_code == ord("s"):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app = EfficientUI()
|
||||
app.start()
|
||||
142
main.py
142
main.py
@ -1,16 +1,18 @@
|
||||
import os
|
||||
import time
|
||||
from queue import Queue
|
||||
|
||||
import numpy as np
|
||||
import scipy.io
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
import models
|
||||
import transmit
|
||||
|
||||
from config import Config
|
||||
from models import RgbDetector, SpecDetector, ManualTree, AnonymousColorDetector
|
||||
from models import RgbDetector, SpecDetector
|
||||
import cv2
|
||||
|
||||
|
||||
SAVE_IMG, SAVE_NUM = False, 30
|
||||
|
||||
|
||||
def main():
|
||||
spec_detector = SpecDetector(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path)
|
||||
rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path,
|
||||
@ -23,8 +25,6 @@ def main():
|
||||
os.mkfifo(mask_fifo_path, 0o777)
|
||||
if not os.access(rgb_fifo_path, os.F_OK):
|
||||
os.mkfifo(rgb_fifo_path, 0o777)
|
||||
if SAVE_IMG:
|
||||
img_list = []
|
||||
while True:
|
||||
fd_img = os.open(img_fifo_path, os.O_RDONLY)
|
||||
fd_rgb = os.open(rgb_fifo_path, os.O_RDONLY)
|
||||
@ -55,19 +55,12 @@ def main():
|
||||
img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)) \
|
||||
.transpose(0, 2, 1)
|
||||
rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
|
||||
if SAVE_IMG:
|
||||
SAVE_NUM -= 1
|
||||
img_list.append((rgb_data, img_data))
|
||||
if SAVE_NUM <= 0:
|
||||
break
|
||||
# 光谱识别
|
||||
mask = spec_detector.predict(img_data)
|
||||
# rgb识别
|
||||
mask_rgb = rgb_detector.predict(rgb_data)
|
||||
|
||||
# 结果合并
|
||||
mask_result = (mask | mask_rgb).astype(np.uint8)
|
||||
|
||||
# mask_result = mask_rgb.astype(np.uint8)
|
||||
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
|
||||
t2 = time.time()
|
||||
@ -79,100 +72,37 @@ def main():
|
||||
os.close(fd_mask)
|
||||
t3 = time.time()
|
||||
print(f'total time is:{t3 - t1}')
|
||||
for i, img in enumerate(img_list):
|
||||
print(f"writing img {i}...")
|
||||
cv2.imwrite(f"./{i}.png", img[0][..., ::-1])
|
||||
np.save(f'./{i}.npy', img[1])
|
||||
i += 1
|
||||
|
||||
|
||||
|
||||
def save_main():
|
||||
threshold = Config.spec_size_threshold
|
||||
rgb_threshold = Config.rgb_size_threshold
|
||||
manual_tree = ManualTree(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path)
|
||||
tobacco_detector = AnonymousColorDetector(file_path=Config.rgb_tobacco_model_path)
|
||||
background_detector = AnonymousColorDetector(file_path=Config.rgb_background_model_path)
|
||||
total_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节
|
||||
total_rgb = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量
|
||||
if not os.access(img_fifo_path, os.F_OK):
|
||||
os.mkfifo(img_fifo_path, 0o777)
|
||||
if not os.access(mask_fifo_path, os.F_OK):
|
||||
os.mkfifo(mask_fifo_path, 0o777)
|
||||
if not os.access(rgb_fifo_path, os.F_OK):
|
||||
os.mkfifo(rgb_fifo_path, 0o777)
|
||||
img_list = []
|
||||
idx = 0
|
||||
while idx <= 30:
|
||||
idx += 1
|
||||
fd_img = os.open(img_fifo_path, os.O_RDONLY)
|
||||
fd_rgb = os.open(rgb_fifo_path, os.O_RDONLY)
|
||||
data = os.read(fd_img, total_len)
|
||||
|
||||
# 读取(开启一个管道)
|
||||
if len(data) < 3:
|
||||
threshold = int(float(data))
|
||||
print("[INFO] Get threshold: ", threshold)
|
||||
continue
|
||||
def read_c_captures(buffer_path, no_mask=True, nrows=256, ncols=1024, selected_bands=None):
|
||||
if os.path.isdir(buffer_path):
|
||||
buffer_names = [buffer_name for buffer_name in os.listdir(buffer_path) if buffer_name.endswith('.raw')]
|
||||
else:
|
||||
buffer_names = [buffer_path, ]
|
||||
for buffer_name in buffer_names:
|
||||
with open(os.path.join(buffer_path, buffer_name), 'rb') as f:
|
||||
data = f.read()
|
||||
img = np.frombuffer(data, dtype=np.float32).reshape((nrows, -1, ncols)) \
|
||||
.transpose(0, 2, 1)
|
||||
if selected_bands is not None:
|
||||
img = img[..., selected_bands]
|
||||
if img.shape[0] == 1:
|
||||
img = img[0, ...]
|
||||
if not no_mask:
|
||||
mask_name = buffer_name.replace('buf', 'mask').replace('.raw', '')
|
||||
with open(os.path.join(buffer_path, mask_name), 'rb') as f:
|
||||
data = f.read()
|
||||
mask = np.frombuffer(data, dtype=np.uint8).reshape((nrows, ncols, -1))
|
||||
else:
|
||||
data_total = data
|
||||
rgb_data = os.read(fd_rgb, total_rgb)
|
||||
if len(rgb_data) < 3:
|
||||
rgb_threshold = int(float(rgb_data))
|
||||
print(rgb_threshold)
|
||||
continue
|
||||
else:
|
||||
rgb_data_total = rgb_data
|
||||
os.close(fd_img)
|
||||
os.close(fd_rgb)
|
||||
|
||||
# 识别
|
||||
t1 = time.time()
|
||||
img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)). \
|
||||
transpose(0, 2, 1)
|
||||
rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
|
||||
img_list.append((rgb_data.copy(), img_data.copy()))
|
||||
|
||||
pixel_predict_result = manual_tree.pixel_predict_ml_dilation(data=img_data, iteration=1)
|
||||
blk_predict_result = manual_tree.blk_predict(data=img_data)
|
||||
rgb_data = tobacco_detector.pretreatment(rgb_data)
|
||||
# print(rgb_data.shape)
|
||||
rgb_predict_result = 1 - (background_detector.predict(rgb_data, threshold_low=Config.threshold_low,
|
||||
threshold_high=Config.threshold_high) |
|
||||
tobacco_detector.swell(tobacco_detector.predict(rgb_data,
|
||||
threshold_low=Config.threshold_low,
|
||||
threshold_high=Config.threshold_high)))
|
||||
mask_rgb = rgb_predict_result.reshape(Config.nRows, Config.nCols // Config.blk_size, Config.blk_size) \
|
||||
.sum(axis=2).reshape(Config.nRows // 4, Config.blk_size, Config.nCols // Config.blk_size) \
|
||||
.sum(axis=1)
|
||||
mask_rgb[mask_rgb <= rgb_threshold] = 0
|
||||
mask_rgb[mask_rgb > rgb_threshold] = 1
|
||||
mask = (pixel_predict_result & blk_predict_result).astype(np.uint8)
|
||||
mask = mask.reshape(Config.nRows, Config.nCols // Config.blk_size, Config.blk_size) \
|
||||
.sum(axis=2).reshape(Config.nRows // 4, Config.blk_size, Config.nCols // Config.blk_size) \
|
||||
.sum(axis=1)
|
||||
mask[mask <= threshold] = 0
|
||||
mask[mask > threshold] = 1
|
||||
mask_result = (mask | mask_rgb).astype(np.uint8)
|
||||
# mask_result = mask_rgb
|
||||
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
|
||||
t2 = time.time()
|
||||
print(f'rgb len = {len(rgb_data)}')
|
||||
|
||||
# 写出
|
||||
fd_mask = os.open(mask_fifo_path, os.O_WRONLY)
|
||||
os.write(fd_mask, mask_result.tobytes())
|
||||
os.close(fd_mask)
|
||||
t3 = time.time()
|
||||
print(f'total time is:{t3 - t1}')
|
||||
i = 0
|
||||
print("Stop Serving")
|
||||
for img in img_list:
|
||||
print(f"writing img {i}...")
|
||||
cv2.imwrite(f"./{i}.png", img[0][..., ::-1])
|
||||
np.save(f'./{i}.npy', img[1])
|
||||
i += 1
|
||||
print("save success")
|
||||
mask_name = "no mask"
|
||||
mask = np.zeros_like(img)
|
||||
# mask = cv2.resize(mask, (1024, 256))
|
||||
fig, axs = plt.subplots(2, 1)
|
||||
axs[0].matshow(img)
|
||||
axs[0].set_title(buffer_name)
|
||||
axs[1].imshow(mask)
|
||||
axs[1].set_title(mask_name)
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
@ -183,3 +113,5 @@ if __name__ == '__main__':
|
||||
|
||||
# 主函数
|
||||
main()
|
||||
# read_c_captures('/home/lzy/2022.7.15/tobacco_v1_0/', no_mask=True, nrows=256, ncols=1024,
|
||||
# selected_bands=[380, 300, 200])
|
||||
|
||||
@ -5,10 +5,12 @@
|
||||
# @Software:PyCharm、
|
||||
import datetime
|
||||
import pickle
|
||||
from queue import Queue
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import scipy.io
|
||||
import threading
|
||||
from scipy.ndimage import binary_dilation
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
from sklearn.metrics import classification_report
|
||||
@ -17,6 +19,7 @@ from sklearn.model_selection import train_test_split
|
||||
from config import Config
|
||||
from utils import lab_scatter, read_labeled_img, size_threshold
|
||||
|
||||
|
||||
deploy = True
|
||||
if not deploy:
|
||||
print("Training env")
|
||||
@ -415,7 +418,7 @@ class SpecDetector(Detector):
|
||||
# 烟梗mask中将背景赋值为0,将烟梗赋值为2
|
||||
yellow_things[yellow_things] = tobacco
|
||||
yellow_things = yellow_things + 0
|
||||
# yellow_things = binary_dilation(yellow_things, iterations=iteration)
|
||||
yellow_things = binary_dilation(yellow_things, iterations=iteration)
|
||||
yellow_things = yellow_things + 0
|
||||
yellow_things[yellow_things == 1] = 2
|
||||
|
||||
|
||||
236
transmit.py
236
transmit.py
@ -1,11 +1,18 @@
|
||||
import os
|
||||
import threading
|
||||
from queue import Queue
|
||||
import time
|
||||
|
||||
from utils import ImgQueue as Queue
|
||||
import numpy as np
|
||||
from config import Config
|
||||
from models import SpecDetector, RgbDetector
|
||||
import typing
|
||||
import logging
|
||||
logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s %(message)s',
|
||||
level=logging.DEBUG)
|
||||
|
||||
|
||||
class Receiver(object):
|
||||
class Transmitter(object):
|
||||
def __init__(self):
|
||||
self.output = None
|
||||
|
||||
@ -45,17 +52,49 @@ class Receiver(object):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FifoReceiver(Receiver):
|
||||
def __init__(self, fifo_path: str, output: Queue, read_max_num: int):
|
||||
class BeforeAfterMethods:
|
||||
@classmethod
|
||||
def mask_preprocess(cls, mask: np.ndarray):
|
||||
logging.info(f"Send mask with size {mask.shape}")
|
||||
return mask.tobytes()
|
||||
|
||||
@classmethod
|
||||
def spec_data_post_process(cls, data):
|
||||
if len(data) < 3:
|
||||
threshold = int(float(data))
|
||||
logging.info(f"Get Spec threshold: {threshold}")
|
||||
return threshold
|
||||
else:
|
||||
spec_img = np.frombuffer(data, dtype=np.float32).\
|
||||
reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1)
|
||||
logging.info(f"Get SPEC image with size {spec_img.shape}")
|
||||
return spec_img
|
||||
|
||||
@classmethod
|
||||
def rgb_data_post_process(cls, data):
|
||||
if len(data) < 3:
|
||||
threshold = int(float(data))
|
||||
logging.info(f"Get RGB threshold: {threshold}")
|
||||
return threshold
|
||||
else:
|
||||
rgb_img = np.frombuffer(data, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
|
||||
logging.info(f"Get RGB img with size {rgb_img.shape}")
|
||||
return rgb_img
|
||||
|
||||
|
||||
class FifoReceiver(Transmitter):
|
||||
def __init__(self, fifo_path: str, output: Queue, read_max_num: int, msg_queue=None):
|
||||
super().__init__()
|
||||
self._input_fifo_path = None
|
||||
self._output_queue = None
|
||||
self._msg_queue = msg_queue
|
||||
self._max_len = read_max_num
|
||||
|
||||
self.set_source(fifo_path)
|
||||
self.set_output(output)
|
||||
self._need_stop = threading.Event()
|
||||
self._need_stop.clear()
|
||||
self._running_thread = None
|
||||
|
||||
def set_source(self, fifo_path: str):
|
||||
if not os.access(fifo_path, os.F_OK):
|
||||
@ -66,9 +105,9 @@ class FifoReceiver(Receiver):
|
||||
self._output_queue = output
|
||||
|
||||
def start(self, post_process_func=None, name='fifo_receiver'):
|
||||
x = threading.Thread(target=self._receive_thread_func,
|
||||
name=name, args=(post_process_func, ))
|
||||
x.start()
|
||||
self._running_thread = threading.Thread(target=self._receive_thread_func,
|
||||
name=name, args=(post_process_func, ))
|
||||
self._running_thread.start()
|
||||
|
||||
def stop(self):
|
||||
self._need_stop.set()
|
||||
@ -85,27 +124,170 @@ class FifoReceiver(Receiver):
|
||||
data = os.read(input_fifo, self._max_len)
|
||||
if post_process_func is not None:
|
||||
data = post_process_func(data)
|
||||
self._output_queue.put(data)
|
||||
self._output_queue.safe_put(data)
|
||||
os.close(input_fifo)
|
||||
self._need_stop.clear()
|
||||
|
||||
@staticmethod
|
||||
def spec_data_post_process(data):
|
||||
if len(data) < 3:
|
||||
threshold = int(float(data))
|
||||
print("[INFO] Get Spec threshold: ", threshold)
|
||||
return threshold
|
||||
else:
|
||||
spec_img = np.frombuffer(data, dtype=np.float32).\
|
||||
reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1)
|
||||
return spec_img
|
||||
|
||||
@staticmethod
|
||||
def rgb_data_post_process(data):
|
||||
if len(data) < 3:
|
||||
threshold = int(float(data))
|
||||
print("[INFO] Get RGB threshold: ", threshold)
|
||||
return threshold
|
||||
else:
|
||||
rgb_img = np.frombuffer(data, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
|
||||
return rgb_img
|
||||
class FifoSender(Transmitter):
|
||||
def __init__(self, output_fifo_path: str, source: Queue):
|
||||
super().__init__()
|
||||
self._input_source = None
|
||||
self._output_fifo_path = None
|
||||
self.set_source(source)
|
||||
self.set_output(output_fifo_path)
|
||||
self._need_stop = threading.Event()
|
||||
self._need_stop.clear()
|
||||
self._running_thread = None
|
||||
|
||||
def set_source(self, source: Queue):
|
||||
self._input_source = source
|
||||
|
||||
def set_output(self, output_fifo_path: str):
|
||||
if not os.access(output_fifo_path, os.F_OK):
|
||||
os.mkfifo(output_fifo_path, 0o777)
|
||||
self._output_fifo_path = output_fifo_path
|
||||
|
||||
def start(self, pre_process=None, name='fifo_receiver'):
|
||||
self._running_thread = threading.Thread(target=self._send_thread_func, name=name,
|
||||
args=(pre_process, ))
|
||||
self._running_thread.start()
|
||||
|
||||
def stop(self):
|
||||
self._need_stop.set()
|
||||
|
||||
def _send_thread_func(self, pre_process=None):
|
||||
"""
|
||||
接收线程
|
||||
|
||||
:param pre_process:
|
||||
:return:
|
||||
"""
|
||||
while not self._need_stop.is_set():
|
||||
if self._input_source.empty():
|
||||
continue
|
||||
data = self._input_source.get()
|
||||
if pre_process is not None:
|
||||
data = pre_process(data)
|
||||
output_fifo = os.open(self._output_fifo_path, os.O_WRONLY)
|
||||
os.write(output_fifo, data)
|
||||
os.close(output_fifo)
|
||||
self._need_stop.clear()
|
||||
|
||||
def __del__(self):
|
||||
self.stop()
|
||||
if self._running_thread is not None:
|
||||
self._running_thread.join()
|
||||
|
||||
|
||||
class CmdImgSplitMidware(Transmitter):
|
||||
"""
|
||||
用于控制命令和图像的中间件
|
||||
"""
|
||||
def __init__(self, subscribers: typing.Dict[str, Queue], rgb_queue: Queue, spec_queue: Queue):
|
||||
super().__init__()
|
||||
self._rgb_queue = None
|
||||
self._spec_queue = None
|
||||
self._subscribers = None
|
||||
self._server_thread = None
|
||||
self.set_source(rgb_queue, spec_queue)
|
||||
self.set_output(subscribers)
|
||||
self.thread_stop = threading.Event()
|
||||
|
||||
def set_source(self, rgb_queue: Queue, spec_queue: Queue):
|
||||
self._rgb_queue = rgb_queue
|
||||
self._spec_queue = spec_queue
|
||||
|
||||
def set_output(self, output: typing.Dict[str, Queue]):
|
||||
self._subscribers = output
|
||||
|
||||
def start(self, name='CMD_thread'):
|
||||
self._server_thread = threading.Thread(target=self._cmd_control_service, name=name)
|
||||
self._server_thread.start()
|
||||
|
||||
def stop(self):
|
||||
self.thread_stop.set()
|
||||
|
||||
def _cmd_control_service(self):
|
||||
while not self.thread_stop.is_set():
|
||||
# 判断是否有数据,如果没有数据那么就等下次吧,如果有数据来,必须保证同时
|
||||
if self._rgb_queue.empty() or self._spec_queue.empty():
|
||||
continue
|
||||
rgb_data = self._rgb_queue.get()
|
||||
spec_data = self._spec_queue.get()
|
||||
if isinstance(rgb_data, int) and isinstance(spec_data, int):
|
||||
# 看是不是命令需要执行如果是命令,就执行
|
||||
Config.rgb_size_threshold = rgb_data
|
||||
Config.spec_size_threshold = spec_data
|
||||
continue
|
||||
elif isinstance(spec_data, np.ndarray) and isinstance(rgb_data, np.ndarray):
|
||||
# 如果是图片,交给预测的人
|
||||
for name, subscriber in self._subscribers.items():
|
||||
item = (spec_data, rgb_data)
|
||||
subscriber.safe_put(item)
|
||||
else:
|
||||
# 否则程序出现毁灭性问题,立刻崩
|
||||
raise Exception("两个相机传回的数据没有对上")
|
||||
self.thread_stop.clear()
|
||||
|
||||
|
||||
class ImageSaver(Transmitter):
|
||||
"""
|
||||
进行图片存储的中间件
|
||||
"""
|
||||
def set_source(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def start(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def stop(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
class ThreadDetector(Transmitter):
|
||||
def __init__(self, input_queue: Queue, output_queue: Queue):
|
||||
super().__init__()
|
||||
self._input_queue, self._output_queue = input_queue, output_queue
|
||||
self._spec_detector = SpecDetector(blk_model_path=Config.blk_model_path,
|
||||
pixel_model_path=Config.pixel_model_path)
|
||||
self._rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path,
|
||||
background_model_path=Config.rgb_background_model_path)
|
||||
self._predict_thread = None
|
||||
self._thread_exit = threading.Event()
|
||||
|
||||
def set_source(self, img_queue: Queue):
|
||||
self._input_queue = img_queue
|
||||
|
||||
def stop(self, *args, **kwargs):
|
||||
self._thread_exit.set()
|
||||
|
||||
def start(self, name='predict_thread'):
|
||||
self._predict_thread = threading.Thread(target=self._predict_server, name=name)
|
||||
self._predict_thread.start()
|
||||
|
||||
def predict(self, spec: np.ndarray, rgb: np.ndarray):
|
||||
logging.info(f'Detector get image with shape {spec.shape} and {rgb.shape}')
|
||||
t1 = time.time()
|
||||
mask = self._spec_detector.predict(spec)
|
||||
t2 = time.time()
|
||||
logging.info(f'Detector finish spec predict within {(t2 - t1) * 1000:.2f}ms')
|
||||
# rgb识别
|
||||
mask_rgb = self._rgb_detector.predict(rgb)
|
||||
t3 = time.time()
|
||||
logging.info(f'Detector finish rgb predict within {(t3 - t2) * 1000:.2f}ms')
|
||||
# 结果合并
|
||||
mask_result = (mask | mask_rgb).astype(np.uint8)
|
||||
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
|
||||
t4 = time.time()
|
||||
logging.info(f'Detector finish merge within {(t4 - t3) * 1000: .2f}ms')
|
||||
logging.info(f'Detector finish predict within {(time.time() -t1)*1000:.2f}ms')
|
||||
return mask_result
|
||||
|
||||
def _predict_server(self):
|
||||
while not self._thread_exit.is_set():
|
||||
if not self._input_queue.empty():
|
||||
spec, rgb = self._input_queue.get()
|
||||
mask = self.predict(spec, rgb)
|
||||
self._output_queue.safe_put(mask)
|
||||
self._thread_exit.clear()
|
||||
|
||||
29
utils.py
29
utils.py
@ -5,6 +5,7 @@
|
||||
# @Software:PyCharm
|
||||
import glob
|
||||
import os
|
||||
from queue import Queue
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@ -26,6 +27,34 @@ class MergeDict(dict):
|
||||
return self
|
||||
|
||||
|
||||
class ImgQueue(Queue):
|
||||
"""
|
||||
A custom queue subclass that provides a :meth:`clear` method.
|
||||
"""
|
||||
|
||||
def clear(self):
|
||||
"""
|
||||
Clears all items from the queue.
|
||||
"""
|
||||
|
||||
with self.mutex:
|
||||
unfinished = self.unfinished_tasks - len(self.queue)
|
||||
if unfinished <= 0:
|
||||
if unfinished < 0:
|
||||
raise ValueError('task_done() called too many times')
|
||||
self.all_tasks_done.notify_all()
|
||||
self.unfinished_tasks = unfinished
|
||||
self.queue.clear()
|
||||
self.not_full.notify_all()
|
||||
|
||||
def safe_put(self, item):
|
||||
if self.full():
|
||||
_ = self.get()
|
||||
return False
|
||||
self.put(item)
|
||||
return True
|
||||
|
||||
|
||||
def read_labeled_img(dataset_dir: str, color_dict: dict, ext='.bmp', is_ps_color_space=True) -> dict:
|
||||
"""
|
||||
根据dataset_dir下的文件创建数据集
|
||||
|
||||
Loading…
Reference in New Issue
Block a user