mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 14:23:55 +00:00
多线程版本
This commit is contained in:
parent
1bd6e7bbe5
commit
a0d14940e8
51
efficient_ui.py
Normal file
51
efficient_ui.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import transmit
|
||||||
|
from config import Config
|
||||||
|
|
||||||
|
from utils import ImgQueue as Queue
|
||||||
|
|
||||||
|
|
||||||
|
class EfficientUI(object):
|
||||||
|
def __init__(self):
|
||||||
|
# 相关参数
|
||||||
|
img_fifo_path = "/tmp/dkimg.fifo"
|
||||||
|
mask_fifo_path = "/tmp/dkmask.fifo"
|
||||||
|
rgb_fifo_path = "/tmp/dkrgb.fifo"
|
||||||
|
# 创建队列用于链接各个线程
|
||||||
|
rgb_img_queue, spec_img_queue = Queue(), Queue()
|
||||||
|
detector_queue, save_queue, self.visual_queue = Queue(), Queue, Queue()
|
||||||
|
mask_queue = Queue()
|
||||||
|
# 两个接收者,接收光谱和rgb图像
|
||||||
|
spec_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节
|
||||||
|
rgb_len = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量
|
||||||
|
spec_receiver = transmit.FifoReceiver(fifo_path=img_fifo_path, output=spec_img_queue, read_max_num=spec_len)
|
||||||
|
rgb_receiver = transmit.FifoReceiver(fifo_path=rgb_fifo_path, output=rgb_img_queue, read_max_num=rgb_len)
|
||||||
|
# 指令执行与图像流向控制
|
||||||
|
subscribers = {'detector': detector_queue, 'visualize': self.visual_queue, 'save': save_queue}
|
||||||
|
cmd_img_controller = transmit.CmdImgSplitMidware(rgb_queue=rgb_img_queue, spec_queue=spec_img_queue,
|
||||||
|
subscribers=subscribers)
|
||||||
|
# 探测器
|
||||||
|
detector = transmit.ThreadDetector(input_queue=detector_queue, output_queue=mask_queue)
|
||||||
|
# 发送
|
||||||
|
sender = transmit.FifoSender(output_fifo_path=mask_fifo_path, source=mask_queue)
|
||||||
|
# 启动所有线程
|
||||||
|
spec_receiver.start(post_process_func=transmit.PostProcessMethods.spec_data_post_process, name='spce_thread')
|
||||||
|
rgb_receiver.start(post_process_func=transmit.PostProcessMethods.rgb_data_post_process, name='rgb_thread')
|
||||||
|
cmd_img_controller.start(name='control_thread')
|
||||||
|
detector.start(name='detector_thread')
|
||||||
|
sender.start(name='sender_thread')
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
# 启动图形化
|
||||||
|
while True:
|
||||||
|
cv2.imshow('image_show', mat=np.ones((256, 1024)))
|
||||||
|
key_code = cv2.waitKey(30)
|
||||||
|
if key_code == ord("s"):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
app = EfficientUI()
|
||||||
|
app.start()
|
||||||
14
main.py
14
main.py
@ -1,12 +1,15 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from queue import Queue
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.io
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
|
import models
|
||||||
|
import transmit
|
||||||
|
|
||||||
from config import Config
|
from config import Config
|
||||||
from models import RgbDetector, SpecDetector, ManualTree, AnonymousColorDetector
|
from models import RgbDetector, SpecDetector
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
@ -56,20 +59,19 @@ def main():
|
|||||||
mask = spec_detector.predict(img_data)
|
mask = spec_detector.predict(img_data)
|
||||||
# rgb识别
|
# rgb识别
|
||||||
mask_rgb = rgb_detector.predict(rgb_data)
|
mask_rgb = rgb_detector.predict(rgb_data)
|
||||||
|
|
||||||
# 结果合并
|
# 结果合并
|
||||||
mask_result = (mask | mask_rgb).astype(np.uint8)
|
mask_result = (mask | mask_rgb).astype(np.uint8)
|
||||||
|
|
||||||
# mask_result = mask_rgb.astype(np.uint8)
|
# mask_result = mask_rgb.astype(np.uint8)
|
||||||
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
|
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
|
||||||
t2 = time.time()
|
t2 = time.time()
|
||||||
|
print(f'rgb len = {len(rgb_data)}')
|
||||||
|
|
||||||
# 写出
|
# 写出
|
||||||
fd_mask = os.open(mask_fifo_path, os.O_WRONLY)
|
fd_mask = os.open(mask_fifo_path, os.O_WRONLY)
|
||||||
os.write(fd_mask, mask_result.tobytes())
|
os.write(fd_mask, mask_result.tobytes())
|
||||||
os.close(fd_mask)
|
os.close(fd_mask)
|
||||||
t3 = time.time()
|
t3 = time.time()
|
||||||
print(f'total time is:{t3 - t1}\n')
|
print(f'total time is:{t3 - t1}')
|
||||||
|
|
||||||
|
|
||||||
def read_c_captures(buffer_path, no_mask=True, nrows=256, ncols=1024, selected_bands=None):
|
def read_c_captures(buffer_path, no_mask=True, nrows=256, ncols=1024, selected_bands=None):
|
||||||
|
|||||||
@ -5,10 +5,12 @@
|
|||||||
# @Software:PyCharm、
|
# @Software:PyCharm、
|
||||||
import datetime
|
import datetime
|
||||||
import pickle
|
import pickle
|
||||||
|
from queue import Queue
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.io
|
import scipy.io
|
||||||
|
import threading
|
||||||
from scipy.ndimage import binary_dilation
|
from scipy.ndimage import binary_dilation
|
||||||
from sklearn.tree import DecisionTreeClassifier
|
from sklearn.tree import DecisionTreeClassifier
|
||||||
from sklearn.metrics import classification_report
|
from sklearn.metrics import classification_report
|
||||||
@ -17,7 +19,8 @@ from sklearn.model_selection import train_test_split
|
|||||||
from config import Config
|
from config import Config
|
||||||
from utils import lab_scatter, read_labeled_img, size_threshold
|
from utils import lab_scatter, read_labeled_img, size_threshold
|
||||||
|
|
||||||
deploy = True
|
|
||||||
|
deploy = False
|
||||||
if not deploy:
|
if not deploy:
|
||||||
print("Training env")
|
print("Training env")
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|||||||
214
transmit.py
214
transmit.py
@ -1,11 +1,13 @@
|
|||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
from queue import Queue
|
from utils import ImgQueue as Queue
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from config import Config
|
from config import Config
|
||||||
|
from models import SpecDetector, RgbDetector
|
||||||
|
import typing
|
||||||
|
|
||||||
|
|
||||||
class Receiver(object):
|
class Transmitter(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.output = None
|
self.output = None
|
||||||
|
|
||||||
@ -45,17 +47,42 @@ class Receiver(object):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class FifoReceiver(Receiver):
|
class PostProcessMethods:
|
||||||
def __init__(self, fifo_path: str, output: Queue, read_max_num: int):
|
@classmethod
|
||||||
|
def spec_data_post_process(cls, data):
|
||||||
|
if len(data) < 3:
|
||||||
|
threshold = int(float(data))
|
||||||
|
print("[INFO] Get Spec threshold: ", threshold)
|
||||||
|
return threshold
|
||||||
|
else:
|
||||||
|
spec_img = np.frombuffer(data, dtype=np.float32).\
|
||||||
|
reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1)
|
||||||
|
return spec_img
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def rgb_data_post_process(cls, data):
|
||||||
|
if len(data) < 3:
|
||||||
|
threshold = int(float(data))
|
||||||
|
print("[INFO] Get RGB threshold: ", threshold)
|
||||||
|
return threshold
|
||||||
|
else:
|
||||||
|
rgb_img = np.frombuffer(data, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
|
||||||
|
return rgb_img
|
||||||
|
|
||||||
|
|
||||||
|
class FifoReceiver(Transmitter):
|
||||||
|
def __init__(self, fifo_path: str, output: Queue, read_max_num: int, msg_queue=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._input_fifo_path = None
|
self._input_fifo_path = None
|
||||||
self._output_queue = None
|
self._output_queue = None
|
||||||
|
self._msg_queue = msg_queue
|
||||||
self._max_len = read_max_num
|
self._max_len = read_max_num
|
||||||
|
|
||||||
self.set_source(fifo_path)
|
self.set_source(fifo_path)
|
||||||
self.set_output(output)
|
self.set_output(output)
|
||||||
self._need_stop = threading.Event()
|
self._need_stop = threading.Event()
|
||||||
self._need_stop.clear()
|
self._need_stop.clear()
|
||||||
|
self._running_thread = None
|
||||||
|
|
||||||
def set_source(self, fifo_path: str):
|
def set_source(self, fifo_path: str):
|
||||||
if not os.access(fifo_path, os.F_OK):
|
if not os.access(fifo_path, os.F_OK):
|
||||||
@ -66,9 +93,9 @@ class FifoReceiver(Receiver):
|
|||||||
self._output_queue = output
|
self._output_queue = output
|
||||||
|
|
||||||
def start(self, post_process_func=None, name='fifo_receiver'):
|
def start(self, post_process_func=None, name='fifo_receiver'):
|
||||||
x = threading.Thread(target=self._receive_thread_func,
|
self._running_thread = threading.Thread(target=self._receive_thread_func,
|
||||||
name=name, args=(post_process_func, ))
|
name=name, args=(post_process_func, ))
|
||||||
x.start()
|
self._running_thread.start()
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
self._need_stop.set()
|
self._need_stop.set()
|
||||||
@ -85,27 +112,166 @@ class FifoReceiver(Receiver):
|
|||||||
data = os.read(input_fifo, self._max_len)
|
data = os.read(input_fifo, self._max_len)
|
||||||
if post_process_func is not None:
|
if post_process_func is not None:
|
||||||
data = post_process_func(data)
|
data = post_process_func(data)
|
||||||
self._output_queue.put(data)
|
if not self._output_queue.safe_put(data):
|
||||||
|
if self._msg_queue is not None:
|
||||||
|
self._msg_queue.put('Fifo Receiver的接收者未来得及接收')
|
||||||
os.close(input_fifo)
|
os.close(input_fifo)
|
||||||
self._need_stop.clear()
|
self._need_stop.clear()
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def spec_data_post_process(data):
|
class FifoSender(Transmitter):
|
||||||
if len(data) < 3:
|
def __init__(self, output_fifo_path: str, source: Queue):
|
||||||
threshold = int(float(data))
|
super().__init__()
|
||||||
print("[INFO] Get Spec threshold: ", threshold)
|
self._input_source = None
|
||||||
return threshold
|
self._output_fifo_path = None
|
||||||
else:
|
self.set_source(source)
|
||||||
spec_img = np.frombuffer(data, dtype=np.float32).\
|
self.set_output(output_fifo_path)
|
||||||
reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1)
|
self._need_stop = threading.Event()
|
||||||
return spec_img
|
self._need_stop.clear()
|
||||||
|
self._running_thread = None
|
||||||
|
|
||||||
|
def set_source(self, source: Queue):
|
||||||
|
self._input_source = source
|
||||||
|
|
||||||
|
def set_output(self, output_fifo_path: str):
|
||||||
|
if not os.access(output_fifo_path, os.F_OK):
|
||||||
|
os.mkfifo(output_fifo_path, 0o777)
|
||||||
|
self._output_fifo_path = output_fifo_path
|
||||||
|
|
||||||
|
def start(self, pre_process=None, name='fifo_receiver'):
|
||||||
|
self._running_thread = threading.Thread(target=self._send_thread_func, name=name,
|
||||||
|
args=(pre_process, ))
|
||||||
|
self._running_thread.start()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self._need_stop.set()
|
||||||
|
|
||||||
|
def _send_thread_func(self, pre_process=None):
|
||||||
|
"""
|
||||||
|
接收线程
|
||||||
|
|
||||||
|
:param pre_process:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
while not self._need_stop.is_set():
|
||||||
|
output_fifo = os.open(self._output_fifo_path, os.O_WRONLY)
|
||||||
|
if self._input_source.empty():
|
||||||
|
continue
|
||||||
|
data = self._input_source.get()
|
||||||
|
if pre_process is not None:
|
||||||
|
data = pre_process(data)
|
||||||
|
os.write(output_fifo, data)
|
||||||
|
os.close(output_fifo)
|
||||||
|
self._need_stop.clear()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def rgb_data_post_process(data):
|
def mask_preprocess(mask: np.ndarray):
|
||||||
if len(data) < 3:
|
return mask.tobytes()
|
||||||
threshold = int(float(data))
|
|
||||||
print("[INFO] Get RGB threshold: ", threshold)
|
def __del__(self):
|
||||||
return threshold
|
self.stop()
|
||||||
|
if self._running_thread is not None:
|
||||||
|
self._running_thread.join()
|
||||||
|
|
||||||
|
|
||||||
|
class CmdImgSplitMidware(Transmitter):
|
||||||
|
"""
|
||||||
|
用于控制命令和图像的中间件
|
||||||
|
"""
|
||||||
|
def __init__(self, subscribers: typing.Dict[str, Queue], rgb_queue: Queue, spec_queue: Queue):
|
||||||
|
super().__init__()
|
||||||
|
self._rgb_queue = None
|
||||||
|
self._spec_queue = None
|
||||||
|
self._subscribers = None
|
||||||
|
self._server_thread = None
|
||||||
|
self.set_source(rgb_queue, spec_queue)
|
||||||
|
self.set_output(subscribers)
|
||||||
|
self.thread_stop = threading.Event()
|
||||||
|
|
||||||
|
def set_source(self, rgb_queue: Queue, spec_queue: Queue):
|
||||||
|
self._rgb_queue = rgb_queue
|
||||||
|
self._spec_queue = spec_queue
|
||||||
|
|
||||||
|
def set_output(self, output: typing.Dict[str, Queue]):
|
||||||
|
self._subscribers = output
|
||||||
|
|
||||||
|
def start(self, name='CMD_thread'):
|
||||||
|
self._server_thread = threading.Thread(target=self._cmd_control_service)
|
||||||
|
self._server_thread.start()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self.thread_stop.set()
|
||||||
|
|
||||||
|
def _cmd_control_service(self):
|
||||||
|
while not self.thread_stop.is_set():
|
||||||
|
# 判断是否有数据,如果没有数据那么就等下次吧,如果有数据来,必须保证同时
|
||||||
|
if self._rgb_queue.empty() or self._spec_queue.empty():
|
||||||
|
continue
|
||||||
|
rgb_data = self._rgb_queue.get()
|
||||||
|
spec_data = self._spec_queue.get()
|
||||||
|
if isinstance(rgb_data, int) and isinstance(spec_data, int):
|
||||||
|
# 看是不是命令需要执行如果是命令,就执行
|
||||||
|
Config.rgb_size_threshold = rgb_data
|
||||||
|
Config.spec_size_threshold = spec_data
|
||||||
|
continue
|
||||||
|
elif isinstance(spec_data, np.ndarray) and isinstance(rgb_data, np.ndarray):
|
||||||
|
# 如果是图片,交给预测的人
|
||||||
|
for _, subscriber in self._subscribers.items():
|
||||||
|
subscriber.put((spec_data, rgb_data))
|
||||||
else:
|
else:
|
||||||
rgb_img = np.frombuffer(data, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
|
# 否则程序出现毁灭性问题,立刻崩
|
||||||
return rgb_img
|
raise Exception("两个相机传回的数据没有对上")
|
||||||
|
self.thread_stop.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class ImageSaver(Transmitter):
|
||||||
|
"""
|
||||||
|
进行图片存储的中间件
|
||||||
|
"""
|
||||||
|
def set_source(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def start(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def stop(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadDetector(Transmitter):
|
||||||
|
def __init__(self, input_queue: Queue, output_queue: Queue):
|
||||||
|
super().__init__()
|
||||||
|
self._input_queue, self._output_queue = input_queue, output_queue
|
||||||
|
self._spec_detector = SpecDetector(blk_model_path=Config.blk_model_path,
|
||||||
|
pixel_model_path=Config.pixel_model_path)
|
||||||
|
self._rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path,
|
||||||
|
background_model_path=Config.rgb_background_model_path)
|
||||||
|
self._predict_thread = None
|
||||||
|
self._thread_exit = threading.Event()
|
||||||
|
|
||||||
|
def set_source(self, img_queue: Queue):
|
||||||
|
self._input_queue = img_queue
|
||||||
|
|
||||||
|
def stop(self, *args, **kwargs):
|
||||||
|
self._thread_exit.set()
|
||||||
|
|
||||||
|
def start(self, name='predict_thread'):
|
||||||
|
self._predict_thread = threading.Thread(target=self._predict_server, name=name)
|
||||||
|
self._predict_thread.start()
|
||||||
|
|
||||||
|
def predict(self, spec: np.ndarray, rgb: np.ndarray):
|
||||||
|
mask = self._spec_detector.predict(spec)
|
||||||
|
# rgb识别
|
||||||
|
mask_rgb = self._rgb_detector.predict(rgb)
|
||||||
|
# 结果合并
|
||||||
|
mask_result = (mask | mask_rgb).astype(np.uint8)
|
||||||
|
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
|
||||||
|
return mask_result
|
||||||
|
|
||||||
|
def _predict_server(self):
|
||||||
|
while not self._thread_exit.is_set():
|
||||||
|
if not self._input_queue.empty():
|
||||||
|
spec, rgb = self._input_queue.get()
|
||||||
|
mask = self.predict(spec, rgb)
|
||||||
|
self._output_queue.put(mask)
|
||||||
|
self._thread_exit.clear()
|
||||||
|
|||||||
29
utils.py
29
utils.py
@ -5,6 +5,7 @@
|
|||||||
# @Software:PyCharm
|
# @Software:PyCharm
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
|
from queue import Queue
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -26,6 +27,34 @@ class MergeDict(dict):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class ImgQueue(Queue):
|
||||||
|
"""
|
||||||
|
A custom queue subclass that provides a :meth:`clear` method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
"""
|
||||||
|
Clears all items from the queue.
|
||||||
|
"""
|
||||||
|
|
||||||
|
with self.mutex:
|
||||||
|
unfinished = self.unfinished_tasks - len(self.queue)
|
||||||
|
if unfinished <= 0:
|
||||||
|
if unfinished < 0:
|
||||||
|
raise ValueError('task_done() called too many times')
|
||||||
|
self.all_tasks_done.notify_all()
|
||||||
|
self.unfinished_tasks = unfinished
|
||||||
|
self.queue.clear()
|
||||||
|
self.not_full.notify_all()
|
||||||
|
|
||||||
|
def safe_put(self, *args, **kwargs):
|
||||||
|
if self.full():
|
||||||
|
_ = self.get()
|
||||||
|
return False
|
||||||
|
self.put(*args, **kwargs)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def read_labeled_img(dataset_dir: str, color_dict: dict, ext='.bmp', is_ps_color_space=True) -> dict:
|
def read_labeled_img(dataset_dir: str, color_dict: dict, ext='.bmp', is_ps_color_space=True) -> dict:
|
||||||
"""
|
"""
|
||||||
根据dataset_dir下的文件创建数据集
|
根据dataset_dir下的文件创建数据集
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user