[ext] 添加了多线程接收功能

This commit is contained in:
li.zhenye 2022-08-18 09:19:44 +08:00
parent 631b46d99a
commit 15c8bcb69c
4 changed files with 168 additions and 51 deletions

View File

@ -388,3 +388,48 @@ python main_test.py /path/to/convert -convert_dir /output/dir -s
![image-20220808123044267](https://raw.githubusercontent.com/Karllzy/imagebed/main/img/image-20220808123044267.png)
# 多线程读取与多进程预测
## 多线程与进程总体结构图
![MultiThread](https://raw.githubusercontent.com/Karllzy/imagebed/main/img/MultiThread.png)
## 多线程读取
为了能够避免IO的等待我们使用了开销相对较小的线程来实现多线程的数据读取。
因为很多时候我们需要读图测试所以我们写了一个FileReceiver类用法大概就像测试文件里这样
```python
def test_file_receiver(self):
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
image_queue = ImgQueue()
file_receiver = FileReceiver(job_name='rgb img receive', input_dir='../data', output_queue=image_queue,
speed=0.5, name_pattern=None)
virtual_data = np.zeros((1024, 4096, 3), dtype=np.uint8)
file_receiver.start(need_time=True, virtual_data=virtual_data)
for i in range(10):
data = image_queue.get()
time_record = data[0]
logging.info(f'Spent {(time.time() - time_record) * 1000:.2f}ms to get image with shape {data[-1].shape}')
self.assertEqual(data[-1].shape, (1024, 4096, 3))
file_receiver.stop()
```
测试结果如下所示:
> 2022-08-17 23:46:09,742 - root - INFO - rgb img receive thread start.
> 2022-08-17 23:46:09,754 - root - INFO - Spent 0.04ms to get image with shape (1024, 4096, 3)
> 2022-08-17 23:46:10,259 - root - INFO - sleep 0.5s ...
> 2022-08-17 23:46:10,276 - root - INFO - Spent 0.92ms to get image with shape (1024, 4096, 3)
> 2022-08-17 23:46:10,780 - root - INFO - sleep 0.5s ...
> 2022-08-17 23:46:10,789 - root - INFO - Spent 0.79ms to get image with shape (1024, 4096, 3)
> 2022-08-17 23:46:11,293 - root - INFO - sleep 0.5s ...
> 2022-08-17 23:46:11,301 - root - INFO - Spent 0.81ms to get image with shape (1024, 4096, 3)
> 2022-08-17 23:46:11,802 - root - INFO - sleep 0.5s ...
> 2022-08-17 23:46:11,810 - root - INFO - Spent 0.77ms to get image with shape (1024, 4096, 3)
> 2022-08-17 23:46:12,314 - root - INFO - sleep 0.5s ...
> 2022-08-17 23:46:12,315 - root - INFO - rgb img receive thread stop.
这里我们得到了一个惊人的数据传递速度只花了将近1ms这速度看着很不错哦。接下来我们把这东西变成多进程。

54
tests/test_transmit.py Normal file
View File

@ -0,0 +1,54 @@
import logging
import time
import unittest
import numpy as np
import transmit
from config import Config
from transmit import FileReceiver, FifoReceiver, FifoSender
from utils import ImgQueue
class TransmitterTest(unittest.TestCase):
def test_file_receiver(self):
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logging.info('测试文件接收器')
image_queue = ImgQueue()
file_receiver = FileReceiver(job_name='rgb img receive', input_dir='../data', output_queue=image_queue,
speed=0.5, name_pattern=None)
virtual_data = np.random.randint(0, 255, (1024, 4096, 3), dtype=np.uint8)
file_receiver.start(need_time=True, virtual_data=virtual_data)
for i in range(5):
time_record, read_data, virtual_data_rec = image_queue.get()
current_time = time.time()
logging.info(f'Spent {(current_time - time_record) * 1000:.2f}ms to get image with shape {virtual_data.shape}')
is_equal = np.all(virtual_data_rec == virtual_data, axis=(0, 1, 2))
self.assertTrue(is_equal)
self.assertEqual(virtual_data.shape, (1024, 4096, 3))
file_receiver.stop()
@unittest.skip('skip')
def test_fifo_receiver_sender(self):
total_rgb = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量
image_queue, input_queue = ImgQueue(), ImgQueue()
fifo_receiver = FifoReceiver(job_name='fifo img receive', fifo_path='/tmp/dkimg.fifo', output=image_queue,
read_max_num=total_rgb)
fifo_sender = FifoSender(fifo_path='/tmp/dkimg.fifo', source=input_queue, job_name='fifo img send')
virtual_data = np.random.randint(0, 255, (1024, 4096, 3), dtype=np.uint8)
fifo_sender.start(preprocess=transmit.BeforeAfterMethods.mask_preprocess)
fifo_receiver.start()
logging.debug('Start to send virtual data')
for i in range(5):
logging.debug('put data to input queue')
input_queue.put(virtual_data)
logging.debug('put data to input queue done')
virtual_data = image_queue.get()
# logging.info(f'Spent {(current_time - time_record) * 1000:.2f}ms to get image with shape {virtual_data.shape}')
self.assertEqual(virtual_data.shape, (1024, 4096, 3))
if __name__ == '__main__':
unittest.main()

View File

@ -1,6 +1,7 @@
import multiprocessing
import os
import threading
import typing
from multiprocessing import Process, Queue
import time
from multiprocessing.synchronize import Lock
@ -16,17 +17,15 @@ from config import Config
from models import SpecDetector, RgbDetector
from typing import Any, Union
import logging
logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s %(message)s',
level=logging.WARNING)
class Transmitter(object):
_io_lock: Union[Lock, Lock]
def __init__(self, job_name:str, run_process:bool = False):
def __init__(self, job_name: str, run_process: bool = False):
self.output = None
self.job_name = job_name
self.run_process = run_process # If true, run process when started else run thread.
self.run_process = run_process # If true, run process when started else run thread.
self._thread_stop = threading.Event()
self._thread_stop.clear()
self._running_handler = None
@ -56,11 +55,11 @@ class Transmitter(object):
:param kwargs:
:return:
"""
name = kwargs.get('name', default='base thread')
name = kwargs.get('name', 'base thread')
if not self.run_process:
self._running_handler = threading.Thread(target=self.job_func, name=name, args=args)
self._running_handler = threading.Thread(target=self.job_func, name=name, args=args, kwargs=kwargs)
else:
self._running_handler = Process(target=self.job_func, name=name, args=args, daemon=True)
self._running_handler = Process(target=self.job_func, name=name, daemon=True, args=args, kwargs=kwargs)
self._running_handler.start()
def stop(self, *args, **kwargs):
@ -81,13 +80,14 @@ class Transmitter(object):
@staticmethod
def job_decorator(func):
functools.wraps(func)
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
logging.info(f'{self.job_name} {"process" if self.run_process else "thread"} start.')
while not self._thread_stop.is_set():
self.job_func(*args, **kwargs)
func(self, *args, **kwargs)
logging.info(f'{self.job_name} {"process" if self.run_process else "thread"} stop.')
self._need_stop.clear()
self._thread_stop.clear()
return wrapper
def job_func(self, *args, **kwargs):
@ -107,7 +107,7 @@ class BeforeAfterMethods:
logging.info(f"Get Spec threshold: {threshold}")
return threshold
else:
spec_img = np.frombuffer(data, dtype=np.float32).\
spec_img = np.frombuffer(data, dtype=np.float32). \
reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1)
logging.info(f"Get SPEC image with size {spec_img.shape}")
return spec_img
@ -125,8 +125,9 @@ class BeforeAfterMethods:
class FileReceiver(Transmitter):
def __init__(self, job_name:str, input_dir: str, output_queue:ImgQueue, speed: int=3, name_pattern=None):
super(FileReceiver, self).__init__(job_name=job_name, run_process=False)
def __init__(self, input_dir: str, output_queue: ImgQueue, speed: float = 3.0, name_pattern: str=None,
job_name: str = 'file_receiver', ):
super(FileReceiver, self).__init__(job_name=job_name)
self.input_dir = input_dir
self.send_speed = speed
self.file_names = None
@ -137,7 +138,7 @@ class FileReceiver(Transmitter):
self.set_source(input_dir, name_pattern)
self.set_output(output_queue)
def set_source(self, input_dir:str, name_pattern:str=None, preprocess_method:callable=None):
def set_source(self, input_dir: str, name_pattern: str = None, preprocess_method: callable = None):
self.name_pattern = name_pattern if name_pattern is not None else self.name_pattern
file_names = os.listdir(input_dir)
if len(file_names) == 0:
@ -152,11 +153,18 @@ class FileReceiver(Transmitter):
self.file_idx = 0
def set_output(self, output: ImgQueue):
with self._io_lock:
self.output_queue = output
self.stop()
self.output_queue = output
@Transmitter.job_decorator
def job_func(self, *args, **kwargs):
def job_func(self, need_time=False, *args, **kwargs):
"""
发送文件
:param need_time: 是否需要发送时间戳
:param kwargs: virtual_data: 虚拟的数据用于测试
:return:
"""
with self._io_lock:
self.file_idx += 1
if self.file_idx >= len(self.file_names):
@ -167,29 +175,37 @@ class FileReceiver(Transmitter):
data = f.read()
if self.preprocess_method is not None:
data = self.preprocess_method(data)
if need_time:
data = (time.time(), data)
if 'virtual_data' in kwargs:
data = (*data, kwargs['virtual_data'])
self.output_queue.put(data)
time.sleep(self.send_speed)
logging.info(f'sleep {self.send_speed}s ...')
class FifoReceiver(Transmitter):
def __init__(self, job_name:str, fifo_path: str, output: ImgQueue,
read_max_num: int, msg_queue=None):
def __init__(self, fifo_path: str, output: ImgQueue,
read_max_num: int, job_name: str = 'fifo_receiver'):
super().__init__(job_name=job_name)
self._input_fifo_path = None
self._output_queue = None
self._msg_queue = msg_queue
self._max_len = read_max_num
self.set_source(fifo_path)
self.set_output(output)
def set_source(self, fifo_path: str):
if not os.access(fifo_path, os.F_OK):
os.mkfifo(fifo_path, 0o777)
self._input_fifo_path = fifo_path
self.stop()
with self._io_lock:
if not os.access(fifo_path, os.F_OK):
os.mkfifo(fifo_path, 0o777)
self._input_fifo_path = fifo_path
def set_output(self, output: ImgQueue):
self._output_queue = output
self.stop()
with self._io_lock:
self._output_queue = output
@Transmitter.job_decorator
def job_func(self, post_process_func=None):
@ -203,34 +219,33 @@ class FifoReceiver(Transmitter):
data = os.read(input_fifo, self._max_len)
if post_process_func is not None:
data = post_process_func(data)
self._output_queue.safe_put(data)
self._output_queue.put(data)
os.close(input_fifo)
class FifoSender(Transmitter):
def __init__(self, output_fifo_path: str, source: ImgQueue):
super().__init__()
def __init__(self, fifo_path: str, source: ImgQueue, job_name: str = 'fifo_sender'):
super().__init__(job_name=job_name)
self._input_source = None
self._output_fifo_path = None
self.set_source(source)
self.set_output(output_fifo_path)
self._need_stop = threading.Event()
self._need_stop.clear()
self._running_thread = None
self.set_output(fifo_path)
def set_source(self, source: ImgQueue):
self._input_source = source
self.stop()
with self._io_lock:
self._input_source = source
def set_output(self, output_fifo_path: str):
if not os.access(output_fifo_path, os.F_OK):
os.mkfifo(output_fifo_path, 0o777)
self._output_fifo_path = output_fifo_path
self.stop()
with self._io_lock:
if not os.access(output_fifo_path, os.F_OK):
os.mkfifo(output_fifo_path, 0o777)
self._output_fifo_path = output_fifo_path
def job_func(self, pre_process, *args, **kwargs):
def job_func(self, pre_process=None, *args, **kwargs):
"""
接收线程
发送线程
:param pre_process:
:return:
@ -240,21 +255,18 @@ class FifoSender(Transmitter):
data = self._input_source.get()
if pre_process is not None:
data = pre_process(data)
logging.debug(f'put data to fifo {self._output_fifo_path}')
output_fifo = os.open(self._output_fifo_path, os.O_WRONLY)
os.write(output_fifo, data)
os.close(output_fifo)
def __del__(self):
self.stop()
if self._running_thread is not None:
self._running_thread.join()
logging.debug(f'put data to fifo {self._output_fifo_path} done')
class CmdImgSplitMidware(Transmitter):
"""
用于控制命令和图像的中间件
"""
def __init__(self, subscribers: typing.Dict[str, ImgQueue], rgb_queue: ImgQueue, spec_queue: ImgQueue):
super().__init__()
self._rgb_queue = None
@ -308,6 +320,7 @@ class ImageSaver(Transmitter):
"""
进行图片存储的中间件
"""
def set_source(self, *args, **kwargs):
pass
@ -354,7 +367,7 @@ class ThreadDetector(Transmitter):
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
t4 = time.time()
logging.info(f'Detector finish merge within {(t4 - t3) * 1000: .2f}ms')
logging.info(f'Detector finish predict within {(time.time() -t1)*1000:.2f}ms')
logging.info(f'Detector finish predict within {(time.time() - t1) * 1000:.2f}ms')
return mask_result
def _predict_server(self):
@ -408,7 +421,7 @@ class ProcessDetector(Transmitter):
masks = [cv2.resize(mask.astype(np.uint8), Config.target_size) for mask in masks]
t4 = time.time()
logging.debug(f'Detector finish merge within {(t4 - t3) * 1000: .2f}ms')
logging.debug(f'Detector finish predict within {(time.time() -t1)*1000:.2f}ms')
logging.debug(f'Detector finish predict within {(time.time() - t1) * 1000:.2f}ms')
return masks
def _predict_server(self):
@ -422,10 +435,14 @@ class ProcessDetector(Transmitter):
class SplitMidware(Transmitter):
def set_source(self, mask_source: ImgQueue):
pass
def start(self, *args, **kwargs):
pass
def stop(self, *args, **kwargs):
pass
pass
if __name__ == '__main__':
pass

View File

@ -186,7 +186,8 @@ def valve_limit(mask: np.ndarray, max_valve_num: int) -> np.ndarray:
row_valve_count = np.sum(mask, axis=1)
if np.any(row_valve_count > max_valve_num):
over_rows_idx = np.argwhere(row_valve_count > max_valve_num).ravel()
logging.warning(f'发现单行喷阀数量{len(over_rows_idx)}超过限制,已限制到最大许可值{max_valve_num}')
logging.warning(f'发现有{len(over_rows_idx)}行的单行喷阀数量超过限制,原喷阀数量为{row_valve_count[over_rows_idx]}'
f'已全部限制到最大许可值{max_valve_num}')
over_rows = mask[over_rows_idx, :]
# a simple function to get lucky valves when too many valves appear in the same line