From 15c8bcb69c402aee97b1e98b6f6f798940e71474 Mon Sep 17 00:00:00 2001
From: "li.zhenye"
Date: Thu, 18 Aug 2022 09:19:44 +0800
Subject: [PATCH] =?UTF-8?q?[ext]=20=E6=B7=BB=E5=8A=A0=E4=BA=86=E5=A4=9A?=
=?UTF-8?q?=E7=BA=BF=E7=A8=8B=E6=8E=A5=E6=94=B6=E5=8A=9F=E8=83=BD?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
README.md | 45 ++++++++++++++++
tests/test_transmit.py | 54 +++++++++++++++++++
transmit.py | 117 +++++++++++++++++++++++------------------
utils.py | 3 +-
4 files changed, 168 insertions(+), 51 deletions(-)
create mode 100644 tests/test_transmit.py
diff --git a/README.md b/README.md
index 18445a1..7d1ce7a 100644
--- a/README.md
+++ b/README.md
@@ -388,3 +388,48 @@ python main_test.py /path/to/convert -convert_dir /output/dir -s

+# 多线程读取与多进程预测
+
+## 多线程与进程总体结构图
+
+
+
+## 多线程读取
+
+为了能够避免IO的等待,我们使用了开销相对较小的线程来实现多线程的数据读取。
+
+因为很多时候我们需要读图测试,所以我们写了一个FileReceiver类,用法大概就像测试文件里这样:
+
+```python
+ def test_file_receiver(self):
+ logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+ image_queue = ImgQueue()
+ file_receiver = FileReceiver(job_name='rgb img receive', input_dir='../data', output_queue=image_queue,
+ speed=0.5, name_pattern=None)
+ virtual_data = np.zeros((1024, 4096, 3), dtype=np.uint8)
+ file_receiver.start(need_time=True, virtual_data=virtual_data)
+ for i in range(10):
+ data = image_queue.get()
+ time_record = data[0]
+ logging.info(f'Spent {(time.time() - time_record) * 1000:.2f}ms to get image with shape {data[-1].shape}')
+ self.assertEqual(data[-1].shape, (1024, 4096, 3))
+ file_receiver.stop()
+```
+
+测试结果如下所示:
+
+> 2022-08-17 23:46:09,742 - root - INFO - rgb img receive thread start.
+> 2022-08-17 23:46:09,754 - root - INFO - Spent 0.04ms to get image with shape (1024, 4096, 3)
+> 2022-08-17 23:46:10,259 - root - INFO - sleep 0.5s ...
+> 2022-08-17 23:46:10,276 - root - INFO - Spent 0.92ms to get image with shape (1024, 4096, 3)
+> 2022-08-17 23:46:10,780 - root - INFO - sleep 0.5s ...
+> 2022-08-17 23:46:10,789 - root - INFO - Spent 0.79ms to get image with shape (1024, 4096, 3)
+> 2022-08-17 23:46:11,293 - root - INFO - sleep 0.5s ...
+> 2022-08-17 23:46:11,301 - root - INFO - Spent 0.81ms to get image with shape (1024, 4096, 3)
+> 2022-08-17 23:46:11,802 - root - INFO - sleep 0.5s ...
+> 2022-08-17 23:46:11,810 - root - INFO - Spent 0.77ms to get image with shape (1024, 4096, 3)
+> 2022-08-17 23:46:12,314 - root - INFO - sleep 0.5s ...
+> 2022-08-17 23:46:12,315 - root - INFO - rgb img receive thread stop.
+
+这里我们得到了一个惊人的数据传递速度,只花了将近1ms,这速度看着很不错哦。接下来我们把这东西变成多进程。
+
diff --git a/tests/test_transmit.py b/tests/test_transmit.py
new file mode 100644
index 0000000..0d56f4d
--- /dev/null
+++ b/tests/test_transmit.py
@@ -0,0 +1,54 @@
+import logging
+import time
+import unittest
+
+import numpy as np
+
+import transmit
+from config import Config
+from transmit import FileReceiver, FifoReceiver, FifoSender
+from utils import ImgQueue
+
+
+class TransmitterTest(unittest.TestCase):
+ def test_file_receiver(self):
+ logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+ logging.info('测试文件接收器')
+ image_queue = ImgQueue()
+ file_receiver = FileReceiver(job_name='rgb img receive', input_dir='../data', output_queue=image_queue,
+ speed=0.5, name_pattern=None)
+ virtual_data = np.random.randint(0, 255, (1024, 4096, 3), dtype=np.uint8)
+ file_receiver.start(need_time=True, virtual_data=virtual_data)
+ for i in range(5):
+ time_record, read_data, virtual_data_rec = image_queue.get()
+ current_time = time.time()
+ logging.info(f'Spent {(current_time - time_record) * 1000:.2f}ms to get image with shape {virtual_data.shape}')
+ is_equal = np.all(virtual_data_rec == virtual_data, axis=(0, 1, 2))
+ self.assertTrue(is_equal)
+ self.assertEqual(virtual_data.shape, (1024, 4096, 3))
+ file_receiver.stop()
+
+ @unittest.skip('skip')
+ def test_fifo_receiver_sender(self):
+ total_rgb = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量
+ image_queue, input_queue = ImgQueue(), ImgQueue()
+ fifo_receiver = FifoReceiver(job_name='fifo img receive', fifo_path='/tmp/dkimg.fifo', output=image_queue,
+ read_max_num=total_rgb)
+ fifo_sender = FifoSender(fifo_path='/tmp/dkimg.fifo', source=input_queue, job_name='fifo img send')
+ virtual_data = np.random.randint(0, 255, (1024, 4096, 3), dtype=np.uint8)
+ fifo_sender.start(preprocess=transmit.BeforeAfterMethods.mask_preprocess)
+ fifo_receiver.start()
+ logging.debug('Start to send virtual data')
+ for i in range(5):
+ logging.debug('put data to input queue')
+
+ input_queue.put(virtual_data)
+ logging.debug('put data to input queue done')
+ virtual_data = image_queue.get()
+
+ # logging.info(f'Spent {(current_time - time_record) * 1000:.2f}ms to get image with shape {virtual_data.shape}')
+ self.assertEqual(virtual_data.shape, (1024, 4096, 3))
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/transmit.py b/transmit.py
index d7b7e89..38011e6 100644
--- a/transmit.py
+++ b/transmit.py
@@ -1,6 +1,7 @@
import multiprocessing
import os
import threading
+import typing
from multiprocessing import Process, Queue
import time
from multiprocessing.synchronize import Lock
@@ -16,17 +17,15 @@ from config import Config
from models import SpecDetector, RgbDetector
from typing import Any, Union
import logging
-logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s %(message)s',
- level=logging.WARNING)
class Transmitter(object):
_io_lock: Union[Lock, Lock]
- def __init__(self, job_name:str, run_process:bool = False):
+ def __init__(self, job_name: str, run_process: bool = False):
self.output = None
self.job_name = job_name
- self.run_process = run_process # If true, run process when started else run thread.
+ self.run_process = run_process # If true, run process when started else run thread.
self._thread_stop = threading.Event()
self._thread_stop.clear()
self._running_handler = None
@@ -56,11 +55,11 @@ class Transmitter(object):
:param kwargs:
:return:
"""
- name = kwargs.get('name', default='base thread')
+ name = kwargs.get('name', 'base thread')
if not self.run_process:
- self._running_handler = threading.Thread(target=self.job_func, name=name, args=args)
+ self._running_handler = threading.Thread(target=self.job_func, name=name, args=args, kwargs=kwargs)
else:
- self._running_handler = Process(target=self.job_func, name=name, args=args, daemon=True)
+ self._running_handler = Process(target=self.job_func, name=name, daemon=True, args=args, kwargs=kwargs)
self._running_handler.start()
def stop(self, *args, **kwargs):
@@ -81,13 +80,14 @@ class Transmitter(object):
@staticmethod
def job_decorator(func):
- functools.wraps(func)
+ @functools.wraps(func)
def wrapper(self, *args, **kwargs):
logging.info(f'{self.job_name} {"process" if self.run_process else "thread"} start.')
while not self._thread_stop.is_set():
- self.job_func(*args, **kwargs)
+ func(self, *args, **kwargs)
logging.info(f'{self.job_name} {"process" if self.run_process else "thread"} stop.')
- self._need_stop.clear()
+ self._thread_stop.clear()
+
return wrapper
def job_func(self, *args, **kwargs):
@@ -107,7 +107,7 @@ class BeforeAfterMethods:
logging.info(f"Get Spec threshold: {threshold}")
return threshold
else:
- spec_img = np.frombuffer(data, dtype=np.float32).\
+ spec_img = np.frombuffer(data, dtype=np.float32). \
reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1)
logging.info(f"Get SPEC image with size {spec_img.shape}")
return spec_img
@@ -125,8 +125,9 @@ class BeforeAfterMethods:
class FileReceiver(Transmitter):
- def __init__(self, job_name:str, input_dir: str, output_queue:ImgQueue, speed: int=3, name_pattern=None):
- super(FileReceiver, self).__init__(job_name=job_name, run_process=False)
+ def __init__(self, input_dir: str, output_queue: ImgQueue, speed: float = 3.0, name_pattern: str=None,
+ job_name: str = 'file_receiver', ):
+ super(FileReceiver, self).__init__(job_name=job_name)
self.input_dir = input_dir
self.send_speed = speed
self.file_names = None
@@ -137,7 +138,7 @@ class FileReceiver(Transmitter):
self.set_source(input_dir, name_pattern)
self.set_output(output_queue)
- def set_source(self, input_dir:str, name_pattern:str=None, preprocess_method:callable=None):
+ def set_source(self, input_dir: str, name_pattern: str = None, preprocess_method: callable = None):
self.name_pattern = name_pattern if name_pattern is not None else self.name_pattern
file_names = os.listdir(input_dir)
if len(file_names) == 0:
@@ -152,11 +153,18 @@ class FileReceiver(Transmitter):
self.file_idx = 0
def set_output(self, output: ImgQueue):
- with self._io_lock:
- self.output_queue = output
+ self.stop()
+ self.output_queue = output
@Transmitter.job_decorator
- def job_func(self, *args, **kwargs):
+ def job_func(self, need_time=False, *args, **kwargs):
+ """
+ 发送文件
+
+ :param need_time: 是否需要发送时间戳
+ :param kwargs: virtual_data: 虚拟的数据,用于测试
+ :return:
+ """
with self._io_lock:
self.file_idx += 1
if self.file_idx >= len(self.file_names):
@@ -167,29 +175,37 @@ class FileReceiver(Transmitter):
data = f.read()
if self.preprocess_method is not None:
data = self.preprocess_method(data)
+ if need_time:
+ data = (time.time(), data)
+ if 'virtual_data' in kwargs:
+ data = (*data, kwargs['virtual_data'])
self.output_queue.put(data)
-
+ time.sleep(self.send_speed)
+ logging.info(f'sleep {self.send_speed}s ...')
class FifoReceiver(Transmitter):
- def __init__(self, job_name:str, fifo_path: str, output: ImgQueue,
- read_max_num: int, msg_queue=None):
+ def __init__(self, fifo_path: str, output: ImgQueue,
+ read_max_num: int, job_name: str = 'fifo_receiver'):
super().__init__(job_name=job_name)
self._input_fifo_path = None
self._output_queue = None
- self._msg_queue = msg_queue
self._max_len = read_max_num
self.set_source(fifo_path)
self.set_output(output)
def set_source(self, fifo_path: str):
- if not os.access(fifo_path, os.F_OK):
- os.mkfifo(fifo_path, 0o777)
- self._input_fifo_path = fifo_path
+ self.stop()
+ with self._io_lock:
+ if not os.access(fifo_path, os.F_OK):
+ os.mkfifo(fifo_path, 0o777)
+ self._input_fifo_path = fifo_path
def set_output(self, output: ImgQueue):
- self._output_queue = output
+ self.stop()
+ with self._io_lock:
+ self._output_queue = output
@Transmitter.job_decorator
def job_func(self, post_process_func=None):
@@ -203,34 +219,33 @@ class FifoReceiver(Transmitter):
data = os.read(input_fifo, self._max_len)
if post_process_func is not None:
data = post_process_func(data)
- self._output_queue.safe_put(data)
+ self._output_queue.put(data)
os.close(input_fifo)
-
class FifoSender(Transmitter):
- def __init__(self, output_fifo_path: str, source: ImgQueue):
- super().__init__()
+ def __init__(self, fifo_path: str, source: ImgQueue, job_name: str = 'fifo_sender'):
+ super().__init__(job_name=job_name)
self._input_source = None
self._output_fifo_path = None
self.set_source(source)
- self.set_output(output_fifo_path)
- self._need_stop = threading.Event()
- self._need_stop.clear()
- self._running_thread = None
+ self.set_output(fifo_path)
def set_source(self, source: ImgQueue):
- self._input_source = source
+ self.stop()
+ with self._io_lock:
+ self._input_source = source
def set_output(self, output_fifo_path: str):
- if not os.access(output_fifo_path, os.F_OK):
- os.mkfifo(output_fifo_path, 0o777)
- self._output_fifo_path = output_fifo_path
+ self.stop()
+ with self._io_lock:
+ if not os.access(output_fifo_path, os.F_OK):
+ os.mkfifo(output_fifo_path, 0o777)
+ self._output_fifo_path = output_fifo_path
-
- def job_func(self, pre_process, *args, **kwargs):
+ def job_func(self, pre_process=None, *args, **kwargs):
"""
- 接收线程
+ 发送线程
:param pre_process:
:return:
@@ -240,21 +255,18 @@ class FifoSender(Transmitter):
data = self._input_source.get()
if pre_process is not None:
data = pre_process(data)
+ logging.debug(f'put data to fifo {self._output_fifo_path}')
output_fifo = os.open(self._output_fifo_path, os.O_WRONLY)
os.write(output_fifo, data)
os.close(output_fifo)
-
-
- def __del__(self):
- self.stop()
- if self._running_thread is not None:
- self._running_thread.join()
+ logging.debug(f'put data to fifo {self._output_fifo_path} done')
class CmdImgSplitMidware(Transmitter):
"""
用于控制命令和图像的中间件
"""
+
def __init__(self, subscribers: typing.Dict[str, ImgQueue], rgb_queue: ImgQueue, spec_queue: ImgQueue):
super().__init__()
self._rgb_queue = None
@@ -308,6 +320,7 @@ class ImageSaver(Transmitter):
"""
进行图片存储的中间件
"""
+
def set_source(self, *args, **kwargs):
pass
@@ -354,7 +367,7 @@ class ThreadDetector(Transmitter):
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
t4 = time.time()
logging.info(f'Detector finish merge within {(t4 - t3) * 1000: .2f}ms')
- logging.info(f'Detector finish predict within {(time.time() -t1)*1000:.2f}ms')
+ logging.info(f'Detector finish predict within {(time.time() - t1) * 1000:.2f}ms')
return mask_result
def _predict_server(self):
@@ -408,7 +421,7 @@ class ProcessDetector(Transmitter):
masks = [cv2.resize(mask.astype(np.uint8), Config.target_size) for mask in masks]
t4 = time.time()
logging.debug(f'Detector finish merge within {(t4 - t3) * 1000: .2f}ms')
- logging.debug(f'Detector finish predict within {(time.time() -t1)*1000:.2f}ms')
+ logging.debug(f'Detector finish predict within {(time.time() - t1) * 1000:.2f}ms')
return masks
def _predict_server(self):
@@ -422,10 +435,14 @@ class ProcessDetector(Transmitter):
class SplitMidware(Transmitter):
def set_source(self, mask_source: ImgQueue):
-
+ pass
def start(self, *args, **kwargs):
pass
def stop(self, *args, **kwargs):
- pass
\ No newline at end of file
+ pass
+
+
+if __name__ == '__main__':
+ pass
diff --git a/utils.py b/utils.py
index 6bd1bcb..ea0ef2f 100755
--- a/utils.py
+++ b/utils.py
@@ -186,7 +186,8 @@ def valve_limit(mask: np.ndarray, max_valve_num: int) -> np.ndarray:
row_valve_count = np.sum(mask, axis=1)
if np.any(row_valve_count > max_valve_num):
over_rows_idx = np.argwhere(row_valve_count > max_valve_num).ravel()
- logging.warning(f'发现单行喷阀数量{len(over_rows_idx)}超过限制,已限制到最大许可值{max_valve_num}')
+ logging.warning(f'发现有{len(over_rows_idx)}行的单行喷阀数量超过限制,原喷阀数量为{row_valve_count[over_rows_idx]},'
+ f'已全部限制到最大许可值{max_valve_num}')
over_rows = mask[over_rows_idx, :]
# a simple function to get lucky valves when too many valves appear in the same line