From bbdfe9b6792587f18280c4bcd4000494080a0fe3 Mon Sep 17 00:00:00 2001
From: "li.zhenye"
Date: Sat, 20 Aug 2022 00:24:20 +0800
Subject: [PATCH] =?UTF-8?q?[ext]=20=E5=A4=9A=E8=BF=9B=E7=A8=8B=E5=8A=9F?=
=?UTF-8?q?=E8=83=BD=E5=88=9D=E7=BA=A7=E7=89=88=E6=9C=AC=E6=B5=8B=E8=AF=95?=
=?UTF-8?q?=E5=AE=8C=E6=AF=95?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
README.md | 30 +++++++++++++++-
tests/test_transmit.py | 24 +++++++++++--
transmit.py | 81 ++++++++++++++++++++++++++----------------
utils.py | 2 +-
4 files changed, 102 insertions(+), 35 deletions(-)
diff --git a/README.md b/README.md
index 7d1ce7a..e348716 100644
--- a/README.md
+++ b/README.md
@@ -390,7 +390,7 @@ python main_test.py /path/to/convert -convert_dir /output/dir -s
# 多线程读取与多进程预测
-## 多线程与进程总体结构图
+## 总体结构图

@@ -433,3 +433,31 @@ python main_test.py /path/to/convert -convert_dir /output/dir -s
这里我们得到了一个惊人的数据传递速度,只花了将近1ms,这速度看着很不错哦。接下来我们把这东西变成多进程。
+## 多进程读取
+
+为了能够实现多进程的读取,我们遇到了一个很奇怪的问题如下:
+
+> File "/Users/lizhenye/miniforge3/lib/python3.9/multiprocessing/reduction.py", line 60, in dump
+> ForkingPickler(file, protocol).dump(obj)
+> TypeError: cannot pickle '_thread.lock' object
+
+经过Google,我们发现问题在于使用了类方法,但经过更近一步的查找原因,我们发现,问题不出在类上,而是出在类里头包含了一些不可以被pickle的带有状态的文件,比如Queue和Fifo等等这些被Linux管理的底层资源。
+
+所以我们把找到了pickle会调用的`__getstate__`和`__setstate__`这两个python自带的类内方法,这两个方法分别在类被调用的时候将类内的变量交给pickle序列化和返序列化。所以我们对于所有类别的基础类别做出了修订:
+
+```python
+def __getstate__(self):
+ self.stop()
+ state = self.__dict__.copy()
+ state['_stop_event'] = None
+ state['_stateful_things'] = {}
+ return state
+
+def __setstate__(self, state):
+ self.__dict__.update(state)
+ self._stop_event = threading.Event()
+```
+
+就是在序列化之前把不能序列化的stateful_things收起来,把`_stop_event`这个线程间同步用的东西也给收起来,然后就可以复制类内所有的变量新的独立的进程运行了,新的进程会拥有自己进程内独立的`_stop_event`,这就导致我们其实已经失去了对于这个新开辟的子进程的控制,除非它自己调用自己的self.stop。
+
+好了,我觉得这个地方有点蠢,之后再改。
diff --git a/tests/test_transmit.py b/tests/test_transmit.py
index 0d56f4d..5e4c414 100644
--- a/tests/test_transmit.py
+++ b/tests/test_transmit.py
@@ -1,9 +1,9 @@
import logging
+import multiprocessing
import time
import unittest
import numpy as np
-
import transmit
from config import Config
from transmit import FileReceiver, FifoReceiver, FifoSender
@@ -11,6 +11,7 @@ from utils import ImgQueue
class TransmitterTest(unittest.TestCase):
+ @unittest.skip("file receiver thread test pass")
def test_file_receiver(self):
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logging.info('测试文件接收器')
@@ -28,6 +29,25 @@ class TransmitterTest(unittest.TestCase):
self.assertEqual(virtual_data.shape, (1024, 4096, 3))
file_receiver.stop()
+ # @unittest.skip('skip')
+ def test_file_receiver_subprocess(self):
+ logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
+ logging.info('测试子进程文件接收器')
+ image_queue = multiprocessing.Queue()
+ file_receiver = FileReceiver(job_name='rgb img receive', input_dir='../data', output_queue=image_queue,
+ speed=0.5, name_pattern=None, run_process=True)
+ virtual_data = np.random.randint(0, 255, (1024, 4096, 3), dtype=np.uint8)
+ file_receiver.start(need_time=True, virtual_data=virtual_data)
+ for i in range(5):
+ time_record, read_data, virtual_data_rec = image_queue.get()
+ current_time = time.time()
+ logging.info(
+ f'Spent {(current_time - time_record) * 1000:.2f}ms to get image with shape {virtual_data.shape}')
+ is_equal = np.all(virtual_data_rec == virtual_data, axis=(0, 1, 2))
+ self.assertTrue(is_equal)
+ self.assertEqual(virtual_data.shape, (1024, 4096, 3))
+ file_receiver.stop()
+
@unittest.skip('skip')
def test_fifo_receiver_sender(self):
total_rgb = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量
@@ -41,11 +61,9 @@ class TransmitterTest(unittest.TestCase):
logging.debug('Start to send virtual data')
for i in range(5):
logging.debug('put data to input queue')
-
input_queue.put(virtual_data)
logging.debug('put data to input queue done')
virtual_data = image_queue.get()
-
# logging.info(f'Spent {(current_time - time_record) * 1000:.2f}ms to get image with shape {virtual_data.shape}')
self.assertEqual(virtual_data.shape, (1024, 4096, 3))
diff --git a/transmit.py b/transmit.py
index 38011e6..9430d8f 100644
--- a/transmit.py
+++ b/transmit.py
@@ -18,18 +18,21 @@ from models import SpecDetector, RgbDetector
from typing import Any, Union
import logging
+def test_func(*args, **kwargs):
+ print('test_func')
+ print(kwargs)
+ return 'test_func'
class Transmitter(object):
- _io_lock: Union[Lock, Lock]
def __init__(self, job_name: str, run_process: bool = False):
self.output = None
self.job_name = job_name
self.run_process = run_process # If true, run process when started else run thread.
- self._thread_stop = threading.Event()
- self._thread_stop.clear()
+ self._stop_event = threading.Event()
+ self._stop_event.clear()
self._running_handler = None
- self._io_lock = multiprocessing.Lock() if run_process else threading.Lock()
+ self._stateful_things = {}
def set_source(self, *args, **kwargs):
"""
@@ -59,6 +62,7 @@ class Transmitter(object):
if not self.run_process:
self._running_handler = threading.Thread(target=self.job_func, name=name, args=args, kwargs=kwargs)
else:
+ kwargs.update({'_stateful_things': self._stateful_things})
self._running_handler = Process(target=self.job_func, name=name, daemon=True, args=args, kwargs=kwargs)
self._running_handler.start()
@@ -70,7 +74,7 @@ class Transmitter(object):
:return:
"""
if self._running_handler is not None:
- self._thread_stop.set()
+ self._stop_event.set()
self._running_handler = None
def __del__(self):
@@ -83,16 +87,28 @@ class Transmitter(object):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
logging.info(f'{self.job_name} {"process" if self.run_process else "thread"} start.')
- while not self._thread_stop.is_set():
+ if self.run_process:
+ self._stateful_things = kwargs['_stateful_things']
+ while not self._stop_event.is_set():
func(self, *args, **kwargs)
logging.info(f'{self.job_name} {"process" if self.run_process else "thread"} stop.')
- self._thread_stop.clear()
-
+ self._stop_event.clear()
return wrapper
def job_func(self, *args, **kwargs):
raise NotImplementedError
+ def __getstate__(self):
+ self.stop()
+ state = self.__dict__.copy()
+ state['_stop_event'] = None
+ state['_stateful_things'] = {}
+ return state
+
+ def __setstate__(self, state):
+ self.__dict__.update(state)
+ self._stop_event = threading.Event()
+
class BeforeAfterMethods:
@classmethod
@@ -125,9 +141,9 @@ class BeforeAfterMethods:
class FileReceiver(Transmitter):
- def __init__(self, input_dir: str, output_queue: ImgQueue, speed: float = 3.0, name_pattern: str=None,
- job_name: str = 'file_receiver', ):
- super(FileReceiver, self).__init__(job_name=job_name)
+ def __init__(self, input_dir: str, output_queue: ImgQueue, speed: float = 3.0, name_pattern: str = None,
+ job_name: str = 'file_receiver', run_process: bool = False):
+ super(FileReceiver, self).__init__(job_name=job_name, run_process=run_process)
self.input_dir = input_dir
self.send_speed = speed
self.file_names = None
@@ -139,6 +155,7 @@ class FileReceiver(Transmitter):
self.set_output(output_queue)
def set_source(self, input_dir: str, name_pattern: str = None, preprocess_method: callable = None):
+ self.stop()
self.name_pattern = name_pattern if name_pattern is not None else self.name_pattern
file_names = os.listdir(input_dir)
if len(file_names) == 0:
@@ -148,13 +165,13 @@ class FileReceiver(Transmitter):
else:
file_names = file_names
- with self._io_lock:
- self.file_names = file_names
- self.file_idx = 0
+ # with self._io_lock:
+ self.file_names = file_names
+ self.file_idx = 0
def set_output(self, output: ImgQueue):
self.stop()
- self.output_queue = output
+ self._stateful_things['output_queue'] = output
@Transmitter.job_decorator
def job_func(self, need_time=False, *args, **kwargs):
@@ -162,14 +179,14 @@ class FileReceiver(Transmitter):
发送文件
:param need_time: 是否需要发送时间戳
- :param kwargs: virtual_data: 虚拟的数据,用于测试
+ :param kwargs: output_queue: 以进程模式运行时需要, virtual_data: 虚拟的数据,用于测试
:return:
"""
- with self._io_lock:
- self.file_idx += 1
- if self.file_idx >= len(self.file_names):
- self.file_idx = 0
- file_name = self.file_names[self.file_idx]
+ logging.debug(f'{self.job_name} start.')
+ self.file_idx += 1
+ if self.file_idx >= len(self.file_names):
+ self.file_idx = 0
+ file_name = self.file_names[self.file_idx]
file_name = os.path.join(self.input_dir, file_name)
with open(file_name, 'rb') as f:
data = f.read()
@@ -179,7 +196,7 @@ class FileReceiver(Transmitter):
data = (time.time(), data)
if 'virtual_data' in kwargs:
data = (*data, kwargs['virtual_data'])
- self.output_queue.put(data)
+ self._stateful_things['output_queue'].put(data)
time.sleep(self.send_speed)
logging.info(f'sleep {self.send_speed}s ...')
@@ -197,15 +214,13 @@ class FifoReceiver(Transmitter):
def set_source(self, fifo_path: str):
self.stop()
- with self._io_lock:
- if not os.access(fifo_path, os.F_OK):
- os.mkfifo(fifo_path, 0o777)
- self._input_fifo_path = fifo_path
+ if not os.access(fifo_path, os.F_OK):
+ os.mkfifo(fifo_path, 0o777)
+ self._input_fifo_path = fifo_path
def set_output(self, output: ImgQueue):
self.stop()
- with self._io_lock:
- self._output_queue = output
+ self._output_queue = output
@Transmitter.job_decorator
def job_func(self, post_process_func=None):
@@ -222,6 +237,12 @@ class FifoReceiver(Transmitter):
self._output_queue.put(data)
os.close(input_fifo)
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ state['_input_fifo_path'] = None
+ state['_output_queue'] = None
+ return state
+
class FifoSender(Transmitter):
def __init__(self, fifo_path: str, source: ImgQueue, job_name: str = 'fifo_sender'):
@@ -308,7 +329,7 @@ class CmdImgSplitMidware(Transmitter):
# 如果是图片,交给预测的人
for name, subscriber in self._subscribers.items():
item = (spec_data, rgb_data)
- subscriber.safe_put(item)
+ subscriber.fifo_put(item)
else:
# 否则程序出现毁灭性问题,立刻崩
logging.critical('两个相机传回的数据没有对上')
@@ -375,7 +396,7 @@ class ThreadDetector(Transmitter):
if not self._input_queue.empty():
spec, rgb = self._input_queue.get()
mask = self.predict(spec, rgb)
- self._output_queue.safe_put(mask)
+ self._output_queue.fifo_put(mask)
self._thread_exit.clear()
diff --git a/utils.py b/utils.py
index ea0ef2f..6d75d13 100755
--- a/utils.py
+++ b/utils.py
@@ -63,7 +63,7 @@ class ImgQueue(Queue):
self.queue.clear()
self.not_full.notify_all()
- def safe_put(self, item):
+ def fifo_put(self, item):
if self.full():
_ = self.get()
return False