mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 14:23:55 +00:00
Slow MultiThread Version
This commit is contained in:
parent
a0d14940e8
commit
d697855abc
@ -15,7 +15,7 @@ class EfficientUI(object):
|
|||||||
rgb_fifo_path = "/tmp/dkrgb.fifo"
|
rgb_fifo_path = "/tmp/dkrgb.fifo"
|
||||||
# 创建队列用于链接各个线程
|
# 创建队列用于链接各个线程
|
||||||
rgb_img_queue, spec_img_queue = Queue(), Queue()
|
rgb_img_queue, spec_img_queue = Queue(), Queue()
|
||||||
detector_queue, save_queue, self.visual_queue = Queue(), Queue, Queue()
|
detector_queue, save_queue, self.visual_queue = Queue(), Queue(), Queue()
|
||||||
mask_queue = Queue()
|
mask_queue = Queue()
|
||||||
# 两个接收者,接收光谱和rgb图像
|
# 两个接收者,接收光谱和rgb图像
|
||||||
spec_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节
|
spec_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节
|
||||||
@ -31,11 +31,11 @@ class EfficientUI(object):
|
|||||||
# 发送
|
# 发送
|
||||||
sender = transmit.FifoSender(output_fifo_path=mask_fifo_path, source=mask_queue)
|
sender = transmit.FifoSender(output_fifo_path=mask_fifo_path, source=mask_queue)
|
||||||
# 启动所有线程
|
# 启动所有线程
|
||||||
spec_receiver.start(post_process_func=transmit.PostProcessMethods.spec_data_post_process, name='spce_thread')
|
spec_receiver.start(post_process_func=transmit.BeforeAfterMethods.spec_data_post_process, name='spce_thread')
|
||||||
rgb_receiver.start(post_process_func=transmit.PostProcessMethods.rgb_data_post_process, name='rgb_thread')
|
rgb_receiver.start(post_process_func=transmit.BeforeAfterMethods.rgb_data_post_process, name='rgb_thread')
|
||||||
cmd_img_controller.start(name='control_thread')
|
cmd_img_controller.start(name='control_thread')
|
||||||
detector.start(name='detector_thread')
|
detector.start(name='detector_thread')
|
||||||
sender.start(name='sender_thread')
|
sender.start(pre_process=transmit.BeforeAfterMethods.mask_preprocess, name='sender_thread')
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
# 启动图形化
|
# 启动图形化
|
||||||
|
|||||||
@ -20,7 +20,7 @@ from config import Config
|
|||||||
from utils import lab_scatter, read_labeled_img, size_threshold
|
from utils import lab_scatter, read_labeled_img, size_threshold
|
||||||
|
|
||||||
|
|
||||||
deploy = False
|
deploy = True
|
||||||
if not deploy:
|
if not deploy:
|
||||||
print("Training env")
|
print("Training env")
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|||||||
46
transmit.py
46
transmit.py
@ -1,10 +1,15 @@
|
|||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
|
|
||||||
from utils import ImgQueue as Queue
|
from utils import ImgQueue as Queue
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from config import Config
|
from config import Config
|
||||||
from models import SpecDetector, RgbDetector
|
from models import SpecDetector, RgbDetector
|
||||||
import typing
|
import typing
|
||||||
|
import logging
|
||||||
|
logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s %(message)s',
|
||||||
|
level=logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
class Transmitter(object):
|
class Transmitter(object):
|
||||||
@ -47,26 +52,33 @@ class Transmitter(object):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class PostProcessMethods:
|
class BeforeAfterMethods:
|
||||||
|
@classmethod
|
||||||
|
def mask_preprocess(cls, mask: np.ndarray):
|
||||||
|
logging.info(f"Send mask with size {mask.shape}")
|
||||||
|
return mask.tobytes()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def spec_data_post_process(cls, data):
|
def spec_data_post_process(cls, data):
|
||||||
if len(data) < 3:
|
if len(data) < 3:
|
||||||
threshold = int(float(data))
|
threshold = int(float(data))
|
||||||
print("[INFO] Get Spec threshold: ", threshold)
|
logging.info(f"Get Spec threshold: {threshold}")
|
||||||
return threshold
|
return threshold
|
||||||
else:
|
else:
|
||||||
spec_img = np.frombuffer(data, dtype=np.float32).\
|
spec_img = np.frombuffer(data, dtype=np.float32).\
|
||||||
reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1)
|
reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1)
|
||||||
|
logging.info(f"Get SPEC image with size {spec_img.shape}")
|
||||||
return spec_img
|
return spec_img
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def rgb_data_post_process(cls, data):
|
def rgb_data_post_process(cls, data):
|
||||||
if len(data) < 3:
|
if len(data) < 3:
|
||||||
threshold = int(float(data))
|
threshold = int(float(data))
|
||||||
print("[INFO] Get RGB threshold: ", threshold)
|
logging.info(f"Get RGB threshold: {threshold}")
|
||||||
return threshold
|
return threshold
|
||||||
else:
|
else:
|
||||||
rgb_img = np.frombuffer(data, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
|
rgb_img = np.frombuffer(data, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
|
||||||
|
logging.info(f"Get RGB img with size {rgb_img.shape}")
|
||||||
return rgb_img
|
return rgb_img
|
||||||
|
|
||||||
|
|
||||||
@ -112,9 +124,7 @@ class FifoReceiver(Transmitter):
|
|||||||
data = os.read(input_fifo, self._max_len)
|
data = os.read(input_fifo, self._max_len)
|
||||||
if post_process_func is not None:
|
if post_process_func is not None:
|
||||||
data = post_process_func(data)
|
data = post_process_func(data)
|
||||||
if not self._output_queue.safe_put(data):
|
self._output_queue.safe_put(data)
|
||||||
if self._msg_queue is not None:
|
|
||||||
self._msg_queue.put('Fifo Receiver的接收者未来得及接收')
|
|
||||||
os.close(input_fifo)
|
os.close(input_fifo)
|
||||||
self._need_stop.clear()
|
self._need_stop.clear()
|
||||||
|
|
||||||
@ -154,20 +164,16 @@ class FifoSender(Transmitter):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
while not self._need_stop.is_set():
|
while not self._need_stop.is_set():
|
||||||
output_fifo = os.open(self._output_fifo_path, os.O_WRONLY)
|
|
||||||
if self._input_source.empty():
|
if self._input_source.empty():
|
||||||
continue
|
continue
|
||||||
data = self._input_source.get()
|
data = self._input_source.get()
|
||||||
if pre_process is not None:
|
if pre_process is not None:
|
||||||
data = pre_process(data)
|
data = pre_process(data)
|
||||||
|
output_fifo = os.open(self._output_fifo_path, os.O_WRONLY)
|
||||||
os.write(output_fifo, data)
|
os.write(output_fifo, data)
|
||||||
os.close(output_fifo)
|
os.close(output_fifo)
|
||||||
self._need_stop.clear()
|
self._need_stop.clear()
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def mask_preprocess(mask: np.ndarray):
|
|
||||||
return mask.tobytes()
|
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
self.stop()
|
self.stop()
|
||||||
if self._running_thread is not None:
|
if self._running_thread is not None:
|
||||||
@ -196,7 +202,7 @@ class CmdImgSplitMidware(Transmitter):
|
|||||||
self._subscribers = output
|
self._subscribers = output
|
||||||
|
|
||||||
def start(self, name='CMD_thread'):
|
def start(self, name='CMD_thread'):
|
||||||
self._server_thread = threading.Thread(target=self._cmd_control_service)
|
self._server_thread = threading.Thread(target=self._cmd_control_service, name=name)
|
||||||
self._server_thread.start()
|
self._server_thread.start()
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
@ -216,8 +222,9 @@ class CmdImgSplitMidware(Transmitter):
|
|||||||
continue
|
continue
|
||||||
elif isinstance(spec_data, np.ndarray) and isinstance(rgb_data, np.ndarray):
|
elif isinstance(spec_data, np.ndarray) and isinstance(rgb_data, np.ndarray):
|
||||||
# 如果是图片,交给预测的人
|
# 如果是图片,交给预测的人
|
||||||
for _, subscriber in self._subscribers.items():
|
for name, subscriber in self._subscribers.items():
|
||||||
subscriber.put((spec_data, rgb_data))
|
item = (spec_data, rgb_data)
|
||||||
|
subscriber.safe_put(item)
|
||||||
else:
|
else:
|
||||||
# 否则程序出现毁灭性问题,立刻崩
|
# 否则程序出现毁灭性问题,立刻崩
|
||||||
raise Exception("两个相机传回的数据没有对上")
|
raise Exception("两个相机传回的数据没有对上")
|
||||||
@ -260,12 +267,21 @@ class ThreadDetector(Transmitter):
|
|||||||
self._predict_thread.start()
|
self._predict_thread.start()
|
||||||
|
|
||||||
def predict(self, spec: np.ndarray, rgb: np.ndarray):
|
def predict(self, spec: np.ndarray, rgb: np.ndarray):
|
||||||
|
logging.info(f'Detector get image with shape {spec.shape} and {rgb.shape}')
|
||||||
|
t1 = time.time()
|
||||||
mask = self._spec_detector.predict(spec)
|
mask = self._spec_detector.predict(spec)
|
||||||
|
t2 = time.time()
|
||||||
|
logging.info(f'Detector finish spec predict within {(t2 - t1) * 1000:.2f}ms')
|
||||||
# rgb识别
|
# rgb识别
|
||||||
mask_rgb = self._rgb_detector.predict(rgb)
|
mask_rgb = self._rgb_detector.predict(rgb)
|
||||||
|
t3 = time.time()
|
||||||
|
logging.info(f'Detector finish rgb predict within {(t3 - t2) * 1000:.2f}ms')
|
||||||
# 结果合并
|
# 结果合并
|
||||||
mask_result = (mask | mask_rgb).astype(np.uint8)
|
mask_result = (mask | mask_rgb).astype(np.uint8)
|
||||||
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
|
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
|
||||||
|
t4 = time.time()
|
||||||
|
logging.info(f'Detector finish merge within {(t4 - t3) * 1000: .2f}ms')
|
||||||
|
logging.info(f'Detector finish predict within {(time.time() -t1)*1000:.2f}ms')
|
||||||
return mask_result
|
return mask_result
|
||||||
|
|
||||||
def _predict_server(self):
|
def _predict_server(self):
|
||||||
@ -273,5 +289,5 @@ class ThreadDetector(Transmitter):
|
|||||||
if not self._input_queue.empty():
|
if not self._input_queue.empty():
|
||||||
spec, rgb = self._input_queue.get()
|
spec, rgb = self._input_queue.get()
|
||||||
mask = self.predict(spec, rgb)
|
mask = self.predict(spec, rgb)
|
||||||
self._output_queue.put(mask)
|
self._output_queue.safe_put(mask)
|
||||||
self._thread_exit.clear()
|
self._thread_exit.clear()
|
||||||
|
|||||||
4
utils.py
4
utils.py
@ -47,11 +47,11 @@ class ImgQueue(Queue):
|
|||||||
self.queue.clear()
|
self.queue.clear()
|
||||||
self.not_full.notify_all()
|
self.not_full.notify_all()
|
||||||
|
|
||||||
def safe_put(self, *args, **kwargs):
|
def safe_put(self, item):
|
||||||
if self.full():
|
if self.full():
|
||||||
_ = self.get()
|
_ = self.get()
|
||||||
return False
|
return False
|
||||||
self.put(*args, **kwargs)
|
self.put(item)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user