diff --git a/README.md b/README.md index 7d1ce7a..e348716 100644 --- a/README.md +++ b/README.md @@ -390,7 +390,7 @@ python main_test.py /path/to/convert -convert_dir /output/dir -s # 多线程读取与多进程预测 -## 多线程与进程总体结构图 +## 总体结构图 ![MultiThread](https://raw.githubusercontent.com/Karllzy/imagebed/main/img/MultiThread.png) @@ -433,3 +433,31 @@ python main_test.py /path/to/convert -convert_dir /output/dir -s 这里我们得到了一个惊人的数据传递速度,只花了将近1ms,这速度看着很不错哦。接下来我们把这东西变成多进程。 +## 多进程读取 + +为了能够实现多进程的读取,我们遇到了一个很奇怪的问题如下: + +> File "/Users/lizhenye/miniforge3/lib/python3.9/multiprocessing/reduction.py", line 60, in dump +> ForkingPickler(file, protocol).dump(obj) +> TypeError: cannot pickle '_thread.lock' object + +经过Google,我们发现问题在于使用了类方法,但经过更近一步的查找原因,我们发现,问题不出在类上,而是出在类里头包含了一些不可以被pickle的带有状态的文件,比如Queue和Fifo等等这些被Linux管理的底层资源。 + +所以我们把找到了pickle会调用的`__getstate__`和`__setstate__`这两个python自带的类内方法,这两个方法分别在类被调用的时候将类内的变量交给pickle序列化和返序列化。所以我们对于所有类别的基础类别做出了修订: + +```python +def __getstate__(self): + self.stop() + state = self.__dict__.copy() + state['_stop_event'] = None + state['_stateful_things'] = {} + return state + +def __setstate__(self, state): + self.__dict__.update(state) + self._stop_event = threading.Event() +``` + +就是在序列化之前把不能序列化的stateful_things收起来,把`_stop_event`这个线程间同步用的东西也给收起来,然后就可以复制类内所有的变量新的独立的进程运行了,新的进程会拥有自己进程内独立的`_stop_event`,这就导致我们其实已经失去了对于这个新开辟的子进程的控制,除非它自己调用自己的self.stop。 + +好了,我觉得这个地方有点蠢,之后再改。 diff --git a/tests/test_transmit.py b/tests/test_transmit.py index 0d56f4d..5e4c414 100644 --- a/tests/test_transmit.py +++ b/tests/test_transmit.py @@ -1,9 +1,9 @@ import logging +import multiprocessing import time import unittest import numpy as np - import transmit from config import Config from transmit import FileReceiver, FifoReceiver, FifoSender @@ -11,6 +11,7 @@ from utils import ImgQueue class TransmitterTest(unittest.TestCase): + @unittest.skip("file receiver thread test pass") def test_file_receiver(self): logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logging.info('测试文件接收器') @@ -28,6 +29,25 @@ class TransmitterTest(unittest.TestCase): self.assertEqual(virtual_data.shape, (1024, 4096, 3)) file_receiver.stop() + # @unittest.skip('skip') + def test_file_receiver_subprocess(self): + logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') + logging.info('测试子进程文件接收器') + image_queue = multiprocessing.Queue() + file_receiver = FileReceiver(job_name='rgb img receive', input_dir='../data', output_queue=image_queue, + speed=0.5, name_pattern=None, run_process=True) + virtual_data = np.random.randint(0, 255, (1024, 4096, 3), dtype=np.uint8) + file_receiver.start(need_time=True, virtual_data=virtual_data) + for i in range(5): + time_record, read_data, virtual_data_rec = image_queue.get() + current_time = time.time() + logging.info( + f'Spent {(current_time - time_record) * 1000:.2f}ms to get image with shape {virtual_data.shape}') + is_equal = np.all(virtual_data_rec == virtual_data, axis=(0, 1, 2)) + self.assertTrue(is_equal) + self.assertEqual(virtual_data.shape, (1024, 4096, 3)) + file_receiver.stop() + @unittest.skip('skip') def test_fifo_receiver_sender(self): total_rgb = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量 @@ -41,11 +61,9 @@ class TransmitterTest(unittest.TestCase): logging.debug('Start to send virtual data') for i in range(5): logging.debug('put data to input queue') - input_queue.put(virtual_data) logging.debug('put data to input queue done') virtual_data = image_queue.get() - # logging.info(f'Spent {(current_time - time_record) * 1000:.2f}ms to get image with shape {virtual_data.shape}') self.assertEqual(virtual_data.shape, (1024, 4096, 3)) diff --git a/transmit.py b/transmit.py index 38011e6..9430d8f 100644 --- a/transmit.py +++ b/transmit.py @@ -18,18 +18,21 @@ from models import SpecDetector, RgbDetector from typing import Any, Union import logging +def test_func(*args, **kwargs): + print('test_func') + print(kwargs) + return 'test_func' class Transmitter(object): - _io_lock: Union[Lock, Lock] def __init__(self, job_name: str, run_process: bool = False): self.output = None self.job_name = job_name self.run_process = run_process # If true, run process when started else run thread. - self._thread_stop = threading.Event() - self._thread_stop.clear() + self._stop_event = threading.Event() + self._stop_event.clear() self._running_handler = None - self._io_lock = multiprocessing.Lock() if run_process else threading.Lock() + self._stateful_things = {} def set_source(self, *args, **kwargs): """ @@ -59,6 +62,7 @@ class Transmitter(object): if not self.run_process: self._running_handler = threading.Thread(target=self.job_func, name=name, args=args, kwargs=kwargs) else: + kwargs.update({'_stateful_things': self._stateful_things}) self._running_handler = Process(target=self.job_func, name=name, daemon=True, args=args, kwargs=kwargs) self._running_handler.start() @@ -70,7 +74,7 @@ class Transmitter(object): :return: """ if self._running_handler is not None: - self._thread_stop.set() + self._stop_event.set() self._running_handler = None def __del__(self): @@ -83,16 +87,28 @@ class Transmitter(object): @functools.wraps(func) def wrapper(self, *args, **kwargs): logging.info(f'{self.job_name} {"process" if self.run_process else "thread"} start.') - while not self._thread_stop.is_set(): + if self.run_process: + self._stateful_things = kwargs['_stateful_things'] + while not self._stop_event.is_set(): func(self, *args, **kwargs) logging.info(f'{self.job_name} {"process" if self.run_process else "thread"} stop.') - self._thread_stop.clear() - + self._stop_event.clear() return wrapper def job_func(self, *args, **kwargs): raise NotImplementedError + def __getstate__(self): + self.stop() + state = self.__dict__.copy() + state['_stop_event'] = None + state['_stateful_things'] = {} + return state + + def __setstate__(self, state): + self.__dict__.update(state) + self._stop_event = threading.Event() + class BeforeAfterMethods: @classmethod @@ -125,9 +141,9 @@ class BeforeAfterMethods: class FileReceiver(Transmitter): - def __init__(self, input_dir: str, output_queue: ImgQueue, speed: float = 3.0, name_pattern: str=None, - job_name: str = 'file_receiver', ): - super(FileReceiver, self).__init__(job_name=job_name) + def __init__(self, input_dir: str, output_queue: ImgQueue, speed: float = 3.0, name_pattern: str = None, + job_name: str = 'file_receiver', run_process: bool = False): + super(FileReceiver, self).__init__(job_name=job_name, run_process=run_process) self.input_dir = input_dir self.send_speed = speed self.file_names = None @@ -139,6 +155,7 @@ class FileReceiver(Transmitter): self.set_output(output_queue) def set_source(self, input_dir: str, name_pattern: str = None, preprocess_method: callable = None): + self.stop() self.name_pattern = name_pattern if name_pattern is not None else self.name_pattern file_names = os.listdir(input_dir) if len(file_names) == 0: @@ -148,13 +165,13 @@ class FileReceiver(Transmitter): else: file_names = file_names - with self._io_lock: - self.file_names = file_names - self.file_idx = 0 + # with self._io_lock: + self.file_names = file_names + self.file_idx = 0 def set_output(self, output: ImgQueue): self.stop() - self.output_queue = output + self._stateful_things['output_queue'] = output @Transmitter.job_decorator def job_func(self, need_time=False, *args, **kwargs): @@ -162,14 +179,14 @@ class FileReceiver(Transmitter): 发送文件 :param need_time: 是否需要发送时间戳 - :param kwargs: virtual_data: 虚拟的数据,用于测试 + :param kwargs: output_queue: 以进程模式运行时需要, virtual_data: 虚拟的数据,用于测试 :return: """ - with self._io_lock: - self.file_idx += 1 - if self.file_idx >= len(self.file_names): - self.file_idx = 0 - file_name = self.file_names[self.file_idx] + logging.debug(f'{self.job_name} start.') + self.file_idx += 1 + if self.file_idx >= len(self.file_names): + self.file_idx = 0 + file_name = self.file_names[self.file_idx] file_name = os.path.join(self.input_dir, file_name) with open(file_name, 'rb') as f: data = f.read() @@ -179,7 +196,7 @@ class FileReceiver(Transmitter): data = (time.time(), data) if 'virtual_data' in kwargs: data = (*data, kwargs['virtual_data']) - self.output_queue.put(data) + self._stateful_things['output_queue'].put(data) time.sleep(self.send_speed) logging.info(f'sleep {self.send_speed}s ...') @@ -197,15 +214,13 @@ class FifoReceiver(Transmitter): def set_source(self, fifo_path: str): self.stop() - with self._io_lock: - if not os.access(fifo_path, os.F_OK): - os.mkfifo(fifo_path, 0o777) - self._input_fifo_path = fifo_path + if not os.access(fifo_path, os.F_OK): + os.mkfifo(fifo_path, 0o777) + self._input_fifo_path = fifo_path def set_output(self, output: ImgQueue): self.stop() - with self._io_lock: - self._output_queue = output + self._output_queue = output @Transmitter.job_decorator def job_func(self, post_process_func=None): @@ -222,6 +237,12 @@ class FifoReceiver(Transmitter): self._output_queue.put(data) os.close(input_fifo) + def __getstate__(self): + state = self.__dict__.copy() + state['_input_fifo_path'] = None + state['_output_queue'] = None + return state + class FifoSender(Transmitter): def __init__(self, fifo_path: str, source: ImgQueue, job_name: str = 'fifo_sender'): @@ -308,7 +329,7 @@ class CmdImgSplitMidware(Transmitter): # 如果是图片,交给预测的人 for name, subscriber in self._subscribers.items(): item = (spec_data, rgb_data) - subscriber.safe_put(item) + subscriber.fifo_put(item) else: # 否则程序出现毁灭性问题,立刻崩 logging.critical('两个相机传回的数据没有对上') @@ -375,7 +396,7 @@ class ThreadDetector(Transmitter): if not self._input_queue.empty(): spec, rgb = self._input_queue.get() mask = self.predict(spec, rgb) - self._output_queue.safe_put(mask) + self._output_queue.fifo_put(mask) self._thread_exit.clear() diff --git a/utils.py b/utils.py index ea0ef2f..6d75d13 100755 --- a/utils.py +++ b/utils.py @@ -63,7 +63,7 @@ class ImgQueue(Queue): self.queue.clear() self.not_full.notify_all() - def safe_put(self, item): + def fifo_put(self, item): if self.full(): _ = self.get() return False