diff --git a/README.md b/README.md index 18445a1..7d1ce7a 100644 --- a/README.md +++ b/README.md @@ -388,3 +388,48 @@ python main_test.py /path/to/convert -convert_dir /output/dir -s ![image-20220808123044267](https://raw.githubusercontent.com/Karllzy/imagebed/main/img/image-20220808123044267.png) +# 多线程读取与多进程预测 + +## 多线程与进程总体结构图 + +![MultiThread](https://raw.githubusercontent.com/Karllzy/imagebed/main/img/MultiThread.png) + +## 多线程读取 + +为了能够避免IO的等待,我们使用了开销相对较小的线程来实现多线程的数据读取。 + +因为很多时候我们需要读图测试,所以我们写了一个FileReceiver类,用法大概就像测试文件里这样: + +```python + def test_file_receiver(self): + logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') + image_queue = ImgQueue() + file_receiver = FileReceiver(job_name='rgb img receive', input_dir='../data', output_queue=image_queue, + speed=0.5, name_pattern=None) + virtual_data = np.zeros((1024, 4096, 3), dtype=np.uint8) + file_receiver.start(need_time=True, virtual_data=virtual_data) + for i in range(10): + data = image_queue.get() + time_record = data[0] + logging.info(f'Spent {(time.time() - time_record) * 1000:.2f}ms to get image with shape {data[-1].shape}') + self.assertEqual(data[-1].shape, (1024, 4096, 3)) + file_receiver.stop() +``` + +测试结果如下所示: + +> 2022-08-17 23:46:09,742 - root - INFO - rgb img receive thread start. +> 2022-08-17 23:46:09,754 - root - INFO - Spent 0.04ms to get image with shape (1024, 4096, 3) +> 2022-08-17 23:46:10,259 - root - INFO - sleep 0.5s ... +> 2022-08-17 23:46:10,276 - root - INFO - Spent 0.92ms to get image with shape (1024, 4096, 3) +> 2022-08-17 23:46:10,780 - root - INFO - sleep 0.5s ... +> 2022-08-17 23:46:10,789 - root - INFO - Spent 0.79ms to get image with shape (1024, 4096, 3) +> 2022-08-17 23:46:11,293 - root - INFO - sleep 0.5s ... +> 2022-08-17 23:46:11,301 - root - INFO - Spent 0.81ms to get image with shape (1024, 4096, 3) +> 2022-08-17 23:46:11,802 - root - INFO - sleep 0.5s ... +> 2022-08-17 23:46:11,810 - root - INFO - Spent 0.77ms to get image with shape (1024, 4096, 3) +> 2022-08-17 23:46:12,314 - root - INFO - sleep 0.5s ... +> 2022-08-17 23:46:12,315 - root - INFO - rgb img receive thread stop. + +这里我们得到了一个惊人的数据传递速度,只花了将近1ms,这速度看着很不错哦。接下来我们把这东西变成多进程。 + diff --git a/tests/test_transmit.py b/tests/test_transmit.py new file mode 100644 index 0000000..0d56f4d --- /dev/null +++ b/tests/test_transmit.py @@ -0,0 +1,54 @@ +import logging +import time +import unittest + +import numpy as np + +import transmit +from config import Config +from transmit import FileReceiver, FifoReceiver, FifoSender +from utils import ImgQueue + + +class TransmitterTest(unittest.TestCase): + def test_file_receiver(self): + logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') + logging.info('测试文件接收器') + image_queue = ImgQueue() + file_receiver = FileReceiver(job_name='rgb img receive', input_dir='../data', output_queue=image_queue, + speed=0.5, name_pattern=None) + virtual_data = np.random.randint(0, 255, (1024, 4096, 3), dtype=np.uint8) + file_receiver.start(need_time=True, virtual_data=virtual_data) + for i in range(5): + time_record, read_data, virtual_data_rec = image_queue.get() + current_time = time.time() + logging.info(f'Spent {(current_time - time_record) * 1000:.2f}ms to get image with shape {virtual_data.shape}') + is_equal = np.all(virtual_data_rec == virtual_data, axis=(0, 1, 2)) + self.assertTrue(is_equal) + self.assertEqual(virtual_data.shape, (1024, 4096, 3)) + file_receiver.stop() + + @unittest.skip('skip') + def test_fifo_receiver_sender(self): + total_rgb = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量 + image_queue, input_queue = ImgQueue(), ImgQueue() + fifo_receiver = FifoReceiver(job_name='fifo img receive', fifo_path='/tmp/dkimg.fifo', output=image_queue, + read_max_num=total_rgb) + fifo_sender = FifoSender(fifo_path='/tmp/dkimg.fifo', source=input_queue, job_name='fifo img send') + virtual_data = np.random.randint(0, 255, (1024, 4096, 3), dtype=np.uint8) + fifo_sender.start(preprocess=transmit.BeforeAfterMethods.mask_preprocess) + fifo_receiver.start() + logging.debug('Start to send virtual data') + for i in range(5): + logging.debug('put data to input queue') + + input_queue.put(virtual_data) + logging.debug('put data to input queue done') + virtual_data = image_queue.get() + + # logging.info(f'Spent {(current_time - time_record) * 1000:.2f}ms to get image with shape {virtual_data.shape}') + self.assertEqual(virtual_data.shape, (1024, 4096, 3)) + + +if __name__ == '__main__': + unittest.main() diff --git a/transmit.py b/transmit.py index d7b7e89..38011e6 100644 --- a/transmit.py +++ b/transmit.py @@ -1,6 +1,7 @@ import multiprocessing import os import threading +import typing from multiprocessing import Process, Queue import time from multiprocessing.synchronize import Lock @@ -16,17 +17,15 @@ from config import Config from models import SpecDetector, RgbDetector from typing import Any, Union import logging -logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s %(message)s', - level=logging.WARNING) class Transmitter(object): _io_lock: Union[Lock, Lock] - def __init__(self, job_name:str, run_process:bool = False): + def __init__(self, job_name: str, run_process: bool = False): self.output = None self.job_name = job_name - self.run_process = run_process # If true, run process when started else run thread. + self.run_process = run_process # If true, run process when started else run thread. self._thread_stop = threading.Event() self._thread_stop.clear() self._running_handler = None @@ -56,11 +55,11 @@ class Transmitter(object): :param kwargs: :return: """ - name = kwargs.get('name', default='base thread') + name = kwargs.get('name', 'base thread') if not self.run_process: - self._running_handler = threading.Thread(target=self.job_func, name=name, args=args) + self._running_handler = threading.Thread(target=self.job_func, name=name, args=args, kwargs=kwargs) else: - self._running_handler = Process(target=self.job_func, name=name, args=args, daemon=True) + self._running_handler = Process(target=self.job_func, name=name, daemon=True, args=args, kwargs=kwargs) self._running_handler.start() def stop(self, *args, **kwargs): @@ -81,13 +80,14 @@ class Transmitter(object): @staticmethod def job_decorator(func): - functools.wraps(func) + @functools.wraps(func) def wrapper(self, *args, **kwargs): logging.info(f'{self.job_name} {"process" if self.run_process else "thread"} start.') while not self._thread_stop.is_set(): - self.job_func(*args, **kwargs) + func(self, *args, **kwargs) logging.info(f'{self.job_name} {"process" if self.run_process else "thread"} stop.') - self._need_stop.clear() + self._thread_stop.clear() + return wrapper def job_func(self, *args, **kwargs): @@ -107,7 +107,7 @@ class BeforeAfterMethods: logging.info(f"Get Spec threshold: {threshold}") return threshold else: - spec_img = np.frombuffer(data, dtype=np.float32).\ + spec_img = np.frombuffer(data, dtype=np.float32). \ reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1) logging.info(f"Get SPEC image with size {spec_img.shape}") return spec_img @@ -125,8 +125,9 @@ class BeforeAfterMethods: class FileReceiver(Transmitter): - def __init__(self, job_name:str, input_dir: str, output_queue:ImgQueue, speed: int=3, name_pattern=None): - super(FileReceiver, self).__init__(job_name=job_name, run_process=False) + def __init__(self, input_dir: str, output_queue: ImgQueue, speed: float = 3.0, name_pattern: str=None, + job_name: str = 'file_receiver', ): + super(FileReceiver, self).__init__(job_name=job_name) self.input_dir = input_dir self.send_speed = speed self.file_names = None @@ -137,7 +138,7 @@ class FileReceiver(Transmitter): self.set_source(input_dir, name_pattern) self.set_output(output_queue) - def set_source(self, input_dir:str, name_pattern:str=None, preprocess_method:callable=None): + def set_source(self, input_dir: str, name_pattern: str = None, preprocess_method: callable = None): self.name_pattern = name_pattern if name_pattern is not None else self.name_pattern file_names = os.listdir(input_dir) if len(file_names) == 0: @@ -152,11 +153,18 @@ class FileReceiver(Transmitter): self.file_idx = 0 def set_output(self, output: ImgQueue): - with self._io_lock: - self.output_queue = output + self.stop() + self.output_queue = output @Transmitter.job_decorator - def job_func(self, *args, **kwargs): + def job_func(self, need_time=False, *args, **kwargs): + """ + 发送文件 + + :param need_time: 是否需要发送时间戳 + :param kwargs: virtual_data: 虚拟的数据,用于测试 + :return: + """ with self._io_lock: self.file_idx += 1 if self.file_idx >= len(self.file_names): @@ -167,29 +175,37 @@ class FileReceiver(Transmitter): data = f.read() if self.preprocess_method is not None: data = self.preprocess_method(data) + if need_time: + data = (time.time(), data) + if 'virtual_data' in kwargs: + data = (*data, kwargs['virtual_data']) self.output_queue.put(data) - + time.sleep(self.send_speed) + logging.info(f'sleep {self.send_speed}s ...') class FifoReceiver(Transmitter): - def __init__(self, job_name:str, fifo_path: str, output: ImgQueue, - read_max_num: int, msg_queue=None): + def __init__(self, fifo_path: str, output: ImgQueue, + read_max_num: int, job_name: str = 'fifo_receiver'): super().__init__(job_name=job_name) self._input_fifo_path = None self._output_queue = None - self._msg_queue = msg_queue self._max_len = read_max_num self.set_source(fifo_path) self.set_output(output) def set_source(self, fifo_path: str): - if not os.access(fifo_path, os.F_OK): - os.mkfifo(fifo_path, 0o777) - self._input_fifo_path = fifo_path + self.stop() + with self._io_lock: + if not os.access(fifo_path, os.F_OK): + os.mkfifo(fifo_path, 0o777) + self._input_fifo_path = fifo_path def set_output(self, output: ImgQueue): - self._output_queue = output + self.stop() + with self._io_lock: + self._output_queue = output @Transmitter.job_decorator def job_func(self, post_process_func=None): @@ -203,34 +219,33 @@ class FifoReceiver(Transmitter): data = os.read(input_fifo, self._max_len) if post_process_func is not None: data = post_process_func(data) - self._output_queue.safe_put(data) + self._output_queue.put(data) os.close(input_fifo) - class FifoSender(Transmitter): - def __init__(self, output_fifo_path: str, source: ImgQueue): - super().__init__() + def __init__(self, fifo_path: str, source: ImgQueue, job_name: str = 'fifo_sender'): + super().__init__(job_name=job_name) self._input_source = None self._output_fifo_path = None self.set_source(source) - self.set_output(output_fifo_path) - self._need_stop = threading.Event() - self._need_stop.clear() - self._running_thread = None + self.set_output(fifo_path) def set_source(self, source: ImgQueue): - self._input_source = source + self.stop() + with self._io_lock: + self._input_source = source def set_output(self, output_fifo_path: str): - if not os.access(output_fifo_path, os.F_OK): - os.mkfifo(output_fifo_path, 0o777) - self._output_fifo_path = output_fifo_path + self.stop() + with self._io_lock: + if not os.access(output_fifo_path, os.F_OK): + os.mkfifo(output_fifo_path, 0o777) + self._output_fifo_path = output_fifo_path - - def job_func(self, pre_process, *args, **kwargs): + def job_func(self, pre_process=None, *args, **kwargs): """ - 接收线程 + 发送线程 :param pre_process: :return: @@ -240,21 +255,18 @@ class FifoSender(Transmitter): data = self._input_source.get() if pre_process is not None: data = pre_process(data) + logging.debug(f'put data to fifo {self._output_fifo_path}') output_fifo = os.open(self._output_fifo_path, os.O_WRONLY) os.write(output_fifo, data) os.close(output_fifo) - - - def __del__(self): - self.stop() - if self._running_thread is not None: - self._running_thread.join() + logging.debug(f'put data to fifo {self._output_fifo_path} done') class CmdImgSplitMidware(Transmitter): """ 用于控制命令和图像的中间件 """ + def __init__(self, subscribers: typing.Dict[str, ImgQueue], rgb_queue: ImgQueue, spec_queue: ImgQueue): super().__init__() self._rgb_queue = None @@ -308,6 +320,7 @@ class ImageSaver(Transmitter): """ 进行图片存储的中间件 """ + def set_source(self, *args, **kwargs): pass @@ -354,7 +367,7 @@ class ThreadDetector(Transmitter): mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8) t4 = time.time() logging.info(f'Detector finish merge within {(t4 - t3) * 1000: .2f}ms') - logging.info(f'Detector finish predict within {(time.time() -t1)*1000:.2f}ms') + logging.info(f'Detector finish predict within {(time.time() - t1) * 1000:.2f}ms') return mask_result def _predict_server(self): @@ -408,7 +421,7 @@ class ProcessDetector(Transmitter): masks = [cv2.resize(mask.astype(np.uint8), Config.target_size) for mask in masks] t4 = time.time() logging.debug(f'Detector finish merge within {(t4 - t3) * 1000: .2f}ms') - logging.debug(f'Detector finish predict within {(time.time() -t1)*1000:.2f}ms') + logging.debug(f'Detector finish predict within {(time.time() - t1) * 1000:.2f}ms') return masks def _predict_server(self): @@ -422,10 +435,14 @@ class ProcessDetector(Transmitter): class SplitMidware(Transmitter): def set_source(self, mask_source: ImgQueue): - + pass def start(self, *args, **kwargs): pass def stop(self, *args, **kwargs): - pass \ No newline at end of file + pass + + +if __name__ == '__main__': + pass diff --git a/utils.py b/utils.py index 6bd1bcb..ea0ef2f 100755 --- a/utils.py +++ b/utils.py @@ -186,7 +186,8 @@ def valve_limit(mask: np.ndarray, max_valve_num: int) -> np.ndarray: row_valve_count = np.sum(mask, axis=1) if np.any(row_valve_count > max_valve_num): over_rows_idx = np.argwhere(row_valve_count > max_valve_num).ravel() - logging.warning(f'发现单行喷阀数量{len(over_rows_idx)}超过限制,已限制到最大许可值{max_valve_num}') + logging.warning(f'发现有{len(over_rows_idx)}行的单行喷阀数量超过限制,原喷阀数量为{row_valve_count[over_rows_idx]},' + f'已全部限制到最大许可值{max_valve_num}') over_rows = mask[over_rows_idx, :] # a simple function to get lucky valves when too many valves appear in the same line