Slow MultiThread Version

This commit is contained in:
li.zhenye 2022-07-28 13:47:32 +08:00
parent a0d14940e8
commit d697855abc
4 changed files with 38 additions and 22 deletions

View File

@ -15,7 +15,7 @@ class EfficientUI(object):
rgb_fifo_path = "/tmp/dkrgb.fifo" rgb_fifo_path = "/tmp/dkrgb.fifo"
# 创建队列用于链接各个线程 # 创建队列用于链接各个线程
rgb_img_queue, spec_img_queue = Queue(), Queue() rgb_img_queue, spec_img_queue = Queue(), Queue()
detector_queue, save_queue, self.visual_queue = Queue(), Queue, Queue() detector_queue, save_queue, self.visual_queue = Queue(), Queue(), Queue()
mask_queue = Queue() mask_queue = Queue()
# 两个接收者,接收光谱和rgb图像 # 两个接收者,接收光谱和rgb图像
spec_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节 spec_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节
@ -31,11 +31,11 @@ class EfficientUI(object):
# 发送 # 发送
sender = transmit.FifoSender(output_fifo_path=mask_fifo_path, source=mask_queue) sender = transmit.FifoSender(output_fifo_path=mask_fifo_path, source=mask_queue)
# 启动所有线程 # 启动所有线程
spec_receiver.start(post_process_func=transmit.PostProcessMethods.spec_data_post_process, name='spce_thread') spec_receiver.start(post_process_func=transmit.BeforeAfterMethods.spec_data_post_process, name='spce_thread')
rgb_receiver.start(post_process_func=transmit.PostProcessMethods.rgb_data_post_process, name='rgb_thread') rgb_receiver.start(post_process_func=transmit.BeforeAfterMethods.rgb_data_post_process, name='rgb_thread')
cmd_img_controller.start(name='control_thread') cmd_img_controller.start(name='control_thread')
detector.start(name='detector_thread') detector.start(name='detector_thread')
sender.start(name='sender_thread') sender.start(pre_process=transmit.BeforeAfterMethods.mask_preprocess, name='sender_thread')
def start(self): def start(self):
# 启动图形化 # 启动图形化

View File

@ -20,7 +20,7 @@ from config import Config
from utils import lab_scatter, read_labeled_img, size_threshold from utils import lab_scatter, read_labeled_img, size_threshold
deploy = False deploy = True
if not deploy: if not deploy:
print("Training env") print("Training env")
from tqdm import tqdm from tqdm import tqdm

View File

