mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 14:23:55 +00:00
[ext] 多进程功能初级版本测试完毕
This commit is contained in:
parent
15c8bcb69c
commit
bbdfe9b679
30
README.md
30
README.md
@ -390,7 +390,7 @@ python main_test.py /path/to/convert -convert_dir /output/dir -s
|
|||||||
|
|
||||||
# 多线程读取与多进程预测
|
# 多线程读取与多进程预测
|
||||||
|
|
||||||
## 多线程与进程总体结构图
|
## 总体结构图
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
@ -433,3 +433,31 @@ python main_test.py /path/to/convert -convert_dir /output/dir -s
|
|||||||
|
|
||||||
这里我们得到了一个惊人的数据传递速度,只花了将近1ms,这速度看着很不错哦。接下来我们把这东西变成多进程。
|
这里我们得到了一个惊人的数据传递速度,只花了将近1ms,这速度看着很不错哦。接下来我们把这东西变成多进程。
|
||||||
|
|
||||||
|
## 多进程读取
|
||||||
|
|
||||||
|
为了能够实现多进程的读取,我们遇到了一个很奇怪的问题如下:
|
||||||
|
|
||||||
|
> File "/Users/lizhenye/miniforge3/lib/python3.9/multiprocessing/reduction.py", line 60, in dump
|
||||||
|
> ForkingPickler(file, protocol).dump(obj)
|
||||||
|
> TypeError: cannot pickle '_thread.lock' object
|
||||||
|
|
||||||
|
经过Google,我们发现问题在于使用了类方法,但经过更近一步的查找原因,我们发现,问题不出在类上,而是出在类里头包含了一些不可以被pickle的带有状态的文件,比如Queue和Fifo等等这些被Linux管理的底层资源。
|
||||||
|
|
||||||
|
所以我们把找到了pickle会调用的`__getstate__`和`__setstate__`这两个python自带的类内方法,这两个方法分别在类被调用的时候将类内的变量交给pickle序列化和返序列化。所以我们对于所有类别的基础类别做出了修订:
|
||||||
|
|
||||||
|
```python
|
||||||
|
def __getstate__(self):
|
||||||
|
self.stop()
|
||||||
|
state = self.__dict__.copy()
|
||||||
|
state['_stop_event'] = None
|
||||||
|
state['_stateful_things'] = {}
|
||||||
|
return state
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
self.__dict__.update(state)
|
||||||
|
self._stop_event = threading.Event()
|
||||||
|
```
|
||||||
|
|
||||||
|
就是在序列化之前把不能序列化的stateful_things收起来,把`_stop_event`这个线程间同步用的东西也给收起来,然后就可以复制类内所有的变量新的独立的进程运行了,新的进程会拥有自己进程内独立的`_stop_event`,这就导致我们其实已经失去了对于这个新开辟的子进程的控制,除非它自己调用自己的self.stop。
|
||||||
|
|
||||||
|
好了,我觉得这个地方有点蠢,之后再改。
|
||||||
|
|||||||
@ -1,9 +1,9 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import multiprocessing
|
||||||
import time
|
import time
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import transmit
|
import transmit
|
||||||
from config import Config
|
from config import Config
|
||||||
from transmit import FileReceiver, FifoReceiver, FifoSender
|
from transmit import FileReceiver, FifoReceiver, FifoSender
|
||||||
@ -11,6 +11,7 @@ from utils import ImgQueue
|
|||||||
|
|
||||||
|
|
||||||
class TransmitterTest(unittest.TestCase):
|
class TransmitterTest(unittest.TestCase):
|
||||||
|
@unittest.skip("file receiver thread test pass")
|
||||||
def test_file_receiver(self):
|
def test_file_receiver(self):
|
||||||
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
logging.info('测试文件接收器')
|
logging.info('测试文件接收器')
|
||||||
@ -28,6 +29,25 @@ class TransmitterTest(unittest.TestCase):
|
|||||||
self.assertEqual(virtual_data.shape, (1024, 4096, 3))
|
self.assertEqual(virtual_data.shape, (1024, 4096, 3))
|
||||||
file_receiver.stop()
|
file_receiver.stop()
|
||||||
|
|
||||||
|
# @unittest.skip('skip')
|
||||||
|
def test_file_receiver_subprocess(self):
|
||||||
|
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
logging.info('测试子进程文件接收器')
|
||||||
|
image_queue = multiprocessing.Queue()
|
||||||
|
file_receiver = FileReceiver(job_name='rgb img receive', input_dir='../data', output_queue=image_queue,
|
||||||
|
speed=0.5, name_pattern=None, run_process=True)
|
||||||
|
virtual_data = np.random.randint(0, 255, (1024, 4096, 3), dtype=np.uint8)
|
||||||
|
file_receiver.start(need_time=True, virtual_data=virtual_data)
|
||||||
|
for i in range(5):
|
||||||
|
time_record, read_data, virtual_data_rec = image_queue.get()
|
||||||
|
current_time = time.time()
|
||||||
|
logging.info(
|
||||||
|
f'Spent {(current_time - time_record) * 1000:.2f}ms to get image with shape {virtual_data.shape}')
|
||||||
|
is_equal = np.all(virtual_data_rec == virtual_data, axis=(0, 1, 2))
|
||||||
|
self.assertTrue(is_equal)
|
||||||
|
self.assertEqual(virtual_data.shape, (1024, 4096, 3))
|
||||||
|
file_receiver.stop()
|
||||||
|
|
||||||
@unittest.skip('skip')
|
@unittest.skip('skip')
|
||||||
def test_fifo_receiver_sender(self):
|
def test_fifo_receiver_sender(self):
|
||||||
total_rgb = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量
|
total_rgb = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量
|
||||||
@ -41,11 +61,9 @@ class TransmitterTest(unittest.TestCase):
|
|||||||
logging.debug('Start to send virtual data')
|
logging.debug('Start to send virtual data')
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
logging.debug('put data to input queue')
|
logging.debug('put data to input queue')
|
||||||
|
|
||||||
input_queue.put(virtual_data)
|
input_queue.put(virtual_data)
|
||||||
logging.debug('put data to input queue done')
|
logging.debug('put data to input queue done')
|
||||||
virtual_data = image_queue.get()
|
virtual_data = image_queue.get()
|
||||||
|
|
||||||
# logging.info(f'Spent {(current_time - time_record) * 1000:.2f}ms to get image with shape {virtual_data.shape}')
|
# logging.info(f'Spent {(current_time - time_record) * 1000:.2f}ms to get image with shape {virtual_data.shape}')
|
||||||
self.assertEqual(virtual_data.shape, (1024, 4096, 3))
|
self.assertEqual(virtual_data.shape, (1024, 4096, 3))
|
||||||
|
|
||||||
|
|||||||
59
transmit.py
59
transmit.py
@ -18,18 +18,21 @@ from models import SpecDetector, RgbDetector
|
|||||||
from typing import Any, Union
|
from typing import Any, Union
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
def test_func(*args, **kwargs):
|
||||||
|
print('test_func')
|
||||||
|
print(kwargs)
|
||||||
|
return 'test_func'
|
||||||
|
|
||||||
class Transmitter(object):
|
class Transmitter(object):
|
||||||
_io_lock: Union[Lock, Lock]
|
|
||||||
|
|
||||||
def __init__(self, job_name: str, run_process: bool = False):
|
def __init__(self, job_name: str, run_process: bool = False):
|
||||||
self.output = None
|
self.output = None
|
||||||
self.job_name = job_name
|
self.job_name = job_name
|
||||||
self.run_process = run_process # If true, run process when started else run thread.
|
self.run_process = run_process # If true, run process when started else run thread.
|
||||||
self._thread_stop = threading.Event()
|
self._stop_event = threading.Event()
|
||||||
self._thread_stop.clear()
|
self._stop_event.clear()
|
||||||
self._running_handler = None
|
self._running_handler = None
|
||||||
self._io_lock = multiprocessing.Lock() if run_process else threading.Lock()
|
self._stateful_things = {}
|
||||||
|
|
||||||
def set_source(self, *args, **kwargs):
|
def set_source(self, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
@ -59,6 +62,7 @@ class Transmitter(object):
|
|||||||
if not self.run_process:
|
if not self.run_process:
|
||||||
self._running_handler = threading.Thread(target=self.job_func, name=name, args=args, kwargs=kwargs)
|
self._running_handler = threading.Thread(target=self.job_func, name=name, args=args, kwargs=kwargs)
|
||||||
else:
|
else:
|
||||||
|
kwargs.update({'_stateful_things': self._stateful_things})
|
||||||
self._running_handler = Process(target=self.job_func, name=name, daemon=True, args=args, kwargs=kwargs)
|
self._running_handler = Process(target=self.job_func, name=name, daemon=True, args=args, kwargs=kwargs)
|
||||||
self._running_handler.start()
|
self._running_handler.start()
|
||||||
|
|
||||||
@ -70,7 +74,7 @@ class Transmitter(object):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if self._running_handler is not None:
|
if self._running_handler is not None:
|
||||||
self._thread_stop.set()
|
self._stop_event.set()
|
||||||
self._running_handler = None
|
self._running_handler = None
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
@ -83,16 +87,28 @@ class Transmitter(object):
|
|||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
def wrapper(self, *args, **kwargs):
|
def wrapper(self, *args, **kwargs):
|
||||||
logging.info(f'{self.job_name} {"process" if self.run_process else "thread"} start.')
|
logging.info(f'{self.job_name} {"process" if self.run_process else "thread"} start.')
|
||||||
while not self._thread_stop.is_set():
|
if self.run_process:
|
||||||
|
self._stateful_things = kwargs['_stateful_things']
|
||||||
|
while not self._stop_event.is_set():
|
||||||
func(self, *args, **kwargs)
|
func(self, *args, **kwargs)
|
||||||
logging.info(f'{self.job_name} {"process" if self.run_process else "thread"} stop.')
|
logging.info(f'{self.job_name} {"process" if self.run_process else "thread"} stop.')
|
||||||
self._thread_stop.clear()
|
self._stop_event.clear()
|
||||||
|
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
def job_func(self, *args, **kwargs):
|
def job_func(self, *args, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
self.stop()
|
||||||
|
state = self.__dict__.copy()
|
||||||
|
state['_stop_event'] = None
|
||||||
|
state['_stateful_things'] = {}
|
||||||
|
return state
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
self.__dict__.update(state)
|
||||||
|
self._stop_event = threading.Event()
|
||||||
|
|
||||||
|
|
||||||
class BeforeAfterMethods:
|
class BeforeAfterMethods:
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -126,8 +142,8 @@ class BeforeAfterMethods:
|
|||||||
|
|
||||||
class FileReceiver(Transmitter):
|
class FileReceiver(Transmitter):
|
||||||
def __init__(self, input_dir: str, output_queue: ImgQueue, speed: float = 3.0, name_pattern: str = None,
|
def __init__(self, input_dir: str, output_queue: ImgQueue, speed: float = 3.0, name_pattern: str = None,
|
||||||
job_name: str = 'file_receiver', ):
|
job_name: str = 'file_receiver', run_process: bool = False):
|
||||||
super(FileReceiver, self).__init__(job_name=job_name)
|
super(FileReceiver, self).__init__(job_name=job_name, run_process=run_process)
|
||||||
self.input_dir = input_dir
|
self.input_dir = input_dir
|
||||||
self.send_speed = speed
|
self.send_speed = speed
|
||||||
self.file_names = None
|
self.file_names = None
|
||||||
@ -139,6 +155,7 @@ class FileReceiver(Transmitter):
|
|||||||
self.set_output(output_queue)
|
self.set_output(output_queue)
|
||||||
|
|
||||||
def set_source(self, input_dir: str, name_pattern: str = None, preprocess_method: callable = None):
|
def set_source(self, input_dir: str, name_pattern: str = None, preprocess_method: callable = None):
|
||||||
|
self.stop()
|
||||||
self.name_pattern = name_pattern if name_pattern is not None else self.name_pattern
|
self.name_pattern = name_pattern if name_pattern is not None else self.name_pattern
|
||||||
file_names = os.listdir(input_dir)
|
file_names = os.listdir(input_dir)
|
||||||
if len(file_names) == 0:
|
if len(file_names) == 0:
|
||||||
@ -148,13 +165,13 @@ class FileReceiver(Transmitter):
|
|||||||
else:
|
else:
|
||||||
file_names = file_names
|
file_names = file_names
|
||||||
|
|
||||||
with self._io_lock:
|
# with self._io_lock:
|
||||||
self.file_names = file_names
|
self.file_names = file_names
|
||||||
self.file_idx = 0
|
self.file_idx = 0
|
||||||
|
|
||||||
def set_output(self, output: ImgQueue):
|
def set_output(self, output: ImgQueue):
|
||||||
self.stop()
|
self.stop()
|
||||||
self.output_queue = output
|
self._stateful_things['output_queue'] = output
|
||||||
|
|
||||||
@Transmitter.job_decorator
|
@Transmitter.job_decorator
|
||||||
def job_func(self, need_time=False, *args, **kwargs):
|
def job_func(self, need_time=False, *args, **kwargs):
|
||||||
@ -162,10 +179,10 @@ class FileReceiver(Transmitter):
|
|||||||
发送文件
|
发送文件
|
||||||
|
|
||||||
:param need_time: 是否需要发送时间戳
|
:param need_time: 是否需要发送时间戳
|
||||||
:param kwargs: virtual_data: 虚拟的数据,用于测试
|
:param kwargs: output_queue: 以进程模式运行时需要, virtual_data: 虚拟的数据,用于测试
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
with self._io_lock:
|
logging.debug(f'{self.job_name} start.')
|
||||||
self.file_idx += 1
|
self.file_idx += 1
|
||||||
if self.file_idx >= len(self.file_names):
|
if self.file_idx >= len(self.file_names):
|
||||||
self.file_idx = 0
|
self.file_idx = 0
|
||||||
@ -179,7 +196,7 @@ class FileReceiver(Transmitter):
|
|||||||
data = (time.time(), data)
|
data = (time.time(), data)
|
||||||
if 'virtual_data' in kwargs:
|
if 'virtual_data' in kwargs:
|
||||||
data = (*data, kwargs['virtual_data'])
|
data = (*data, kwargs['virtual_data'])
|
||||||
self.output_queue.put(data)
|
self._stateful_things['output_queue'].put(data)
|
||||||
time.sleep(self.send_speed)
|
time.sleep(self.send_speed)
|
||||||
logging.info(f'sleep {self.send_speed}s ...')
|
logging.info(f'sleep {self.send_speed}s ...')
|
||||||
|
|
||||||
@ -197,14 +214,12 @@ class FifoReceiver(Transmitter):
|
|||||||
|
|
||||||
def set_source(self, fifo_path: str):
|
def set_source(self, fifo_path: str):
|
||||||
self.stop()
|
self.stop()
|
||||||
with self._io_lock:
|
|
||||||
if not os.access(fifo_path, os.F_OK):
|
if not os.access(fifo_path, os.F_OK):
|
||||||
os.mkfifo(fifo_path, 0o777)
|
os.mkfifo(fifo_path, 0o777)
|
||||||
self._input_fifo_path = fifo_path
|
self._input_fifo_path = fifo_path
|
||||||
|
|
||||||
def set_output(self, output: ImgQueue):
|
def set_output(self, output: ImgQueue):
|
||||||
self.stop()
|
self.stop()
|
||||||
with self._io_lock:
|
|
||||||
self._output_queue = output
|
self._output_queue = output
|
||||||
|
|
||||||
@Transmitter.job_decorator
|
@Transmitter.job_decorator
|
||||||
@ -222,6 +237,12 @@ class FifoReceiver(Transmitter):
|
|||||||
self._output_queue.put(data)
|
self._output_queue.put(data)
|
||||||
os.close(input_fifo)
|
os.close(input_fifo)
|
||||||
|
|
||||||
|
def __getstate__(self):
|
||||||
|
state = self.__dict__.copy()
|
||||||
|
state['_input_fifo_path'] = None
|
||||||
|
state['_output_queue'] = None
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
class FifoSender(Transmitter):
|
class FifoSender(Transmitter):
|
||||||
def __init__(self, fifo_path: str, source: ImgQueue, job_name: str = 'fifo_sender'):
|
def __init__(self, fifo_path: str, source: ImgQueue, job_name: str = 'fifo_sender'):
|
||||||
@ -308,7 +329,7 @@ class CmdImgSplitMidware(Transmitter):
|
|||||||
# 如果是图片,交给预测的人
|
# 如果是图片,交给预测的人
|
||||||
for name, subscriber in self._subscribers.items():
|
for name, subscriber in self._subscribers.items():
|
||||||
item = (spec_data, rgb_data)
|
item = (spec_data, rgb_data)
|
||||||
subscriber.safe_put(item)
|
subscriber.fifo_put(item)
|
||||||
else:
|
else:
|
||||||
# 否则程序出现毁灭性问题,立刻崩
|
# 否则程序出现毁灭性问题,立刻崩
|
||||||
logging.critical('两个相机传回的数据没有对上')
|
logging.critical('两个相机传回的数据没有对上')
|
||||||
@ -375,7 +396,7 @@ class ThreadDetector(Transmitter):
|
|||||||
if not self._input_queue.empty():
|
if not self._input_queue.empty():
|
||||||
spec, rgb = self._input_queue.get()
|
spec, rgb = self._input_queue.get()
|
||||||
mask = self.predict(spec, rgb)
|
mask = self.predict(spec, rgb)
|
||||||
self._output_queue.safe_put(mask)
|
self._output_queue.fifo_put(mask)
|
||||||
self._thread_exit.clear()
|
self._thread_exit.clear()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user