[ext] 添加了多线程接收功能

This commit is contained in:
li.zhenye 2022-08-18 09:19:44 +08:00
parent 631b46d99a
commit 15c8bcb69c
4 changed files with 168 additions and 51 deletions

View File

@ -388,3 +388,48 @@ python main_test.py /path/to/convert -convert_dir /output/dir -s
![image-20220808123044267](https://raw.githubusercontent.com/Karllzy/imagebed/main/img/image-20220808123044267.png) ![image-20220808123044267](https://raw.githubusercontent.com/Karllzy/imagebed/main/img/image-20220808123044267.png)
# 多线程读取与多进程预测
## 多线程与进程总体结构图
![MultiThread](https://raw.githubusercontent.com/Karllzy/imagebed/main/img/MultiThread.png)
## 多线程读取
为了能够避免IO的等待我们使用了开销相对较小的线程来实现多线程的数据读取。
因为很多时候我们需要读图测试所以我们写了一个FileReceiver类用法大概就像测试文件里这样
```python
def test_file_receiver(self):
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
image_queue = ImgQueue()
file_receiver = FileReceiver(job_name='rgb img receive', input_dir='../data', output_queue=image_queue,
speed=0.5, name_pattern=None)
virtual_data = np.zeros((1024, 4096, 3), dtype=np.uint8)
file_receiver.start(need_time=True, virtual_data=virtual_data)
for i in range(10):
data = image_queue.get()
time_record = data[0]
logging.info(f'Spent {(time.time() - time_record) * 1000:.2f}ms to get image with shape {data[-1].shape}')
self.assertEqual(data[-1].shape, (1024, 4096, 3))
file_receiver.stop()
```
测试结果如下所示:
> 2022-08-17 23:46:09,742 - root - INFO - rgb img receive thread start.
> 2022-08-17 23:46:09,754 - root - INFO - Spent 0.04ms to get image with shape (1024, 4096, 3)
> 2022-08-17 23:46:10,259 - root - INFO - sleep 0.5s ...
> 2022-08-17 23:46:10,276 - root - INFO - Spent 0.92ms to get image with shape (1024, 4096, 3)
> 2022-08-17 23:46:10,780 - root - INFO - sleep 0.5s ...
> 2022-08-17 23:46:10,789 - root - INFO - Spent 0.79ms to get image with shape (1024, 4096, 3)
> 2022-08-17 23:46:11,293 - root - INFO - sleep 0.5s ...
> 2022-08-17 23:46:11,301 - root - INFO - Spent 0.81ms to get image with shape (1024, 4096, 3)
> 2022-08-17 23:46:11,802 - root - INFO - sleep 0.5s ...
> 2022-08-17 23:46:11,810 - root - INFO - Spent 0.77ms to get image with shape (1024, 4096, 3)
> 2022-08-17 23:46:12,314 - root - INFO - sleep 0.5s ...
> 2022-08-17 23:46:12,315 - root - INFO - rgb img receive thread stop.
这里我们得到了一个惊人的数据传递速度只花了将近1ms这速度看着很不错哦。接下来我们把这东西变成多进程。

54
tests/test_transmit.py Normal file
View File

@ -0,0 +1,54 @@
import logging
import time
import unittest
import numpy as np
import transmit
from config import Config
from transmit import FileReceiver, FifoReceiver, FifoSender
from utils import ImgQueue
class TransmitterTest(unittest.TestCase):
def test_file_receiver(self):
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logging.info('测试文件接收器')
image_queue = ImgQueue()
file_receiver = FileReceiver(job_name='rgb img receive', input_dir='../data', output_queue=image_queue,
speed=0.5, name_pattern=None)
virtual_data = np.random.randint(0, 255, (1024, 4096, 3), dtype=np.uint8)
file_receiver.start(need_time=True, virtual_data=virtual_data)
for i in range(5):
time_record, read_data, virtual_data_rec = image_queue.get()
current_time = time.time()
logging.info(f'Spent {(current_time - time_record) * 1000:.2f}ms to get image with shape {virtual_data.shape}')
is_equal = np.all(virtual_data_rec == virtual_data, axis=(0, 1, 2))
self.assertTrue(is_equal)
self.assertEqual(virtual_data.shape, (1024, 4096, 3))
file_receiver.stop()
@unittest.skip('skip')
def test_fifo_receiver_sender(self):
total_rgb = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量
image_queue, input_queue = ImgQueue(), ImgQueue()
fifo_receiver = FifoReceiver(job_name='fifo img receive', fifo_path='/tmp/dkimg.fifo', output=image_queue,
read_max_num=total_rgb)
fifo_sender = FifoSender(fifo_path='/tmp/dkimg.fifo', source=input_queue, job_name='fifo img send')
virtual_data = np.random.randint(0, 255, (1024, 4096, 3), dtype=np.uint8)
fifo_sender.start(preprocess=transmit.BeforeAfterMethods.mask_preprocess)
fifo_receiver.start()
logging.debug('Start to send virtual data')
for i in range(5):
logging.debug('put data to input queue')
input_queue.put(virtual_data)
logging.debug('put data to input queue done')
virtual_data = image_queue.get()
# logging.info(f'Spent {(current_time - time_record) * 1000:.2f}ms to get image with shape {virtual_data.shape}')
self.assertEqual(virtual_data.shape, (1024, 4096, 3))
if __name__ == '__main__':
unittest.main()

View File

@ -1,6 +1,7 @@
import multiprocessing import multiprocessing
import os import os
import threading import threading
import typing
from multiprocessing import Process, Queue from multiprocessing import Process, Queue
import time import time
from multiprocessing.synchronize import Lock from multiprocessing.synchronize import Lock
@ -16,14 +17,12 @@ from config import Config
from models import SpecDetector, RgbDetector from models import SpecDetector, RgbDetector
from typing import Any, Union from typing import Any, Union
import logging import logging
logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s %(message)s',
level=logging.WARNING)
class Transmitter(object): class Transmitter(object):
_io_lock: Union[Lock, Lock] _io_lock: Union[Lock, Lock]
def __init__(self, job_name:str, run_process:bool = False): def __init__(self, job_name: str, run_process: bool = False):
self.output = None self.output = None
self.job_name = job_name self.job_name = job_name
self.run_process = run_process # If true, run process when started else run thread. self.run_process = run_process # If true, run process when started else run thread.
@ -56,11 +55,11 @@ class Transmitter(object):
:param kwargs: :param kwargs:
:return: :return:
""" """
name = kwargs.get('name', default='base thread') name = kwargs.get('name', 'base thread')
if not self.run_process: if not self.run_process:
self._running_handler = threading.Thread(target=self.job_func, name=name, args=args) self._running_handler = threading.Thread(target=self.job_func, name=name, args=args, kwargs=kwargs)
else: else:
self._running_handler = Process(target=self.job_func, name=name, args=args, daemon=True) self._running_handler = Process(target=self.job_func, name=name, daemon=True, args=args, kwargs=kwargs)
self._running_handler.start() self._running_handler.start()
def stop(self, *args, **kwargs): def stop(self, *args, **kwargs):
@ -81,13 +80,14 @@ class Transmitter(object):
@staticmethod @staticmethod
def job_decorator(func): def job_decorator(func):
functools.wraps(func) @functools.wraps(func)
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
logging.info(f'{self.job_name} {"process" if self.run_process else "thread"} start.') logging.info(f'{self.job_name} {"process" if self.run_process else "thread"} start.')
while not self._thread_stop.is_set(): while not self._thread_stop.is_set():
self.job_func(*args, **kwargs) func(self, *args, **kwargs)
logging.info(f'{self.job_name} {"process" if self.run_process else "thread"} stop.') logging.info(f'{self.job_name} {"process" if self.run_process else "thread"} stop.')
self._need_stop.clear() self._thread_stop.clear()
return wrapper return wrapper
def job_func(self, *args, **kwargs): def job_func(self, *args, **kwargs):
@ -107,7 +107,7 @@ class BeforeAfterMethods:
logging.info(f"Get Spec threshold: {threshold}") logging.info(f"Get Spec threshold: {threshold}")
return threshold return threshold
else: else:
spec_img = np.frombuffer(data, dtype=np.float32).\ spec_img = np.frombuffer(data, dtype=np.float32). \
reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1) reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1)
logging.info(f"Get SPEC image with size {spec_img.shape}") logging.info(f"Get SPEC image with size {spec_img.shape}")
return spec_img return spec_img
@ -125,8 +125,9 @@ class BeforeAfterMethods:
class FileReceiver(Transmitter): class FileReceiver(Transmitter):
def __init__(self, job_name:str, input_dir: str, output_queue:ImgQueue, speed: int=3, name_pattern=None): def __init__(self, input_dir: str, output_queue: ImgQueue, speed: float = 3.0, name_pattern: str=None,
super(FileReceiver, self).__init__(job_name=job_name, run_process=False) job_name: str = 'file_receiver', ):
super(FileReceiver, self).__init__(job_name=job_name)
self.input_dir = input_dir self.input_dir = input_dir
self.send_speed = speed self.send_speed = speed
self.file_names = None self.file_names = None
@ -137,7 +138,7 @@ class FileReceiver(Transmitter):
self.set_source(input_dir, name_pattern) self.set_source(input_dir, name_pattern)
self.set_output(output_queue) self.set_output(output_queue)
def set_source(self, input_dir:str, name_pattern:str=None, preprocess_method:callable=None): def set_source(self, input_dir: str, name_pattern: str = None, preprocess_method: callable = None):
self.name_pattern = name_pattern if name_pattern is not None else self.name_pattern self.name_pattern = name_pattern if name_pattern is not None else self.name_pattern
file_names = os.listdir(input_dir) file_names = os.listdir(input_dir)
if len(file_names) == 0: if len(file_names) == 0:
@ -152,11 +153,18 @@ class FileReceiver(Transmitter):
self.file_idx = 0 self.file_idx = 0
def set_output(self, output: ImgQueue): def set_output(self, output: ImgQueue):
with self._io_lock: self.stop()
self.output_queue = output self.output_queue = output
@Transmitter.job_decorator @Transmitter.job_decorator
def job_func(self, *args, **kwargs): def job_func(self, need_time=False, *args, **kwargs):
"""
发送文件
:param need_time: 是否需要发送时间戳
:param kwargs: virtual_data: 虚拟的数据用于测试
:return:
"""
with self._io_lock: with self._io_lock:
self.file_idx += 1 self.file_idx += 1
if self.file_idx >= len(self.file_names): if self.file_idx >= len(self.file_names):
@ -167,28 +175,36 @@ class FileReceiver(Transmitter):
data = f.read() data = f.read()
if self.preprocess_method is not None: if self.preprocess_method is not None:
data = self.preprocess_method(data) data = self.preprocess_method(data)
if need_time:
data = (time.time(), data)
if 'virtual_data' in kwargs:
data = (*data, kwargs['virtual_data'])
self.output_queue.put(data) self.output_queue.put(data)
time.sleep(self.send_speed)
logging.info(f'sleep {self.send_speed}s ...')
class FifoReceiver(Transmitter): class FifoReceiver(Transmitter):
def __init__(self, job_name:str, fifo_path: str, output: ImgQueue, def __init__(self, fifo_path: str, output: ImgQueue,
read_max_num: int, msg_queue=None): read_max_num: int, job_name: str = 'fifo_receiver'):
super().__init__(job_name=job_name) super().__init__(job_name=job_name)
self._input_fifo_path = None self._input_fifo_path = None
self._output_queue = None self._output_queue = None
self._msg_queue = msg_queue
self._max_len = read_max_num self._max_len = read_max_num
self.set_source(fifo_path) self.set_source(fifo_path)
self.set_output(output) self.set_output(output)
def set_source(self, fifo_path: str): def set_source(self, fifo_path: str):
self.stop()
with self._io_lock:
if not os.access(fifo_path, os.F_OK): if not os.access(fifo_path, os.F_OK):
os.mkfifo(fifo_path, 0o777) os.mkfifo(fifo_path, 0o777)
self._input_fifo_path = fifo_path self._input_fifo_path = fifo_path
def set_output(self, output: ImgQueue): def set_output(self, output: ImgQueue):
self.stop()
with self._io_lock:
self._output_queue = output self._output_queue = output
@Transmitter.job_decorator @Transmitter.job_decorator
@ -203,34 +219,33 @@ class FifoReceiver(Transmitter):
data = os.read(input_fifo, self._max_len) data = os.read(input_fifo, self._max_len)
if post_process_func is not None: if post_process_func is not None:
data = post_process_func(data) data = post_process_func(data)
self._output_queue.safe_put(data) self._output_queue.put(data)
os.close(input_fifo) os.close(input_fifo)
class FifoSender(Transmitter): class FifoSender(Transmitter):
def __init__(self, output_fifo_path: str, source: ImgQueue): def __init__(self, fifo_path: str, source: ImgQueue, job_name: str = 'fifo_sender'):
super().__init__() super().__init__(job_name=job_name)
self._input_source = None self._input_source = None
self._output_fifo_path = None self._output_fifo_path = None
self.set_source(source) self.set_source(source)
self.set_output(output_fifo_path) self.set_output(fifo_path)
self._need_stop = threading.Event()
self._need_stop.clear()
self._running_thread = None
def set_source(self, source: ImgQueue): def set_source(self, source: ImgQueue):
self.stop()
with self._io_lock:
self._input_source = source self._input_source = source
def set_output(self, output_fifo_path: str): def set_output(self, output_fifo_path: str):
self.stop()
with self._io_lock:
if not os.access(output_fifo_path, os.F_OK): if not os.access(output_fifo_path, os.F_OK):
os.mkfifo(output_fifo_path, 0o777) os.mkfifo(output_fifo_path, 0o777)
self._output_fifo_path = output_fifo_path self._output_fifo_path = output_fifo_path
def job_func(self, pre_process=None, *args, **kwargs):
def job_func(self, pre_process, *args, **kwargs):
""" """
接收线程 发送线程
:param pre_process: :param pre_process:
:return: :return:
@ -240,21 +255,18 @@ class FifoSender(Transmitter):
data = self._input_source.get() data = self._input_source.get()
if pre_process is not None: if pre_process is not None:
data = pre_process(data) data = pre_process(data)
logging.debug(f'put data to fifo {self._output_fifo_path}')
output_fifo = os.open(self._output_fifo_path, os.O_WRONLY) output_fifo = os.open(self._output_fifo_path, os.O_WRONLY)
os.write(output_fifo, data) os.write(output_fifo, data)
os.close(output_fifo) os.close(output_fifo)
logging.debug(f'put data to fifo {self._output_fifo_path} done')
def __del__(self):
self.stop()
if self._running_thread is not None:
self._running_thread.join()
class CmdImgSplitMidware(Transmitter): class CmdImgSplitMidware(Transmitter):
""" """
用于控制命令和图像的中间件 用于控制命令和图像的中间件
""" """
def __init__(self, subscribers: typing.Dict[str, ImgQueue], rgb_queue: ImgQueue, spec_queue: ImgQueue): def __init__(self, subscribers: typing.Dict[str, ImgQueue], rgb_queue: ImgQueue, spec_queue: ImgQueue):
super().__init__() super().__init__()
self._rgb_queue = None self._rgb_queue = None
@ -308,6 +320,7 @@ class ImageSaver(Transmitter):
""" """
进行图片存储的中间件 进行图片存储的中间件
""" """
def set_source(self, *args, **kwargs): def set_source(self, *args, **kwargs):
pass pass
@ -354,7 +367,7 @@ class ThreadDetector(Transmitter):
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8) mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
t4 = time.time() t4 = time.time()
logging.info(f'Detector finish merge within {(t4 - t3) * 1000: .2f}ms') logging.info(f'Detector finish merge within {(t4 - t3) * 1000: .2f}ms')
logging.info(f'Detector finish predict within {(time.time() -t1)*1000:.2f}ms') logging.info(f'Detector finish predict within {(time.time() - t1) * 1000:.2f}ms')
return mask_result return mask_result
def _predict_server(self): def _predict_server(self):
@ -408,7 +421,7 @@ class ProcessDetector(Transmitter):
masks = [cv2.resize(mask.astype(np.uint8), Config.target_size) for mask in masks] masks = [cv2.resize(mask.astype(np.uint8), Config.target_size) for mask in masks]
t4 = time.time() t4 = time.time()
logging.debug(f'Detector finish merge within {(t4 - t3) * 1000: .2f}ms') logging.debug(f'Detector finish merge within {(t4 - t3) * 1000: .2f}ms')
logging.debug(f'Detector finish predict within {(time.time() -t1)*1000:.2f}ms') logging.debug(f'Detector finish predict within {(time.time() - t1) * 1000:.2f}ms')
return masks return masks
def _predict_server(self): def _predict_server(self):
@ -422,10 +435,14 @@ class ProcessDetector(Transmitter):
class SplitMidware(Transmitter): class SplitMidware(Transmitter):
def set_source(self, mask_source: ImgQueue): def set_source(self, mask_source: ImgQueue):
pass
def start(self, *args, **kwargs): def start(self, *args, **kwargs):
pass pass
def stop(self, *args, **kwargs): def stop(self, *args, **kwargs):
pass pass
if __name__ == '__main__':
pass

View File

@ -186,7 +186,8 @@ def valve_limit(mask: np.ndarray, max_valve_num: int) -> np.ndarray:
row_valve_count = np.sum(mask, axis=1) row_valve_count = np.sum(mask, axis=1)
if np.any(row_valve_count > max_valve_num): if np.any(row_valve_count > max_valve_num):
over_rows_idx = np.argwhere(row_valve_count > max_valve_num).ravel() over_rows_idx = np.argwhere(row_valve_count > max_valve_num).ravel()
logging.warning(f'发现单行喷阀数量{len(over_rows_idx)}超过限制,已限制到最大许可值{max_valve_num}') logging.warning(f'发现有{len(over_rows_idx)}行的单行喷阀数量超过限制,原喷阀数量为{row_valve_count[over_rows_idx]}'
f'已全部限制到最大许可值{max_valve_num}')
over_rows = mask[over_rows_idx, :] over_rows = mask[over_rows_idx, :]
# a simple function to get lucky valves when too many valves appear in the same line # a simple function to get lucky valves when too many valves appear in the same line