Slow MultiThread Version

This commit is contained in:
li.zhenye 2022-07-28 13:47:32 +08:00
parent a0d14940e8
commit d697855abc
4 changed files with 38 additions and 22 deletions

View File

@ -15,7 +15,7 @@ class EfficientUI(object):
rgb_fifo_path = "/tmp/dkrgb.fifo"
# 创建队列用于链接各个线程
rgb_img_queue, spec_img_queue = Queue(), Queue()
detector_queue, save_queue, self.visual_queue = Queue(), Queue, Queue()
detector_queue, save_queue, self.visual_queue = Queue(), Queue(), Queue()
mask_queue = Queue()
# 两个接收者,接收光谱和rgb图像
spec_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节
@ -31,11 +31,11 @@ class EfficientUI(object):
# 发送
sender = transmit.FifoSender(output_fifo_path=mask_fifo_path, source=mask_queue)
# 启动所有线程
spec_receiver.start(post_process_func=transmit.PostProcessMethods.spec_data_post_process, name='spce_thread')
rgb_receiver.start(post_process_func=transmit.PostProcessMethods.rgb_data_post_process, name='rgb_thread')
spec_receiver.start(post_process_func=transmit.BeforeAfterMethods.spec_data_post_process, name='spce_thread')
rgb_receiver.start(post_process_func=transmit.BeforeAfterMethods.rgb_data_post_process, name='rgb_thread')
cmd_img_controller.start(name='control_thread')
detector.start(name='detector_thread')
sender.start(name='sender_thread')
sender.start(pre_process=transmit.BeforeAfterMethods.mask_preprocess, name='sender_thread')
def start(self):
# 启动图形化

View File

@ -20,7 +20,7 @@ from config import Config
from utils import lab_scatter, read_labeled_img, size_threshold
deploy = False
deploy = True
if not deploy:
print("Training env")
from tqdm import tqdm

View File

@ -1,10 +1,15 @@
import os
import threading
import time
from utils import ImgQueue as Queue
import numpy as np
from config import Config
from models import SpecDetector, RgbDetector
import typing
import logging
logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s %(message)s',
level=logging.DEBUG)
class Transmitter(object):
@ -47,26 +52,33 @@ class Transmitter(object):
raise NotImplementedError
class PostProcessMethods:
class BeforeAfterMethods:
@classmethod
def mask_preprocess(cls, mask: np.ndarray):
logging.info(f"Send mask with size {mask.shape}")
return mask.tobytes()
@classmethod
def spec_data_post_process(cls, data):
if len(data) < 3:
threshold = int(float(data))
print("[INFO] Get Spec threshold: ", threshold)
logging.info(f"Get Spec threshold: {threshold}")
return threshold
else:
spec_img = np.frombuffer(data, dtype=np.float32).\
reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1)
logging.info(f"Get SPEC image with size {spec_img.shape}")
return spec_img
@classmethod
def rgb_data_post_process(cls, data):
if len(data) < 3:
threshold = int(float(data))
print("[INFO] Get RGB threshold: ", threshold)
logging.info(f"Get RGB threshold: {threshold}")
return threshold
else:
rgb_img = np.frombuffer(data, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
logging.info(f"Get RGB img with size {rgb_img.shape}")
return rgb_img
@ -112,9 +124,7 @@ class FifoReceiver(Transmitter):
data = os.read(input_fifo, self._max_len)
if post_process_func is not None:
data = post_process_func(data)
if not self._output_queue.safe_put(data):
if self._msg_queue is not None:
self._msg_queue.put('Fifo Receiver的接收者未来得及接收')
self._output_queue.safe_put(data)
os.close(input_fifo)
self._need_stop.clear()
@ -154,20 +164,16 @@ class FifoSender(Transmitter):
:return:
"""
while not self._need_stop.is_set():
output_fifo = os.open(self._output_fifo_path, os.O_WRONLY)
if self._input_source.empty():
continue
data = self._input_source.get()
if pre_process is not None:
data = pre_process(data)
output_fifo = os.open(self._output_fifo_path, os.O_WRONLY)
os.write(output_fifo, data)
os.close(output_fifo)
self._need_stop.clear()
@staticmethod
def mask_preprocess(mask: np.ndarray):
return mask.tobytes()
def __del__(self):
self.stop()
if self._running_thread is not None:
@ -196,7 +202,7 @@ class CmdImgSplitMidware(Transmitter):
self._subscribers = output
def start(self, name='CMD_thread'):
self._server_thread = threading.Thread(target=self._cmd_control_service)
self._server_thread = threading.Thread(target=self._cmd_control_service, name=name)
self._server_thread.start()
def stop(self):
@ -216,8 +222,9 @@ class CmdImgSplitMidware(Transmitter):
continue
elif isinstance(spec_data, np.ndarray) and isinstance(rgb_data, np.ndarray):
# 如果是图片,交给预测的人
for _, subscriber in self._subscribers.items():
subscriber.put((spec_data, rgb_data))
for name, subscriber in self._subscribers.items():
item = (spec_data, rgb_data)
subscriber.safe_put(item)
else:
# 否则程序出现毁灭性问题,立刻崩
raise Exception("两个相机传回的数据没有对上")
@ -260,12 +267,21 @@ class ThreadDetector(Transmitter):
self._predict_thread.start()
def predict(self, spec: np.ndarray, rgb: np.ndarray):
logging.info(f'Detector get image with shape {spec.shape} and {rgb.shape}')
t1 = time.time()
mask = self._spec_detector.predict(spec)
t2 = time.time()
logging.info(f'Detector finish spec predict within {(t2 - t1) * 1000:.2f}ms')
# rgb识别
mask_rgb = self._rgb_detector.predict(rgb)
t3 = time.time()
logging.info(f'Detector finish rgb predict within {(t3 - t2) * 1000:.2f}ms')
# 结果合并
mask_result = (mask | mask_rgb).astype(np.uint8)
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
t4 = time.time()
logging.info(f'Detector finish merge within {(t4 - t3) * 1000: .2f}ms')
logging.info(f'Detector finish predict within {(time.time() -t1)*1000:.2f}ms')
return mask_result
def _predict_server(self):
@ -273,5 +289,5 @@ class ThreadDetector(Transmitter):
if not self._input_queue.empty():
spec, rgb = self._input_queue.get()
mask = self.predict(spec, rgb)
self._output_queue.put(mask)
self._output_queue.safe_put(mask)
self._thread_exit.clear()

View File

@ -47,11 +47,11 @@ class ImgQueue(Queue):
self.queue.clear()
self.not_full.notify_all()
def safe_put(self, *args, **kwargs):
def safe_put(self, item):
if self.full():
_ = self.get()
return False
self.put(*args, **kwargs)
self.put(item)
return True