[ext] 多进程功能初级版本测试完毕

This commit is contained in:
li.zhenye 2022-08-20 00:24:20 +08:00
parent 15c8bcb69c
commit bbdfe9b679
4 changed files with 102 additions and 35 deletions

View File

@ -390,7 +390,7 @@ python main_test.py /path/to/convert -convert_dir /output/dir -s
# 多线程读取与多进程预测 # 多线程读取与多进程预测
## 多线程与进程总体结构图 ## 总体结构图
![MultiThread](https://raw.githubusercontent.com/Karllzy/imagebed/main/img/MultiThread.png) ![MultiThread](https://raw.githubusercontent.com/Karllzy/imagebed/main/img/MultiThread.png)
@ -433,3 +433,31 @@ python main_test.py /path/to/convert -convert_dir /output/dir -s
这里我们得到了一个惊人的数据传递速度只花了将近1ms这速度看着很不错哦。接下来我们把这东西变成多进程。 这里我们得到了一个惊人的数据传递速度只花了将近1ms这速度看着很不错哦。接下来我们把这东西变成多进程。
## 多进程读取
为了能够实现多进程的读取,我们遇到了一个很奇怪的问题如下:
> File "/Users/lizhenye/miniforge3/lib/python3.9/multiprocessing/reduction.py", line 60, in dump
> ForkingPickler(file, protocol).dump(obj)
> TypeError: cannot pickle '_thread.lock' object
经过Google我们发现问题在于使用了类方法,但经过更近一步的查找原因我们发现问题不出在类上而是出在类里头包含了一些不可以被pickle的带有状态的文件比如Queue和Fifo等等这些被Linux管理的底层资源。
所以我们把找到了pickle会调用的`__getstate__`和`__setstate__`这两个python自带的类内方法这两个方法分别在类被调用的时候将类内的变量交给pickle序列化和返序列化。所以我们对于所有类别的基础类别做出了修订
```python
def __getstate__(self):
self.stop()
state = self.__dict__.copy()
state['_stop_event'] = None
state['_stateful_things'] = {}
return state
def __setstate__(self, state):
self.__dict__.update(state)
self._stop_event = threading.Event()
```
就是在序列化之前把不能序列化的stateful_things收起来把`_stop_event`这个线程间同步用的东西也给收起来,然后就可以复制类内所有的变量新的独立的进程运行了,新的进程会拥有自己进程内独立的`_stop_event`这就导致我们其实已经失去了对于这个新开辟的子进程的控制除非它自己调用自己的self.stop。
好了,我觉得这个地方有点蠢,之后再改。

View File

@ -1,9 +1,9 @@
import logging import logging
import multiprocessing
import time import time
import unittest import unittest
import numpy as np import numpy as np
import transmit import transmit
from config import Config from config import Config
from transmit import FileReceiver, FifoReceiver, FifoSender from transmit import FileReceiver, FifoReceiver, FifoSender
@ -11,6 +11,7 @@ from utils import ImgQueue
class TransmitterTest(unittest.TestCase): class TransmitterTest(unittest.TestCase):
@unittest.skip("file receiver thread test pass")
def test_file_receiver(self): def test_file_receiver(self):
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logging.info('测试文件接收器') logging.info('测试文件接收器')
@ -28,6 +29,25 @@ class TransmitterTest(unittest.TestCase):
self.assertEqual(virtual_data.shape, (1024, 4096, 3)) self.assertEqual(virtual_data.shape, (1024, 4096, 3))
file_receiver.stop() file_receiver.stop()
# @unittest.skip('skip')
def test_file_receiver_subprocess(self):
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logging.info('测试子进程文件接收器')
image_queue = multiprocessing.Queue()
file_receiver = FileReceiver(job_name='rgb img receive', input_dir='../data', output_queue=image_queue,
speed=0.5, name_pattern=None, run_process=True)
virtual_data = np.random.randint(0, 255, (1024, 4096, 3), dtype=np.uint8)
file_receiver.start(need_time=True, virtual_data=virtual_data)
for i in range(5):
time_record, read_data, virtual_data_rec = image_queue.get()
current_time = time.time()
logging.info(
f'Spent {(current_time - time_record) * 1000:.2f}ms to get image with shape {virtual_data.shape}')
is_equal = np.all(virtual_data_rec == virtual_data, axis=(0, 1, 2))
self.assertTrue(is_equal)
self.assertEqual(virtual_data.shape, (1024, 4096, 3))
file_receiver.stop()
@unittest.skip('skip') @unittest.skip('skip')
def test_fifo_receiver_sender(self): def test_fifo_receiver_sender(self):
total_rgb = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量 total_rgb = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量
@ -41,11 +61,9 @@ class TransmitterTest(unittest.TestCase):
logging.debug('Start to send virtual data') logging.debug('Start to send virtual data')
for i in range(5): for i in range(5):
logging.debug('put data to input queue') logging.debug('put data to input queue')
input_queue.put(virtual_data) input_queue.put(virtual_data)
logging.debug('put data to input queue done') logging.debug('put data to input queue done')
virtual_data = image_queue.get() virtual_data = image_queue.get()
# logging.info(f'Spent {(current_time - time_record) * 1000:.2f}ms to get image with shape {virtual_data.shape}') # logging.info(f'Spent {(current_time - time_record) * 1000:.2f}ms to get image with shape {virtual_data.shape}')
self.assertEqual(virtual_data.shape, (1024, 4096, 3)) self.assertEqual(virtual_data.shape, (1024, 4096, 3))

View File

@ -18,18 +18,21 @@ from models import SpecDetector, RgbDetector
from typing import Any, Union from typing import Any, Union
import logging import logging
def test_func(*args, **kwargs):
print('test_func')
print(kwargs)
return 'test_func'
class Transmitter(object): class Transmitter(object):
_io_lock: Union[Lock, Lock]
def __init__(self, job_name: str, run_process: bool = False): def __init__(self, job_name: str, run_process: bool = False):
self.output = None self.output = None
self.job_name = job_name self.job_name = job_name
self.run_process = run_process # If true, run process when started else run thread. self.run_process = run_process # If true, run process when started else run thread.
self._thread_stop = threading.Event() self._stop_event = threading.Event()
self._thread_stop.clear() self._stop_event.clear()
self._running_handler = None self._running_handler = None
self._io_lock = multiprocessing.Lock() if run_process else threading.Lock() self._stateful_things = {}
def set_source(self, *args, **kwargs): def set_source(self, *args, **kwargs):
""" """
@ -59,6 +62,7 @@ class Transmitter(object):
if not self.run_process: if not self.run_process:
self._running_handler = threading.Thread(target=self.job_func, name=name, args=args, kwargs=kwargs) self._running_handler = threading.Thread(target=self.job_func, name=name, args=args, kwargs=kwargs)
else: else:
kwargs.update({'_stateful_things': self._stateful_things})
self._running_handler = Process(target=self.job_func, name=name, daemon=True, args=args, kwargs=kwargs) self._running_handler = Process(target=self.job_func, name=name, daemon=True, args=args, kwargs=kwargs)
self._running_handler.start() self._running_handler.start()
@ -70,7 +74,7 @@ class Transmitter(object):
:return: :return:
""" """
if self._running_handler is not None: if self._running_handler is not None:
self._thread_stop.set() self._stop_event.set()
self._running_handler = None self._running_handler = None
def __del__(self): def __del__(self):
@ -83,16 +87,28 @@ class Transmitter(object):
@functools.wraps(func) @functools.wraps(func)
def wrapper(self, *args, **kwargs): def wrapper(self, *args, **kwargs):
logging.info(f'{self.job_name} {"process" if self.run_process else "thread"} start.') logging.info(f'{self.job_name} {"process" if self.run_process else "thread"} start.')
while not self._thread_stop.is_set(): if self.run_process:
self._stateful_things = kwargs['_stateful_things']
while not self._stop_event.is_set():
func(self, *args, **kwargs) func(self, *args, **kwargs)
logging.info(f'{self.job_name} {"process" if self.run_process else "thread"} stop.') logging.info(f'{self.job_name} {"process" if self.run_process else "thread"} stop.')
self._thread_stop.clear() self._stop_event.clear()
return wrapper return wrapper
def job_func(self, *args, **kwargs): def job_func(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError
def __getstate__(self):
self.stop()
state = self.__dict__.copy()
state['_stop_event'] = None
state['_stateful_things'] = {}
return state
def __setstate__(self, state):
self.__dict__.update(state)
self._stop_event = threading.Event()
class BeforeAfterMethods: class BeforeAfterMethods:
@classmethod @classmethod
@ -125,9 +141,9 @@ class BeforeAfterMethods:
class FileReceiver(Transmitter): class FileReceiver(Transmitter):
def __init__(self, input_dir: str, output_queue: ImgQueue, speed: float = 3.0, name_pattern: str=None, def __init__(self, input_dir: str, output_queue: ImgQueue, speed: float = 3.0, name_pattern: str = None,
job_name: str = 'file_receiver', ): job_name: str = 'file_receiver', run_process: bool = False):
super(FileReceiver, self).__init__(job_name=job_name) super(FileReceiver, self).__init__(job_name=job_name, run_process=run_process)
self.input_dir = input_dir self.input_dir = input_dir
self.send_speed = speed self.send_speed = speed
self.file_names = None self.file_names = None
@ -139,6 +155,7 @@ class FileReceiver(Transmitter):
self.set_output(output_queue) self.set_output(output_queue)
def set_source(self, input_dir: str, name_pattern: str = None, preprocess_method: callable = None): def set_source(self, input_dir: str, name_pattern: str = None, preprocess_method: callable = None):
self.stop()
self.name_pattern = name_pattern if name_pattern is not None else self.name_pattern self.name_pattern = name_pattern if name_pattern is not None else self.name_pattern
file_names = os.listdir(input_dir) file_names = os.listdir(input_dir)
if len(file_names) == 0: if len(file_names) == 0:
@ -148,13 +165,13 @@ class FileReceiver(Transmitter):
else: else:
file_names = file_names file_names = file_names
with self._io_lock: # with self._io_lock:
self.file_names = file_names self.file_names = file_names
self.file_idx = 0 self.file_idx = 0
def set_output(self, output: ImgQueue): def set_output(self, output: ImgQueue):
self.stop() self.stop()
self.output_queue = output self._stateful_things['output_queue'] = output
@Transmitter.job_decorator @Transmitter.job_decorator
def job_func(self, need_time=False, *args, **kwargs): def job_func(self, need_time=False, *args, **kwargs):
@ -162,14 +179,14 @@ class FileReceiver(Transmitter):
发送文件 发送文件
:param need_time: 是否需要发送时间戳 :param need_time: 是否需要发送时间戳
:param kwargs: virtual_data: 虚拟的数据用于测试 :param kwargs: output_queue: 以进程模式运行时需要, virtual_data: 虚拟的数据用于测试
:return: :return:
""" """
with self._io_lock: logging.debug(f'{self.job_name} start.')
self.file_idx += 1 self.file_idx += 1
if self.file_idx >= len(self.file_names): if self.file_idx >= len(self.file_names):
self.file_idx = 0 self.file_idx = 0
file_name = self.file_names[self.file_idx] file_name = self.file_names[self.file_idx]
file_name = os.path.join(self.input_dir, file_name) file_name = os.path.join(self.input_dir, file_name)
with open(file_name, 'rb') as f: with open(file_name, 'rb') as f:
data = f.read() data = f.read()
@ -179,7 +196,7 @@ class FileReceiver(Transmitter):
data = (time.time(), data) data = (time.time(), data)
if 'virtual_data' in kwargs: if 'virtual_data' in kwargs:
data = (*data, kwargs['virtual_data']) data = (*data, kwargs['virtual_data'])
self.output_queue.put(data) self._stateful_things['output_queue'].put(data)
time.sleep(self.send_speed) time.sleep(self.send_speed)
logging.info(f'sleep {self.send_speed}s ...') logging.info(f'sleep {self.send_speed}s ...')
@ -197,15 +214,13 @@ class FifoReceiver(Transmitter):
def set_source(self, fifo_path: str): def set_source(self, fifo_path: str):
self.stop() self.stop()
with self._io_lock: if not os.access(fifo_path, os.F_OK):
if not os.access(fifo_path, os.F_OK): os.mkfifo(fifo_path, 0o777)
os.mkfifo(fifo_path, 0o777) self._input_fifo_path = fifo_path
self._input_fifo_path = fifo_path
def set_output(self, output: ImgQueue): def set_output(self, output: ImgQueue):
self.stop() self.stop()
with self._io_lock: self._output_queue = output
self._output_queue = output
@Transmitter.job_decorator @Transmitter.job_decorator
def job_func(self, post_process_func=None): def job_func(self, post_process_func=None):
@ -222,6 +237,12 @@ class FifoReceiver(Transmitter):
self._output_queue.put(data) self._output_queue.put(data)
os.close(input_fifo) os.close(input_fifo)
def __getstate__(self):
state = self.__dict__.copy()
state['_input_fifo_path'] = None
state['_output_queue'] = None
return state
class FifoSender(Transmitter): class FifoSender(Transmitter):
def __init__(self, fifo_path: str, source: ImgQueue, job_name: str = 'fifo_sender'): def __init__(self, fifo_path: str, source: ImgQueue, job_name: str = 'fifo_sender'):
@ -308,7 +329,7 @@ class CmdImgSplitMidware(Transmitter):
# 如果是图片,交给预测的人 # 如果是图片,交给预测的人
for name, subscriber in self._subscribers.items(): for name, subscriber in self._subscribers.items():
item = (spec_data, rgb_data) item = (spec_data, rgb_data)
subscriber.safe_put(item) subscriber.fifo_put(item)
else: else:
# 否则程序出现毁灭性问题,立刻崩 # 否则程序出现毁灭性问题,立刻崩
logging.critical('两个相机传回的数据没有对上') logging.critical('两个相机传回的数据没有对上')
@ -375,7 +396,7 @@ class ThreadDetector(Transmitter):
if not self._input_queue.empty(): if not self._input_queue.empty():
spec, rgb = self._input_queue.get() spec, rgb = self._input_queue.get()
mask = self.predict(spec, rgb) mask = self.predict(spec, rgb)
self._output_queue.safe_put(mask) self._output_queue.fifo_put(mask)
self._thread_exit.clear() self._thread_exit.clear()

View File

@ -63,7 +63,7 @@ class ImgQueue(Queue):
self.queue.clear() self.queue.clear()
self.not_full.notify_all() self.not_full.notify_all()
def safe_put(self, item): def fifo_put(self, item):
if self.full(): if self.full():
_ = self.get() _ = self.get()
return False return False