From a0d14940e850472d5229a12b2765fd97f012a009 Mon Sep 17 00:00:00 2001 From: "li.zhenye" Date: Thu, 28 Jul 2022 12:39:54 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A4=9A=E7=BA=BF=E7=A8=8B=E7=89=88=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- efficient_ui.py | 51 +++++++++++ main.py | 14 ++-- models.py | 5 +- transmit.py | 218 ++++++++++++++++++++++++++++++++++++++++++------ utils.py | 29 +++++++ 5 files changed, 284 insertions(+), 33 deletions(-) create mode 100644 efficient_ui.py diff --git a/efficient_ui.py b/efficient_ui.py new file mode 100644 index 0000000..dd627b6 --- /dev/null +++ b/efficient_ui.py @@ -0,0 +1,51 @@ +import cv2 +import numpy as np + +import transmit +from config import Config + +from utils import ImgQueue as Queue + + +class EfficientUI(object): + def __init__(self): + # 相关参数 + img_fifo_path = "/tmp/dkimg.fifo" + mask_fifo_path = "/tmp/dkmask.fifo" + rgb_fifo_path = "/tmp/dkrgb.fifo" + # 创建队列用于链接各个线程 + rgb_img_queue, spec_img_queue = Queue(), Queue() + detector_queue, save_queue, self.visual_queue = Queue(), Queue, Queue() + mask_queue = Queue() + # 两个接收者,接收光谱和rgb图像 + spec_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节 + rgb_len = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量 + spec_receiver = transmit.FifoReceiver(fifo_path=img_fifo_path, output=spec_img_queue, read_max_num=spec_len) + rgb_receiver = transmit.FifoReceiver(fifo_path=rgb_fifo_path, output=rgb_img_queue, read_max_num=rgb_len) + # 指令执行与图像流向控制 + subscribers = {'detector': detector_queue, 'visualize': self.visual_queue, 'save': save_queue} + cmd_img_controller = transmit.CmdImgSplitMidware(rgb_queue=rgb_img_queue, spec_queue=spec_img_queue, + subscribers=subscribers) + # 探测器 + detector = transmit.ThreadDetector(input_queue=detector_queue, output_queue=mask_queue) + # 发送 + sender = transmit.FifoSender(output_fifo_path=mask_fifo_path, source=mask_queue) + # 启动所有线程 + spec_receiver.start(post_process_func=transmit.PostProcessMethods.spec_data_post_process, name='spce_thread') + rgb_receiver.start(post_process_func=transmit.PostProcessMethods.rgb_data_post_process, name='rgb_thread') + cmd_img_controller.start(name='control_thread') + detector.start(name='detector_thread') + sender.start(name='sender_thread') + + def start(self): + # 启动图形化 + while True: + cv2.imshow('image_show', mat=np.ones((256, 1024))) + key_code = cv2.waitKey(30) + if key_code == ord("s"): + pass + + +if __name__ == '__main__': + app = EfficientUI() + app.start() diff --git a/main.py b/main.py index ee51cf4..9f4a8f6 100755 --- a/main.py +++ b/main.py @@ -1,12 +1,15 @@ import os import time +from queue import Queue -import matplotlib.pyplot as plt import numpy as np -import scipy.io +from matplotlib import pyplot as plt + +import models +import transmit from config import Config -from models import RgbDetector, SpecDetector, ManualTree, AnonymousColorDetector +from models import RgbDetector, SpecDetector import cv2 @@ -56,20 +59,19 @@ def main(): mask = spec_detector.predict(img_data) # rgb识别 mask_rgb = rgb_detector.predict(rgb_data) - # 结果合并 mask_result = (mask | mask_rgb).astype(np.uint8) - # mask_result = mask_rgb.astype(np.uint8) mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8) t2 = time.time() + print(f'rgb len = {len(rgb_data)}') # 写出 fd_mask = os.open(mask_fifo_path, os.O_WRONLY) os.write(fd_mask, mask_result.tobytes()) os.close(fd_mask) t3 = time.time() - print(f'total time is:{t3 - t1}\n') + print(f'total time is:{t3 - t1}') def read_c_captures(buffer_path, no_mask=True, nrows=256, ncols=1024, selected_bands=None): diff --git a/models.py b/models.py index 181651d..f503b96 100755 --- a/models.py +++ b/models.py @@ -5,10 +5,12 @@ # @Software:PyCharm、 import datetime import pickle +from queue import Queue import cv2 import numpy as np import scipy.io +import threading from scipy.ndimage import binary_dilation from sklearn.tree import DecisionTreeClassifier from sklearn.metrics import classification_report @@ -17,7 +19,8 @@ from sklearn.model_selection import train_test_split from config import Config from utils import lab_scatter, read_labeled_img, size_threshold -deploy = True + +deploy = False if not deploy: print("Training env") from tqdm import tqdm diff --git a/transmit.py b/transmit.py index b990718..c0fed63 100644 --- a/transmit.py +++ b/transmit.py @@ -1,11 +1,13 @@ import os import threading -from queue import Queue +from utils import ImgQueue as Queue import numpy as np from config import Config +from models import SpecDetector, RgbDetector +import typing -class Receiver(object): +class Transmitter(object): def __init__(self): self.output = None @@ -45,17 +47,42 @@ class Receiver(object): raise NotImplementedError -class FifoReceiver(Receiver): - def __init__(self, fifo_path: str, output: Queue, read_max_num: int): +class PostProcessMethods: + @classmethod + def spec_data_post_process(cls, data): + if len(data) < 3: + threshold = int(float(data)) + print("[INFO] Get Spec threshold: ", threshold) + return threshold + else: + spec_img = np.frombuffer(data, dtype=np.float32).\ + reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1) + return spec_img + + @classmethod + def rgb_data_post_process(cls, data): + if len(data) < 3: + threshold = int(float(data)) + print("[INFO] Get RGB threshold: ", threshold) + return threshold + else: + rgb_img = np.frombuffer(data, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1)) + return rgb_img + + +class FifoReceiver(Transmitter): + def __init__(self, fifo_path: str, output: Queue, read_max_num: int, msg_queue=None): super().__init__() self._input_fifo_path = None self._output_queue = None + self._msg_queue = msg_queue self._max_len = read_max_num self.set_source(fifo_path) self.set_output(output) self._need_stop = threading.Event() self._need_stop.clear() + self._running_thread = None def set_source(self, fifo_path: str): if not os.access(fifo_path, os.F_OK): @@ -66,9 +93,9 @@ class FifoReceiver(Receiver): self._output_queue = output def start(self, post_process_func=None, name='fifo_receiver'): - x = threading.Thread(target=self._receive_thread_func, - name=name, args=(post_process_func, )) - x.start() + self._running_thread = threading.Thread(target=self._receive_thread_func, + name=name, args=(post_process_func, )) + self._running_thread.start() def stop(self): self._need_stop.set() @@ -85,27 +112,166 @@ class FifoReceiver(Receiver): data = os.read(input_fifo, self._max_len) if post_process_func is not None: data = post_process_func(data) - self._output_queue.put(data) + if not self._output_queue.safe_put(data): + if self._msg_queue is not None: + self._msg_queue.put('Fifo Receiver的接收者未来得及接收') os.close(input_fifo) self._need_stop.clear() - @staticmethod - def spec_data_post_process(data): - if len(data) < 3: - threshold = int(float(data)) - print("[INFO] Get Spec threshold: ", threshold) - return threshold - else: - spec_img = np.frombuffer(data, dtype=np.float32).\ - reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1) - return spec_img + +class FifoSender(Transmitter): + def __init__(self, output_fifo_path: str, source: Queue): + super().__init__() + self._input_source = None + self._output_fifo_path = None + self.set_source(source) + self.set_output(output_fifo_path) + self._need_stop = threading.Event() + self._need_stop.clear() + self._running_thread = None + + def set_source(self, source: Queue): + self._input_source = source + + def set_output(self, output_fifo_path: str): + if not os.access(output_fifo_path, os.F_OK): + os.mkfifo(output_fifo_path, 0o777) + self._output_fifo_path = output_fifo_path + + def start(self, pre_process=None, name='fifo_receiver'): + self._running_thread = threading.Thread(target=self._send_thread_func, name=name, + args=(pre_process, )) + self._running_thread.start() + + def stop(self): + self._need_stop.set() + + def _send_thread_func(self, pre_process=None): + """ + 接收线程 + + :param pre_process: + :return: + """ + while not self._need_stop.is_set(): + output_fifo = os.open(self._output_fifo_path, os.O_WRONLY) + if self._input_source.empty(): + continue + data = self._input_source.get() + if pre_process is not None: + data = pre_process(data) + os.write(output_fifo, data) + os.close(output_fifo) + self._need_stop.clear() @staticmethod - def rgb_data_post_process(data): - if len(data) < 3: - threshold = int(float(data)) - print("[INFO] Get RGB threshold: ", threshold) - return threshold - else: - rgb_img = np.frombuffer(data, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1)) - return rgb_img + def mask_preprocess(mask: np.ndarray): + return mask.tobytes() + + def __del__(self): + self.stop() + if self._running_thread is not None: + self._running_thread.join() + + +class CmdImgSplitMidware(Transmitter): + """ + 用于控制命令和图像的中间件 + """ + def __init__(self, subscribers: typing.Dict[str, Queue], rgb_queue: Queue, spec_queue: Queue): + super().__init__() + self._rgb_queue = None + self._spec_queue = None + self._subscribers = None + self._server_thread = None + self.set_source(rgb_queue, spec_queue) + self.set_output(subscribers) + self.thread_stop = threading.Event() + + def set_source(self, rgb_queue: Queue, spec_queue: Queue): + self._rgb_queue = rgb_queue + self._spec_queue = spec_queue + + def set_output(self, output: typing.Dict[str, Queue]): + self._subscribers = output + + def start(self, name='CMD_thread'): + self._server_thread = threading.Thread(target=self._cmd_control_service) + self._server_thread.start() + + def stop(self): + self.thread_stop.set() + + def _cmd_control_service(self): + while not self.thread_stop.is_set(): + # 判断是否有数据,如果没有数据那么就等下次吧,如果有数据来,必须保证同时 + if self._rgb_queue.empty() or self._spec_queue.empty(): + continue + rgb_data = self._rgb_queue.get() + spec_data = self._spec_queue.get() + if isinstance(rgb_data, int) and isinstance(spec_data, int): + # 看是不是命令需要执行如果是命令,就执行 + Config.rgb_size_threshold = rgb_data + Config.spec_size_threshold = spec_data + continue + elif isinstance(spec_data, np.ndarray) and isinstance(rgb_data, np.ndarray): + # 如果是图片,交给预测的人 + for _, subscriber in self._subscribers.items(): + subscriber.put((spec_data, rgb_data)) + else: + # 否则程序出现毁灭性问题,立刻崩 + raise Exception("两个相机传回的数据没有对上") + self.thread_stop.clear() + + +class ImageSaver(Transmitter): + """ + 进行图片存储的中间件 + """ + def set_source(self, *args, **kwargs): + pass + + def start(self, *args, **kwargs): + pass + + def stop(self, *args, **kwargs): + pass + + +class ThreadDetector(Transmitter): + def __init__(self, input_queue: Queue, output_queue: Queue): + super().__init__() + self._input_queue, self._output_queue = input_queue, output_queue + self._spec_detector = SpecDetector(blk_model_path=Config.blk_model_path, + pixel_model_path=Config.pixel_model_path) + self._rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path, + background_model_path=Config.rgb_background_model_path) + self._predict_thread = None + self._thread_exit = threading.Event() + + def set_source(self, img_queue: Queue): + self._input_queue = img_queue + + def stop(self, *args, **kwargs): + self._thread_exit.set() + + def start(self, name='predict_thread'): + self._predict_thread = threading.Thread(target=self._predict_server, name=name) + self._predict_thread.start() + + def predict(self, spec: np.ndarray, rgb: np.ndarray): + mask = self._spec_detector.predict(spec) + # rgb识别 + mask_rgb = self._rgb_detector.predict(rgb) + # 结果合并 + mask_result = (mask | mask_rgb).astype(np.uint8) + mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8) + return mask_result + + def _predict_server(self): + while not self._thread_exit.is_set(): + if not self._input_queue.empty(): + spec, rgb = self._input_queue.get() + mask = self.predict(spec, rgb) + self._output_queue.put(mask) + self._thread_exit.clear() diff --git a/utils.py b/utils.py index 25607e1..f004242 100755 --- a/utils.py +++ b/utils.py @@ -5,6 +5,7 @@ # @Software:PyCharm import glob import os +from queue import Queue import cv2 import numpy as np @@ -26,6 +27,34 @@ class MergeDict(dict): return self +class ImgQueue(Queue): + """ + A custom queue subclass that provides a :meth:`clear` method. + """ + + def clear(self): + """ + Clears all items from the queue. + """ + + with self.mutex: + unfinished = self.unfinished_tasks - len(self.queue) + if unfinished <= 0: + if unfinished < 0: + raise ValueError('task_done() called too many times') + self.all_tasks_done.notify_all() + self.unfinished_tasks = unfinished + self.queue.clear() + self.not_full.notify_all() + + def safe_put(self, *args, **kwargs): + if self.full(): + _ = self.get() + return False + self.put(*args, **kwargs) + return True + + def read_labeled_img(dataset_dir: str, color_dict: dict, ext='.bmp', is_ps_color_space=True) -> dict: """ 根据dataset_dir下的文件创建数据集