mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 22:33:54 +00:00
Slow MultiThread Version
This commit is contained in:
parent
a0d14940e8
commit
d697855abc
@ -15,7 +15,7 @@ class EfficientUI(object):
|
||||
rgb_fifo_path = "/tmp/dkrgb.fifo"
|
||||
# 创建队列用于链接各个线程
|
||||
rgb_img_queue, spec_img_queue = Queue(), Queue()
|
||||
detector_queue, save_queue, self.visual_queue = Queue(), Queue, Queue()
|
||||
detector_queue, save_queue, self.visual_queue = Queue(), Queue(), Queue()
|
||||
mask_queue = Queue()
|
||||
# 两个接收者,接收光谱和rgb图像
|
||||
spec_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节
|
||||
@ -31,11 +31,11 @@ class EfficientUI(object):
|
||||
# 发送
|
||||
sender = transmit.FifoSender(output_fifo_path=mask_fifo_path, source=mask_queue)
|
||||
# 启动所有线程
|
||||
spec_receiver.start(post_process_func=transmit.PostProcessMethods.spec_data_post_process, name='spce_thread')
|
||||
rgb_receiver.start(post_process_func=transmit.PostProcessMethods.rgb_data_post_process, name='rgb_thread')
|
||||
spec_receiver.start(post_process_func=transmit.BeforeAfterMethods.spec_data_post_process, name='spce_thread')
|
||||
rgb_receiver.start(post_process_func=transmit.BeforeAfterMethods.rgb_data_post_process, name='rgb_thread')
|
||||
cmd_img_controller.start(name='control_thread')
|
||||
detector.start(name='detector_thread')
|
||||
sender.start(name='sender_thread')
|
||||
sender.start(pre_process=transmit.BeforeAfterMethods.mask_preprocess, name='sender_thread')
|
||||
|
||||
def start(self):
|
||||
# 启动图形化
|
||||
|
||||
@ -20,7 +20,7 @@ from config import Config
|
||||
from utils import lab_scatter, read_labeled_img, size_threshold
|
||||
|
||||
|
||||
deploy = False
|
||||
deploy = True
|
||||
if not deploy:
|
||||
print("Training env")
|
||||
from tqdm import tqdm
|
||||
|
||||
46
transmit.py
46
transmit.py
@ -1,10 +1,15 @@
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
|
||||
from utils import ImgQueue as Queue
|
||||
import numpy as np
|
||||
from config import Config
|
||||
from models import SpecDetector, RgbDetector
|
||||
import typing
|
||||
import logging
|
||||
logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s %(message)s',
|
||||
level=logging.DEBUG)
|
||||
|
||||
|
||||
class Transmitter(object):
|
||||
@ -47,26 +52,33 @@ class Transmitter(object):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class PostProcessMethods:
|
||||
class BeforeAfterMethods:
|
||||
@classmethod
|
||||
def mask_preprocess(cls, mask: np.ndarray):
|
||||
logging.info(f"Send mask with size {mask.shape}")
|
||||
return mask.tobytes()
|
||||
|
||||
@classmethod
|
||||
def spec_data_post_process(cls, data):
|
||||
if len(data) < 3:
|
||||
threshold = int(float(data))
|
||||
print("[INFO] Get Spec threshold: ", threshold)
|
||||
logging.info(f"Get Spec threshold: {threshold}")
|
||||
return threshold
|
||||
else:
|
||||
spec_img = np.frombuffer(data, dtype=np.float32).\
|
||||
reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1)
|
||||
logging.info(f"Get SPEC image with size {spec_img.shape}")
|
||||
return spec_img
|
||||
|
||||
@classmethod
|
||||
def rgb_data_post_process(cls, data):
|
||||
if len(data) < 3:
|
||||
threshold = int(float(data))
|
||||
print("[INFO] Get RGB threshold: ", threshold)
|
||||
logging.info(f"Get RGB threshold: {threshold}")
|
||||
return threshold
|
||||
else:
|
||||
rgb_img = np.frombuffer(data, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
|
||||
logging.info(f"Get RGB img with size {rgb_img.shape}")
|
||||
return rgb_img
|
||||
|
||||
|
||||
@ -112,9 +124,7 @@ class FifoReceiver(Transmitter):
|
||||
data = os.read(input_fifo, self._max_len)
|
||||
if post_process_func is not None:
|
||||
data = post_process_func(data)
|
||||
if not self._output_queue.safe_put(data):
|
||||
if self._msg_queue is not None:
|
||||
self._msg_queue.put('Fifo Receiver的接收者未来得及接收')
|
||||
self._output_queue.safe_put(data)
|
||||
os.close(input_fifo)
|
||||
self._need_stop.clear()
|
||||
|
||||
@ -154,20 +164,16 @@ class FifoSender(Transmitter):
|
||||
:return:
|
||||
"""
|
||||
while not self._need_stop.is_set():
|
||||
output_fifo = os.open(self._output_fifo_path, os.O_WRONLY)
|
||||
if self._input_source.empty():
|
||||
continue
|
||||
data = self._input_source.get()
|
||||
if pre_process is not None:
|
||||
data = pre_process(data)
|
||||
output_fifo = os.open(self._output_fifo_path, os.O_WRONLY)
|
||||
os.write(output_fifo, data)
|
||||
os.close(output_fifo)
|
||||
self._need_stop.clear()
|
||||
|
||||
@staticmethod
|
||||
def mask_preprocess(mask: np.ndarray):
|
||||
return mask.tobytes()
|
||||
|
||||
def __del__(self):
|
||||
self.stop()
|
||||
if self._running_thread is not None:
|
||||
@ -196,7 +202,7 @@ class CmdImgSplitMidware(Transmitter):
|
||||
self._subscribers = output
|
||||
|
||||
def start(self, name='CMD_thread'):
|
||||
self._server_thread = threading.Thread(target=self._cmd_control_service)
|
||||
self._server_thread = threading.Thread(target=self._cmd_control_service, name=name)
|
||||
self._server_thread.start()
|
||||
|
||||
def stop(self):
|
||||
@ -216,8 +222,9 @@ class CmdImgSplitMidware(Transmitter):
|
||||
continue
|
||||
elif isinstance(spec_data, np.ndarray) and isinstance(rgb_data, np.ndarray):
|
||||
# 如果是图片,交给预测的人
|
||||
for _, subscriber in self._subscribers.items():
|
||||
subscriber.put((spec_data, rgb_data))
|
||||
for name, subscriber in self._subscribers.items():
|
||||
item = (spec_data, rgb_data)
|
||||
subscriber.safe_put(item)
|
||||
else:
|
||||
# 否则程序出现毁灭性问题,立刻崩
|
||||
raise Exception("两个相机传回的数据没有对上")
|
||||
@ -260,12 +267,21 @@ class ThreadDetector(Transmitter):
|
||||
self._predict_thread.start()
|
||||
|
||||
def predict(self, spec: np.ndarray, rgb: np.ndarray):
|
||||
logging.info(f'Detector get image with shape {spec.shape} and {rgb.shape}')
|
||||
t1 = time.time()
|
||||
mask = self._spec_detector.predict(spec)
|
||||
t2 = time.time()
|
||||
logging.info(f'Detector finish spec predict within {(t2 - t1) * 1000:.2f}ms')
|
||||
# rgb识别
|
||||
mask_rgb = self._rgb_detector.predict(rgb)
|
||||
t3 = time.time()
|
||||
logging.info(f'Detector finish rgb predict within {(t3 - t2) * 1000:.2f}ms')
|
||||
# 结果合并
|
||||
mask_result = (mask | mask_rgb).astype(np.uint8)
|
||||
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
|
||||
t4 = time.time()
|
||||
logging.info(f'Detector finish merge within {(t4 - t3) * 1000: .2f}ms')
|
||||
logging.info(f'Detector finish predict within {(time.time() -t1)*1000:.2f}ms')
|
||||
return mask_result
|
||||
|
||||
def _predict_server(self):
|
||||
@ -273,5 +289,5 @@ class ThreadDetector(Transmitter):
|
||||
if not self._input_queue.empty():
|
||||
spec, rgb = self._input_queue.get()
|
||||
mask = self.predict(spec, rgb)
|
||||
self._output_queue.put(mask)
|
||||
self._output_queue.safe_put(mask)
|
||||
self._thread_exit.clear()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user