@ -1,10 +1,15 @@
import os import os
import threading import threading
import time
from utils import ImgQueue as Queue from utils import ImgQueue as Queue
import numpy as np import numpy as np
from config import Config from config import Config
from models import SpecDetector, RgbDetector from models import SpecDetector, RgbDetector
import typing import typing
import logging
logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s %(message)s',
level=logging.DEBUG)
class Transmitter(object): class Transmitter(object):
@ -47,26 +52,33 @@ class Transmitter(object):
raise NotImplementedError raise NotImplementedError
class PostProcessMethods: class BeforeAfterMethods:
@classmethod
def mask_preprocess(cls, mask: np.ndarray):
logging.info(f"Send mask with size {mask.shape}")
return mask.tobytes()
@classmethod @classmethod
def spec_data_post_process(cls, data): def spec_data_post_process(cls, data):
if len(data) < 3: if len(data) < 3:
threshold = int(float(data)) threshold = int(float(data))
print("[INFO] Get Spec threshold: ", threshold) logging.info(f"Get Spec threshold: {threshold}")
return threshold return threshold
else: else:
spec_img = np.frombuffer(data, dtype=np.float32).\ spec_img = np.frombuffer(data, dtype=np.float32).\
reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1) reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1)
logging.info(f"Get SPEC image with size {spec_img.shape}")
return spec_img return spec_img
@classmethod @classmethod
def rgb_data_post_process(cls, data): def rgb_data_post_process(cls, data):
if len(data) < 3: if len(data) < 3:
threshold = int(float(data)) threshold = int(float(data))
print("[INFO] Get RGB threshold: ", threshold) logging.info(f"Get RGB threshold: {threshold}")
return threshold return threshold
else: else:
rgb_img = np.frombuffer(data, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1)) rgb_img = np.frombuffer(data, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
logging.info(f"Get RGB img with size {rgb_img.shape}")
return rgb_img return rgb_img
@ -112,9 +124,7 @@ class FifoReceiver(Transmitter):
data = os.read(input_fifo, self._max_len) data = os.read(input_fifo, self._max_len)
if post_process_func is not None: if post_process_func is not None:
data = post_process_func(data) data = post_process_func(data)
if not self._output_queue.safe_put(data): self._output_queue.safe_put(data)
if self._msg_queue is not None:
self._msg_queue.put('Fifo Receiver的接收者未来得及接收')
os.close(input_fifo) os.close(input_fifo)
self._need_stop.clear() self._need_stop.clear()
@ -154,20 +164,16 @@ class FifoSender(Transmitter):
:return: :return:
""" """
while not self._need_stop.is_set(): while not self._need_stop.is_set():
output_fifo = os.open(self._output_fifo_path, os.O_WRONLY)
if self._input_source.empty(): if self._input_source.empty():
continue continue
data = self._input_source.get() data = self._input_source.get()
if pre_process is not None: if pre_process is not None:
data = pre_process(data) data = pre_process(data)
output_fifo = os.open(self._output_fifo_path, os.O_WRONLY)
os.write(output_fifo, data) os.write(output_fifo, data)
os.close(output_fifo) os.close(output_fifo)
self._need_stop.clear() self._need_stop.clear()
@staticmethod
def mask_preprocess(mask: np.ndarray):
return mask.tobytes()
def __del__(self): def __del__(self):
self.stop() self.stop()
if self._running_thread is not None: if self._running_thread is not None:
@ -196,7 +202,7 @@ class CmdImgSplitMidware(Transmitter):
self._subscribers = output self._subscribers = output
def start(self, name='CMD_thread'): def start(self, name='CMD_thread'):
self._server_thread = threading.Thread(target=self._cmd_control_service) self._server_thread = threading.Thread(target=self._cmd_control_service, name=name)
self._server_thread.start() self._server_thread.start()
def stop(self): def stop(self):
@ -216,8 +222,9 @@ class CmdImgSplitMidware(Transmitter):
continue continue
elif isinstance(spec_data, np.ndarray) and isinstance(rgb_data, np.ndarray): elif isinstance(spec_data, np.ndarray) and isinstance(rgb_data, np.ndarray):
# 如果是图片,交给预测的人 # 如果是图片,交给预测的人
for _, subscriber in self._subscribers.items(): for name, subscriber in self._subscribers.items():
subscriber.put((spec_data, rgb_data)) item = (spec_data, rgb_data)
subscriber.safe_put(item)
else: else:
# 否则程序出现毁灭性问题,立刻崩 # 否则程序出现毁灭性问题,立刻崩
raise Exception("两个相机传回的数据没有对上") raise Exception("两个相机传回的数据没有对上")
@ -260,12 +267,21 @@ class ThreadDetector(Transmitter):
self._predict_thread.start() self._predict_thread.start()
def predict(self, spec: np.ndarray, rgb: np.ndarray): def predict(self, spec: np.ndarray, rgb: np.ndarray):
logging.info(f'Detector get image with shape {spec.shape} and {rgb.shape}')
t1 = time.time()
mask = self._spec_detector.predict(spec) mask = self._spec_detector.predict(spec)
t2 = time.time()
logging.info(f'Detector finish spec predict within {(t2 - t1) * 1000:.2f}ms')
# rgb识别 # rgb识别
mask_rgb = self._rgb_detector.predict(rgb) mask_rgb = self._rgb_detector.predict(rgb)
t3 = time.time()
logging.info(f'Detector finish rgb predict within {(t3 - t2) * 1000:.2f}ms')
# 结果合并 # 结果合并
mask_result = (mask | mask_rgb).astype(np.uint8) mask_result = (mask | mask_rgb).astype(np.uint8)
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8) mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
t4 = time.time()
logging.info(f'Detector finish merge within {(t4 - t3) * 1000: .2f}ms')
logging.info(f'Detector finish predict within {(time.time() -t1)*1000:.2f}ms')
return mask_result return mask_result
def _predict_server(self): def _predict_server(self):
@ -273,5 +289,5 @@ class ThreadDetector(Transmitter):
if not self._input_queue.empty(): if not self._input_queue.empty():
spec, rgb = self._input_queue.get() spec, rgb = self._input_queue.get()
mask = self.predict(spec, rgb) mask = self.predict(spec, rgb)
self._output_queue.put(mask) self._output_queue.safe_put(mask)
self._thread_exit.clear() self._thread_exit.clear()

View File

@ -47,11 +47,11 @@ class ImgQueue(Queue):
self.queue.clear() self.queue.clear()
self.not_full.notify_all() self.not_full.notify_all()
def safe_put(self, *args, **kwargs): def safe_put(self, item):
if self.full(): if self.full():
_ = self.get() _ = self.get()
return False return False
self.put(*args, **kwargs) self.put(item)
return True return True