mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 14:23:55 +00:00
[ext] 添加了多线程接收功能
This commit is contained in:
parent
631b46d99a
commit
15c8bcb69c
45
README.md
45
README.md
@ -388,3 +388,48 @@ python main_test.py /path/to/convert -convert_dir /output/dir -s
|
||||
|
||||

|
||||
|
||||
# 多线程读取与多进程预测
|
||||
|
||||
## 多线程与进程总体结构图
|
||||
|
||||

|
||||
|
||||
## 多线程读取
|
||||
|
||||
为了能够避免IO的等待,我们使用了开销相对较小的线程来实现多线程的数据读取。
|
||||
|
||||
因为很多时候我们需要读图测试,所以我们写了一个FileReceiver类,用法大概就像测试文件里这样:
|
||||
|
||||
```python
|
||||
def test_file_receiver(self):
|
||||
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
image_queue = ImgQueue()
|
||||
file_receiver = FileReceiver(job_name='rgb img receive', input_dir='../data', output_queue=image_queue,
|
||||
speed=0.5, name_pattern=None)
|
||||
virtual_data = np.zeros((1024, 4096, 3), dtype=np.uint8)
|
||||
file_receiver.start(need_time=True, virtual_data=virtual_data)
|
||||
for i in range(10):
|
||||
data = image_queue.get()
|
||||
time_record = data[0]
|
||||
logging.info(f'Spent {(time.time() - time_record) * 1000:.2f}ms to get image with shape {data[-1].shape}')
|
||||
self.assertEqual(data[-1].shape, (1024, 4096, 3))
|
||||
file_receiver.stop()
|
||||
```
|
||||
|
||||
测试结果如下所示:
|
||||
|
||||
> 2022-08-17 23:46:09,742 - root - INFO - rgb img receive thread start.
|
||||
> 2022-08-17 23:46:09,754 - root - INFO - Spent 0.04ms to get image with shape (1024, 4096, 3)
|
||||
> 2022-08-17 23:46:10,259 - root - INFO - sleep 0.5s ...
|
||||
> 2022-08-17 23:46:10,276 - root - INFO - Spent 0.92ms to get image with shape (1024, 4096, 3)
|
||||
> 2022-08-17 23:46:10,780 - root - INFO - sleep 0.5s ...
|
||||
> 2022-08-17 23:46:10,789 - root - INFO - Spent 0.79ms to get image with shape (1024, 4096, 3)
|
||||
> 2022-08-17 23:46:11,293 - root - INFO - sleep 0.5s ...
|
||||
> 2022-08-17 23:46:11,301 - root - INFO - Spent 0.81ms to get image with shape (1024, 4096, 3)
|
||||
> 2022-08-17 23:46:11,802 - root - INFO - sleep 0.5s ...
|
||||
> 2022-08-17 23:46:11,810 - root - INFO - Spent 0.77ms to get image with shape (1024, 4096, 3)
|
||||
> 2022-08-17 23:46:12,314 - root - INFO - sleep 0.5s ...
|
||||
> 2022-08-17 23:46:12,315 - root - INFO - rgb img receive thread stop.
|
||||
|
||||
这里我们得到了一个惊人的数据传递速度,只花了将近1ms,这速度看着很不错哦。接下来我们把这东西变成多进程。
|
||||
|
||||
|
||||
54
tests/test_transmit.py
Normal file
54
tests/test_transmit.py
Normal file
@ -0,0 +1,54 @@
|
||||
import logging
|
||||
import time
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
|
||||
import transmit
|
||||
from config import Config
|
||||
from transmit import FileReceiver, FifoReceiver, FifoSender
|
||||
from utils import ImgQueue
|
||||
|
||||
|
||||
class TransmitterTest(unittest.TestCase):
|
||||
def test_file_receiver(self):
|
||||
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||
logging.info('测试文件接收器')
|
||||
image_queue = ImgQueue()
|
||||
file_receiver = FileReceiver(job_name='rgb img receive', input_dir='../data', output_queue=image_queue,
|
||||
speed=0.5, name_pattern=None)
|
||||
virtual_data = np.random.randint(0, 255, (1024, 4096, 3), dtype=np.uint8)
|
||||
file_receiver.start(need_time=True, virtual_data=virtual_data)
|
||||
for i in range(5):
|
||||
time_record, read_data, virtual_data_rec = image_queue.get()
|
||||
current_time = time.time()
|
||||
logging.info(f'Spent {(current_time - time_record) * 1000:.2f}ms to get image with shape {virtual_data.shape}')
|
||||
is_equal = np.all(virtual_data_rec == virtual_data, axis=(0, 1, 2))
|
||||
self.assertTrue(is_equal)
|
||||
self.assertEqual(virtual_data.shape, (1024, 4096, 3))
|
||||
file_receiver.stop()
|
||||
|
||||
@unittest.skip('skip')
|
||||
def test_fifo_receiver_sender(self):
|
||||
total_rgb = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量
|
||||
image_queue, input_queue = ImgQueue(), ImgQueue()
|
||||
fifo_receiver = FifoReceiver(job_name='fifo img receive', fifo_path='/tmp/dkimg.fifo', output=image_queue,
|
||||
read_max_num=total_rgb)
|
||||
fifo_sender = FifoSender(fifo_path='/tmp/dkimg.fifo', source=input_queue, job_name='fifo img send')
|
||||
virtual_data = np.random.randint(0, 255, (1024, 4096, 3), dtype=np.uint8)
|
||||
fifo_sender.start(preprocess=transmit.BeforeAfterMethods.mask_preprocess)
|
||||
fifo_receiver.start()
|
||||
logging.debug('Start to send virtual data')
|
||||
for i in range(5):
|
||||
logging.debug('put data to input queue')
|
||||
|
||||
input_queue.put(virtual_data)
|
||||
logging.debug('put data to input queue done')
|
||||
virtual_data = image_queue.get()
|
||||
|
||||
# logging.info(f'Spent {(current_time - time_record) * 1000:.2f}ms to get image with shape {virtual_data.shape}')
|
||||
self.assertEqual(virtual_data.shape, (1024, 4096, 3))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
85
transmit.py
85
transmit.py
@ -1,6 +1,7 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import threading
|
||||
import typing
|
||||
from multiprocessing import Process, Queue
|
||||
import time
|
||||
from multiprocessing.synchronize import Lock
|
||||
@ -16,8 +17,6 @@ from config import Config
|
||||
from models import SpecDetector, RgbDetector
|
||||
from typing import Any, Union
|
||||
import logging
|
||||
logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s %(message)s',
|
||||
level=logging.WARNING)
|
||||
|
||||
|
||||
class Transmitter(object):
|
||||
@ -56,11 +55,11 @@ class Transmitter(object):
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
name = kwargs.get('name', default='base thread')
|
||||
name = kwargs.get('name', 'base thread')
|
||||
if not self.run_process:
|
||||
self._running_handler = threading.Thread(target=self.job_func, name=name, args=args)
|
||||
self._running_handler = threading.Thread(target=self.job_func, name=name, args=args, kwargs=kwargs)
|
||||
else:
|
||||
self._running_handler = Process(target=self.job_func, name=name, args=args, daemon=True)
|
||||
self._running_handler = Process(target=self.job_func, name=name, daemon=True, args=args, kwargs=kwargs)
|
||||
self._running_handler.start()
|
||||
|
||||
def stop(self, *args, **kwargs):
|
||||
@ -81,13 +80,14 @@ class Transmitter(object):
|
||||
|
||||
@staticmethod
|
||||
def job_decorator(func):
|
||||
functools.wraps(func)
|
||||
@functools.wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
logging.info(f'{self.job_name} {"process" if self.run_process else "thread"} start.')
|
||||
while not self._thread_stop.is_set():
|
||||
self.job_func(*args, **kwargs)
|
||||
func(self, *args, **kwargs)
|
||||
logging.info(f'{self.job_name} {"process" if self.run_process else "thread"} stop.')
|
||||
self._need_stop.clear()
|
||||
self._thread_stop.clear()
|
||||
|
||||
return wrapper
|
||||
|
||||
def job_func(self, *args, **kwargs):
|
||||
@ -125,8 +125,9 @@ class BeforeAfterMethods:
|
||||
|
||||
|
||||
class FileReceiver(Transmitter):
|
||||
def __init__(self, job_name:str, input_dir: str, output_queue:ImgQueue, speed: int=3, name_pattern=None):
|
||||
super(FileReceiver, self).__init__(job_name=job_name, run_process=False)
|
||||
def __init__(self, input_dir: str, output_queue: ImgQueue, speed: float = 3.0, name_pattern: str=None,
|
||||
job_name: str = 'file_receiver', ):
|
||||
super(FileReceiver, self).__init__(job_name=job_name)
|
||||
self.input_dir = input_dir
|
||||
self.send_speed = speed
|
||||
self.file_names = None
|
||||
@ -152,11 +153,18 @@ class FileReceiver(Transmitter):
|
||||
self.file_idx = 0
|
||||
|
||||
def set_output(self, output: ImgQueue):
|
||||
with self._io_lock:
|
||||
self.stop()
|
||||
self.output_queue = output
|
||||
|
||||
@Transmitter.job_decorator
|
||||
def job_func(self, *args, **kwargs):
|
||||
def job_func(self, need_time=False, *args, **kwargs):
|
||||
"""
|
||||
发送文件
|
||||
|
||||
:param need_time: 是否需要发送时间戳
|
||||
:param kwargs: virtual_data: 虚拟的数据,用于测试
|
||||
:return:
|
||||
"""
|
||||
with self._io_lock:
|
||||
self.file_idx += 1
|
||||
if self.file_idx >= len(self.file_names):
|
||||
@ -167,28 +175,36 @@ class FileReceiver(Transmitter):
|
||||
data = f.read()
|
||||
if self.preprocess_method is not None:
|
||||
data = self.preprocess_method(data)
|
||||
if need_time:
|
||||
data = (time.time(), data)
|
||||
if 'virtual_data' in kwargs:
|
||||
data = (*data, kwargs['virtual_data'])
|
||||
self.output_queue.put(data)
|
||||
|
||||
time.sleep(self.send_speed)
|
||||
logging.info(f'sleep {self.send_speed}s ...')
|
||||
|
||||
|
||||
class FifoReceiver(Transmitter):
|
||||
def __init__(self, job_name:str, fifo_path: str, output: ImgQueue,
|
||||
read_max_num: int, msg_queue=None):
|
||||
def __init__(self, fifo_path: str, output: ImgQueue,
|
||||
read_max_num: int, job_name: str = 'fifo_receiver'):
|
||||
super().__init__(job_name=job_name)
|
||||
self._input_fifo_path = None
|
||||
self._output_queue = None
|
||||
self._msg_queue = msg_queue
|
||||
self._max_len = read_max_num
|
||||
|
||||
self.set_source(fifo_path)
|
||||
self.set_output(output)
|
||||
|
||||
def set_source(self, fifo_path: str):
|
||||
self.stop()
|
||||
with self._io_lock:
|
||||
if not os.access(fifo_path, os.F_OK):
|
||||
os.mkfifo(fifo_path, 0o777)
|
||||
self._input_fifo_path = fifo_path
|
||||
|
||||
def set_output(self, output: ImgQueue):
|
||||
self.stop()
|
||||
with self._io_lock:
|
||||
self._output_queue = output
|
||||
|
||||
@Transmitter.job_decorator
|
||||
@ -203,34 +219,33 @@ class FifoReceiver(Transmitter):
|
||||
data = os.read(input_fifo, self._max_len)
|
||||
if post_process_func is not None:
|
||||
data = post_process_func(data)
|
||||
self._output_queue.safe_put(data)
|
||||
self._output_queue.put(data)
|
||||
os.close(input_fifo)
|
||||
|
||||
|
||||
|
||||
class FifoSender(Transmitter):
|
||||
def __init__(self, output_fifo_path: str, source: ImgQueue):
|
||||
super().__init__()
|
||||
def __init__(self, fifo_path: str, source: ImgQueue, job_name: str = 'fifo_sender'):
|
||||
super().__init__(job_name=job_name)
|
||||
self._input_source = None
|
||||
self._output_fifo_path = None
|
||||
self.set_source(source)
|
||||
self.set_output(output_fifo_path)
|
||||
self._need_stop = threading.Event()
|
||||
self._need_stop.clear()
|
||||
self._running_thread = None
|
||||
self.set_output(fifo_path)
|
||||
|
||||
def set_source(self, source: ImgQueue):
|
||||
self.stop()
|
||||
with self._io_lock:
|
||||
self._input_source = source
|
||||
|
||||
def set_output(self, output_fifo_path: str):
|
||||
self.stop()
|
||||
with self._io_lock:
|
||||
if not os.access(output_fifo_path, os.F_OK):
|
||||
os.mkfifo(output_fifo_path, 0o777)
|
||||
self._output_fifo_path = output_fifo_path
|
||||
|
||||
|
||||
def job_func(self, pre_process, *args, **kwargs):
|
||||
def job_func(self, pre_process=None, *args, **kwargs):
|
||||
"""
|
||||
接收线程
|
||||
发送线程
|
||||
|
||||
:param pre_process:
|
||||
:return:
|
||||
@ -240,21 +255,18 @@ class FifoSender(Transmitter):
|
||||
data = self._input_source.get()
|
||||
if pre_process is not None:
|
||||
data = pre_process(data)
|
||||
logging.debug(f'put data to fifo {self._output_fifo_path}')
|
||||
output_fifo = os.open(self._output_fifo_path, os.O_WRONLY)
|
||||
os.write(output_fifo, data)
|
||||
os.close(output_fifo)
|
||||
|
||||
|
||||
def __del__(self):
|
||||
self.stop()
|
||||
if self._running_thread is not None:
|
||||
self._running_thread.join()
|
||||
logging.debug(f'put data to fifo {self._output_fifo_path} done')
|
||||
|
||||
|
||||
class CmdImgSplitMidware(Transmitter):
|
||||
"""
|
||||
用于控制命令和图像的中间件
|
||||
"""
|
||||
|
||||
def __init__(self, subscribers: typing.Dict[str, ImgQueue], rgb_queue: ImgQueue, spec_queue: ImgQueue):
|
||||
super().__init__()
|
||||
self._rgb_queue = None
|
||||
@ -308,6 +320,7 @@ class ImageSaver(Transmitter):
|
||||
"""
|
||||
进行图片存储的中间件
|
||||
"""
|
||||
|
||||
def set_source(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@ -422,10 +435,14 @@ class ProcessDetector(Transmitter):
|
||||
|
||||
class SplitMidware(Transmitter):
|
||||
def set_source(self, mask_source: ImgQueue):
|
||||
|
||||
pass
|
||||
|
||||
def start(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def stop(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
||||
|
||||
3
utils.py
3
utils.py
@ -186,7 +186,8 @@ def valve_limit(mask: np.ndarray, max_valve_num: int) -> np.ndarray:
|
||||
row_valve_count = np.sum(mask, axis=1)
|
||||
if np.any(row_valve_count > max_valve_num):
|
||||
over_rows_idx = np.argwhere(row_valve_count > max_valve_num).ravel()
|
||||
logging.warning(f'发现单行喷阀数量{len(over_rows_idx)}超过限制,已限制到最大许可值{max_valve_num}')
|
||||
logging.warning(f'发现有{len(over_rows_idx)}行的单行喷阀数量超过限制,原喷阀数量为{row_valve_count[over_rows_idx]},'
|
||||
f'已全部限制到最大许可值{max_valve_num}')
|
||||
over_rows = mask[over_rows_idx, :]
|
||||
|
||||
# a simple function to get lucky valves when too many valves appear in the same line
|
||||
|
||||
Loading…
Reference in New Issue
Block a user