From d697855abc79f7ce12f76d8655a9ba67c9576c41 Mon Sep 17 00:00:00 2001 From: "li.zhenye" <李> Date: Thu, 28 Jul 2022 13:47:32 +0800 Subject: [PATCH] Slow MultiThread Version --- efficient_ui.py | 8 ++++---- models.py | 2 +- transmit.py | 46 +++++++++++++++++++++++++++++++--------------- utils.py | 4 ++-- 4 files changed, 38 insertions(+), 22 deletions(-) diff --git a/efficient_ui.py b/efficient_ui.py index dd627b6..0c74260 100644 --- a/efficient_ui.py +++ b/efficient_ui.py @@ -15,7 +15,7 @@ class EfficientUI(object): rgb_fifo_path = "/tmp/dkrgb.fifo" # 创建队列用于链接各个线程 rgb_img_queue, spec_img_queue = Queue(), Queue() - detector_queue, save_queue, self.visual_queue = Queue(), Queue, Queue() + detector_queue, save_queue, self.visual_queue = Queue(), Queue(), Queue() mask_queue = Queue() # 两个接收者,接收光谱和rgb图像 spec_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节 @@ -31,11 +31,11 @@ class EfficientUI(object): # 发送 sender = transmit.FifoSender(output_fifo_path=mask_fifo_path, source=mask_queue) # 启动所有线程 - spec_receiver.start(post_process_func=transmit.PostProcessMethods.spec_data_post_process, name='spce_thread') - rgb_receiver.start(post_process_func=transmit.PostProcessMethods.rgb_data_post_process, name='rgb_thread') + spec_receiver.start(post_process_func=transmit.BeforeAfterMethods.spec_data_post_process, name='spce_thread') + rgb_receiver.start(post_process_func=transmit.BeforeAfterMethods.rgb_data_post_process, name='rgb_thread') cmd_img_controller.start(name='control_thread') detector.start(name='detector_thread') - sender.start(name='sender_thread') + sender.start(pre_process=transmit.BeforeAfterMethods.mask_preprocess, name='sender_thread') def start(self): # 启动图形化 diff --git a/models.py b/models.py index f503b96..abb1a74 100755 --- a/models.py +++ b/models.py @@ -20,7 +20,7 @@ from config import Config from utils import lab_scatter, read_labeled_img, size_threshold -deploy = False +deploy = True if not deploy: print("Training env") from tqdm import tqdm diff --git a/transmit.py b/transmit.py index c0fed63..2632a12 100644 --- a/transmit.py +++ b/transmit.py @@ -1,10 +1,15 @@ import os import threading +import time + from utils import ImgQueue as Queue import numpy as np from config import Config from models import SpecDetector, RgbDetector import typing +import logging +logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s %(message)s', + level=logging.DEBUG) class Transmitter(object): @@ -47,26 +52,33 @@ class Transmitter(object): raise NotImplementedError -class PostProcessMethods: +class BeforeAfterMethods: + @classmethod + def mask_preprocess(cls, mask: np.ndarray): + logging.info(f"Send mask with size {mask.shape}") + return mask.tobytes() + @classmethod def spec_data_post_process(cls, data): if len(data) < 3: threshold = int(float(data)) - print("[INFO] Get Spec threshold: ", threshold) + logging.info(f"Get Spec threshold: {threshold}") return threshold else: spec_img = np.frombuffer(data, dtype=np.float32).\ reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1) + logging.info(f"Get SPEC image with size {spec_img.shape}") return spec_img @classmethod def rgb_data_post_process(cls, data): if len(data) < 3: threshold = int(float(data)) - print("[INFO] Get RGB threshold: ", threshold) + logging.info(f"Get RGB threshold: {threshold}") return threshold else: rgb_img = np.frombuffer(data, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1)) + logging.info(f"Get RGB img with size {rgb_img.shape}") return rgb_img @@ -112,9 +124,7 @@ class FifoReceiver(Transmitter): data = os.read(input_fifo, self._max_len) if post_process_func is not None: data = post_process_func(data) - if not self._output_queue.safe_put(data): - if self._msg_queue is not None: - self._msg_queue.put('Fifo Receiver的接收者未来得及接收') + self._output_queue.safe_put(data) os.close(input_fifo) self._need_stop.clear() @@ -154,20 +164,16 @@ class FifoSender(Transmitter): :return: """ while not self._need_stop.is_set(): - output_fifo = os.open(self._output_fifo_path, os.O_WRONLY) if self._input_source.empty(): continue data = self._input_source.get() if pre_process is not None: data = pre_process(data) + output_fifo = os.open(self._output_fifo_path, os.O_WRONLY) os.write(output_fifo, data) os.close(output_fifo) self._need_stop.clear() - @staticmethod - def mask_preprocess(mask: np.ndarray): - return mask.tobytes() - def __del__(self): self.stop() if self._running_thread is not None: @@ -196,7 +202,7 @@ class CmdImgSplitMidware(Transmitter): self._subscribers = output def start(self, name='CMD_thread'): - self._server_thread = threading.Thread(target=self._cmd_control_service) + self._server_thread = threading.Thread(target=self._cmd_control_service, name=name) self._server_thread.start() def stop(self): @@ -216,8 +222,9 @@ class CmdImgSplitMidware(Transmitter): continue elif isinstance(spec_data, np.ndarray) and isinstance(rgb_data, np.ndarray): # 如果是图片,交给预测的人 - for _, subscriber in self._subscribers.items(): - subscriber.put((spec_data, rgb_data)) + for name, subscriber in self._subscribers.items(): + item = (spec_data, rgb_data) + subscriber.safe_put(item) else: # 否则程序出现毁灭性问题,立刻崩 raise Exception("两个相机传回的数据没有对上") @@ -260,12 +267,21 @@ class ThreadDetector(Transmitter): self._predict_thread.start() def predict(self, spec: np.ndarray, rgb: np.ndarray): + logging.info(f'Detector get image with shape {spec.shape} and {rgb.shape}') + t1 = time.time() mask = self._spec_detector.predict(spec) + t2 = time.time() + logging.info(f'Detector finish spec predict within {(t2 - t1) * 1000:.2f}ms') # rgb识别 mask_rgb = self._rgb_detector.predict(rgb) + t3 = time.time() + logging.info(f'Detector finish rgb predict within {(t3 - t2) * 1000:.2f}ms') # 结果合并 mask_result = (mask | mask_rgb).astype(np.uint8) mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8) + t4 = time.time() + logging.info(f'Detector finish merge within {(t4 - t3) * 1000: .2f}ms') + logging.info(f'Detector finish predict within {(time.time() -t1)*1000:.2f}ms') return mask_result def _predict_server(self): @@ -273,5 +289,5 @@ class ThreadDetector(Transmitter): if not self._input_queue.empty(): spec, rgb = self._input_queue.get() mask = self.predict(spec, rgb) - self._output_queue.put(mask) + self._output_queue.safe_put(mask) self._thread_exit.clear() diff --git a/utils.py b/utils.py index f004242..bf466f8 100755 --- a/utils.py +++ b/utils.py @@ -47,11 +47,11 @@ class ImgQueue(Queue): self.queue.clear() self.not_full.notify_all() - def safe_put(self, *args, **kwargs): + def safe_put(self, item): if self.full(): _ = self.get() return False - self.put(*args, **kwargs) + self.put(item) return True