From 597284c3e1ef3f9062dee5600fd9254a713ce838 Mon Sep 17 00:00:00 2001 From: "li.zhenye" <李> Date: Wed, 27 Jul 2022 13:40:35 +0800 Subject: [PATCH 1/5] remove the saving function --- main.py | 107 +------------------------------------------------------- 1 file changed, 1 insertion(+), 106 deletions(-) diff --git a/main.py b/main.py index 32602f9..7ac3b21 100755 --- a/main.py +++ b/main.py @@ -8,9 +8,6 @@ from models import RgbDetector, SpecDetector, ManualTree, AnonymousColorDetector import cv2 -SAVE_IMG, SAVE_NUM = False, 30 - - def main(): spec_detector = SpecDetector(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path) rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path, @@ -23,8 +20,6 @@ def main(): os.mkfifo(mask_fifo_path, 0o777) if not os.access(rgb_fifo_path, os.F_OK): os.mkfifo(rgb_fifo_path, 0o777) - if SAVE_IMG: - img_list = [] while True: fd_img = os.open(img_fifo_path, os.O_RDONLY) fd_rgb = os.open(rgb_fifo_path, os.O_RDONLY) @@ -55,11 +50,6 @@ def main(): img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)) \ .transpose(0, 2, 1) rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1)) - if SAVE_IMG: - SAVE_NUM -= 1 - img_list.append((rgb_data, img_data)) - if SAVE_NUM <= 0: - break # 光谱识别 mask = spec_detector.predict(img_data) # rgb识别 @@ -71,108 +61,13 @@ def main(): # mask_result = mask_rgb.astype(np.uint8) mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8) t2 = time.time() - print(f'rgb len = {len(rgb_data)}') # 写出 fd_mask = os.open(mask_fifo_path, os.O_WRONLY) os.write(fd_mask, mask_result.tobytes()) os.close(fd_mask) t3 = time.time() - print(f'total time is:{t3 - t1}') - for i, img in enumerate(img_list): - print(f"writing img {i}...") - cv2.imwrite(f"./{i}.png", img[0][..., ::-1]) - np.save(f'./{i}.npy', img[1]) - i += 1 - - - -def save_main(): - threshold = Config.spec_size_threshold - rgb_threshold = Config.rgb_size_threshold - manual_tree = ManualTree(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path) - tobacco_detector = AnonymousColorDetector(file_path=Config.rgb_tobacco_model_path) - background_detector = AnonymousColorDetector(file_path=Config.rgb_background_model_path) - total_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节 - total_rgb = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量 - if not os.access(img_fifo_path, os.F_OK): - os.mkfifo(img_fifo_path, 0o777) - if not os.access(mask_fifo_path, os.F_OK): - os.mkfifo(mask_fifo_path, 0o777) - if not os.access(rgb_fifo_path, os.F_OK): - os.mkfifo(rgb_fifo_path, 0o777) - img_list = [] - idx = 0 - while idx <= 30: - idx += 1 - fd_img = os.open(img_fifo_path, os.O_RDONLY) - fd_rgb = os.open(rgb_fifo_path, os.O_RDONLY) - data = os.read(fd_img, total_len) - - # 读取(开启一个管道) - if len(data) < 3: - threshold = int(float(data)) - print("[INFO] Get threshold: ", threshold) - continue - else: - data_total = data - rgb_data = os.read(fd_rgb, total_rgb) - if len(rgb_data) < 3: - rgb_threshold = int(float(rgb_data)) - print(rgb_threshold) - continue - else: - rgb_data_total = rgb_data - os.close(fd_img) - os.close(fd_rgb) - - # 识别 - t1 = time.time() - img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)). \ - transpose(0, 2, 1) - rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1)) - img_list.append((rgb_data.copy(), img_data.copy())) - - pixel_predict_result = manual_tree.pixel_predict_ml_dilation(data=img_data, iteration=1) - blk_predict_result = manual_tree.blk_predict(data=img_data) - rgb_data = tobacco_detector.pretreatment(rgb_data) - # print(rgb_data.shape) - rgb_predict_result = 1 - (background_detector.predict(rgb_data, threshold_low=Config.threshold_low, - threshold_high=Config.threshold_high) | - tobacco_detector.swell(tobacco_detector.predict(rgb_data, - threshold_low=Config.threshold_low, - threshold_high=Config.threshold_high))) - mask_rgb = rgb_predict_result.reshape(Config.nRows, Config.nCols // Config.blk_size, Config.blk_size) \ - .sum(axis=2).reshape(Config.nRows // 4, Config.blk_size, Config.nCols // Config.blk_size) \ - .sum(axis=1) - mask_rgb[mask_rgb <= rgb_threshold] = 0 - mask_rgb[mask_rgb > rgb_threshold] = 1 - mask = (pixel_predict_result & blk_predict_result).astype(np.uint8) - mask = mask.reshape(Config.nRows, Config.nCols // Config.blk_size, Config.blk_size) \ - .sum(axis=2).reshape(Config.nRows // 4, Config.blk_size, Config.nCols // Config.blk_size) \ - .sum(axis=1) - mask[mask <= threshold] = 0 - mask[mask > threshold] = 1 - mask_result = (mask | mask_rgb).astype(np.uint8) - # mask_result = mask_rgb - mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8) - t2 = time.time() - print(f'rgb len = {len(rgb_data)}') - - # 写出 - fd_mask = os.open(mask_fifo_path, os.O_WRONLY) - os.write(fd_mask, mask_result.tobytes()) - os.close(fd_mask) - t3 = time.time() - print(f'total time is:{t3 - t1}') - i = 0 - print("Stop Serving") - for img in img_list: - print(f"writing img {i}...") - cv2.imwrite(f"./{i}.png", img[0][..., ::-1]) - np.save(f'./{i}.npy', img[1]) - i += 1 - print("save success") + print(f'total time is:{t3 - t1}\n') if __name__ == '__main__': From 32010f35f00fec1e75d8f420979426dcc8792371 Mon Sep 17 00:00:00 2001 From: "li.zhenye" <李> Date: Wed, 27 Jul 2022 17:44:21 +0800 Subject: [PATCH 2/5] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86=E8=AF=BB?= =?UTF-8?q?=E5=8F=96c=E8=AF=AD=E8=A8=80buff=E7=9A=84=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config.py | 2 +- main.py | 23 +++++++++++++++++++++++ models.py | 2 +- 3 files changed, 25 insertions(+), 2 deletions(-) diff --git a/config.py b/config.py index 483c629..2e5266c 100755 --- a/config.py +++ b/config.py @@ -20,7 +20,7 @@ class Config: # 光谱模型参数 blk_size = 4 pixel_model_path = r"./models/dt.p" - blk_model_path = r"./models/rf_4x4_c22_20_sen8_8.model" + blk_model_path = r"./models/rf_4x4_c22_20_sen8_9.model" spec_size_threshold = 3 # rgb模型参数 diff --git a/main.py b/main.py index 7ac3b21..e368ea7 100755 --- a/main.py +++ b/main.py @@ -1,5 +1,7 @@ import os import time + +import matplotlib.pyplot as plt import numpy as np import scipy.io @@ -70,6 +72,26 @@ def main(): print(f'total time is:{t3 - t1}\n') +def read_c_captures(buffer_path): + buffer_names = [buffer_name for buffer_name in os.listdir(buffer_path) if buffer_name.endswith('.raw')] + for buffer_name in buffer_names: + with open(os.path.join(buffer_path, buffer_name), 'rb') as f: + data = f.read() + img = np.frombuffer(data, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)) \ + .transpose(0, 2, 1) + mask_name = buffer_name.replace('buf', 'mask').replace('.raw', '') + with open(os.path.join(buffer_path, mask_name), 'rb') as f: + data = f.read() + mask = np.frombuffer(data, dtype=np.uint8).reshape((256, 1024, -1)) + # mask = cv2.resize(mask, (1024, 256)) + fig, axs = plt.subplots(2, 1) + axs[0].imshow(img[..., [21, 3, 0]]) + axs[0].set_title(buffer_name) + axs[1].imshow(mask) + axs[1].set_title(mask_name) + plt.show() + + if __name__ == '__main__': # 相关参数 img_fifo_path = "/tmp/dkimg.fifo" @@ -78,3 +100,4 @@ if __name__ == '__main__': # 主函数 main() + # read_c_captures('/home/lzy/2022.7.20/tobacco_v1_0/') diff --git a/models.py b/models.py index 95d7f4a..181651d 100755 --- a/models.py +++ b/models.py @@ -415,7 +415,7 @@ class SpecDetector(Detector): # 烟梗mask中将背景赋值为0,将烟梗赋值为2 yellow_things[yellow_things] = tobacco yellow_things = yellow_things + 0 - # yellow_things = binary_dilation(yellow_things, iterations=iteration) + yellow_things = binary_dilation(yellow_things, iterations=iteration) yellow_things = yellow_things + 0 yellow_things[yellow_things == 1] = 2 From 1bd6e7bbe57ae7a20fa86102ae669a45637a333c Mon Sep 17 00:00:00 2001 From: "li.zhenye" <李> Date: Thu, 28 Jul 2022 11:03:21 +0800 Subject: [PATCH 3/5] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86=E8=AF=BB?= =?UTF-8?q?=E5=8F=96rawfile=20de=20function?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/main.py b/main.py index e368ea7..ee51cf4 100755 --- a/main.py +++ b/main.py @@ -72,20 +72,31 @@ def main(): print(f'total time is:{t3 - t1}\n') -def read_c_captures(buffer_path): - buffer_names = [buffer_name for buffer_name in os.listdir(buffer_path) if buffer_name.endswith('.raw')] +def read_c_captures(buffer_path, no_mask=True, nrows=256, ncols=1024, selected_bands=None): + if os.path.isdir(buffer_path): + buffer_names = [buffer_name for buffer_name in os.listdir(buffer_path) if buffer_name.endswith('.raw')] + else: + buffer_names = [buffer_path, ] for buffer_name in buffer_names: with open(os.path.join(buffer_path, buffer_name), 'rb') as f: data = f.read() - img = np.frombuffer(data, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)) \ + img = np.frombuffer(data, dtype=np.float32).reshape((nrows, -1, ncols)) \ .transpose(0, 2, 1) - mask_name = buffer_name.replace('buf', 'mask').replace('.raw', '') - with open(os.path.join(buffer_path, mask_name), 'rb') as f: - data = f.read() - mask = np.frombuffer(data, dtype=np.uint8).reshape((256, 1024, -1)) + if selected_bands is not None: + img = img[..., selected_bands] + if img.shape[0] == 1: + img = img[0, ...] + if not no_mask: + mask_name = buffer_name.replace('buf', 'mask').replace('.raw', '') + with open(os.path.join(buffer_path, mask_name), 'rb') as f: + data = f.read() + mask = np.frombuffer(data, dtype=np.uint8).reshape((nrows, ncols, -1)) + else: + mask_name = "no mask" + mask = np.zeros_like(img) # mask = cv2.resize(mask, (1024, 256)) fig, axs = plt.subplots(2, 1) - axs[0].imshow(img[..., [21, 3, 0]]) + axs[0].matshow(img) axs[0].set_title(buffer_name) axs[1].imshow(mask) axs[1].set_title(mask_name) @@ -100,4 +111,5 @@ if __name__ == '__main__': # 主函数 main() - # read_c_captures('/home/lzy/2022.7.20/tobacco_v1_0/') + # read_c_captures('/home/lzy/2022.7.15/tobacco_v1_0/', no_mask=True, nrows=256, ncols=1024, + # selected_bands=[380, 300, 200]) From a0d14940e850472d5229a12b2765fd97f012a009 Mon Sep 17 00:00:00 2001 From: "li.zhenye" Date: Thu, 28 Jul 2022 12:39:54 +0800 Subject: [PATCH 4/5] =?UTF-8?q?=E5=A4=9A=E7=BA=BF=E7=A8=8B=E7=89=88?= =?UTF-8?q?=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- efficient_ui.py | 51 +++++++++++ main.py | 14 ++-- models.py | 5 +- transmit.py | 218 ++++++++++++++++++++++++++++++++++++++++++------ utils.py | 29 +++++++ 5 files changed, 284 insertions(+), 33 deletions(-) create mode 100644 efficient_ui.py diff --git a/efficient_ui.py b/efficient_ui.py new file mode 100644 index 0000000..dd627b6 --- /dev/null +++ b/efficient_ui.py @@ -0,0 +1,51 @@ +import cv2 +import numpy as np + +import transmit +from config import Config + +from utils import ImgQueue as Queue + + +class EfficientUI(object): + def __init__(self): + # 相关参数 + img_fifo_path = "/tmp/dkimg.fifo" + mask_fifo_path = "/tmp/dkmask.fifo" + rgb_fifo_path = "/tmp/dkrgb.fifo" + # 创建队列用于链接各个线程 + rgb_img_queue, spec_img_queue = Queue(), Queue() + detector_queue, save_queue, self.visual_queue = Queue(), Queue, Queue() + mask_queue = Queue() + # 两个接收者,接收光谱和rgb图像 + spec_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节 + rgb_len = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量 + spec_receiver = transmit.FifoReceiver(fifo_path=img_fifo_path, output=spec_img_queue, read_max_num=spec_len) + rgb_receiver = transmit.FifoReceiver(fifo_path=rgb_fifo_path, output=rgb_img_queue, read_max_num=rgb_len) + # 指令执行与图像流向控制 + subscribers = {'detector': detector_queue, 'visualize': self.visual_queue, 'save': save_queue} + cmd_img_controller = transmit.CmdImgSplitMidware(rgb_queue=rgb_img_queue, spec_queue=spec_img_queue, + subscribers=subscribers) + # 探测器 + detector = transmit.ThreadDetector(input_queue=detector_queue, output_queue=mask_queue) + # 发送 + sender = transmit.FifoSender(output_fifo_path=mask_fifo_path, source=mask_queue) + # 启动所有线程 + spec_receiver.start(post_process_func=transmit.PostProcessMethods.spec_data_post_process, name='spce_thread') + rgb_receiver.start(post_process_func=transmit.PostProcessMethods.rgb_data_post_process, name='rgb_thread') + cmd_img_controller.start(name='control_thread') + detector.start(name='detector_thread') + sender.start(name='sender_thread') + + def start(self): + # 启动图形化 + while True: + cv2.imshow('image_show', mat=np.ones((256, 1024))) + key_code = cv2.waitKey(30) + if key_code == ord("s"): + pass + + +if __name__ == '__main__': + app = EfficientUI() + app.start() diff --git a/main.py b/main.py index ee51cf4..9f4a8f6 100755 --- a/main.py +++ b/main.py @@ -1,12 +1,15 @@ import os import time +from queue import Queue -import matplotlib.pyplot as plt import numpy as np -import scipy.io +from matplotlib import pyplot as plt + +import models +import transmit from config import Config -from models import RgbDetector, SpecDetector, ManualTree, AnonymousColorDetector +from models import RgbDetector, SpecDetector import cv2 @@ -56,20 +59,19 @@ def main(): mask = spec_detector.predict(img_data) # rgb识别 mask_rgb = rgb_detector.predict(rgb_data) - # 结果合并 mask_result = (mask | mask_rgb).astype(np.uint8) - # mask_result = mask_rgb.astype(np.uint8) mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8) t2 = time.time() + print(f'rgb len = {len(rgb_data)}') # 写出 fd_mask = os.open(mask_fifo_path, os.O_WRONLY) os.write(fd_mask, mask_result.tobytes()) os.close(fd_mask) t3 = time.time() - print(f'total time is:{t3 - t1}\n') + print(f'total time is:{t3 - t1}') def read_c_captures(buffer_path, no_mask=True, nrows=256, ncols=1024, selected_bands=None): diff --git a/models.py b/models.py index 181651d..f503b96 100755 --- a/models.py +++ b/models.py @@ -5,10 +5,12 @@ # @Software:PyCharm、 import datetime import pickle +from queue import Queue import cv2 import numpy as np import scipy.io +import threading from scipy.ndimage import binary_dilation from sklearn.tree import DecisionTreeClassifier from sklearn.metrics import classification_report @@ -17,7 +19,8 @@ from sklearn.model_selection import train_test_split from config import Config from utils import lab_scatter, read_labeled_img, size_threshold -deploy = True + +deploy = False if not deploy: print("Training env") from tqdm import tqdm diff --git a/transmit.py b/transmit.py index b990718..c0fed63 100644 --- a/transmit.py +++ b/transmit.py @@ -1,11 +1,13 @@ import os import threading -from queue import Queue +from utils import ImgQueue as Queue import numpy as np from config import Config +from models import SpecDetector, RgbDetector +import typing -class Receiver(object): +class Transmitter(object): def __init__(self): self.output = None @@ -45,17 +47,42 @@ class Receiver(object): raise NotImplementedError -class FifoReceiver(Receiver): - def __init__(self, fifo_path: str, output: Queue, read_max_num: int): +class PostProcessMethods: + @classmethod + def spec_data_post_process(cls, data): + if len(data) < 3: + threshold = int(float(data)) + print("[INFO] Get Spec threshold: ", threshold) + return threshold + else: + spec_img = np.frombuffer(data, dtype=np.float32).\ + reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1) + return spec_img + + @classmethod + def rgb_data_post_process(cls, data): + if len(data) < 3: + threshold = int(float(data)) + print("[INFO] Get RGB threshold: ", threshold) + return threshold + else: + rgb_img = np.frombuffer(data, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1)) + return rgb_img + + +class FifoReceiver(Transmitter): + def __init__(self, fifo_path: str, output: Queue, read_max_num: int, msg_queue=None): super().__init__() self._input_fifo_path = None self._output_queue = None + self._msg_queue = msg_queue self._max_len = read_max_num self.set_source(fifo_path) self.set_output(output) self._need_stop = threading.Event() self._need_stop.clear() + self._running_thread = None def set_source(self, fifo_path: str): if not os.access(fifo_path, os.F_OK): @@ -66,9 +93,9 @@ class FifoReceiver(Receiver): self._output_queue = output def start(self, post_process_func=None, name='fifo_receiver'): - x = threading.Thread(target=self._receive_thread_func, - name=name, args=(post_process_func, )) - x.start() + self._running_thread = threading.Thread(target=self._receive_thread_func, + name=name, args=(post_process_func, )) + self._running_thread.start() def stop(self): self._need_stop.set() @@ -85,27 +112,166 @@ class FifoReceiver(Receiver): data = os.read(input_fifo, self._max_len) if post_process_func is not None: data = post_process_func(data) - self._output_queue.put(data) + if not self._output_queue.safe_put(data): + if self._msg_queue is not None: + self._msg_queue.put('Fifo Receiver的接收者未来得及接收') os.close(input_fifo) self._need_stop.clear() - @staticmethod - def spec_data_post_process(data): - if len(data) < 3: - threshold = int(float(data)) - print("[INFO] Get Spec threshold: ", threshold) - return threshold - else: - spec_img = np.frombuffer(data, dtype=np.float32).\ - reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1) - return spec_img + +class FifoSender(Transmitter): + def __init__(self, output_fifo_path: str, source: Queue): + super().__init__() + self._input_source = None + self._output_fifo_path = None + self.set_source(source) + self.set_output(output_fifo_path) + self._need_stop = threading.Event() + self._need_stop.clear() + self._running_thread = None + + def set_source(self, source: Queue): + self._input_source = source + + def set_output(self, output_fifo_path: str): + if not os.access(output_fifo_path, os.F_OK): + os.mkfifo(output_fifo_path, 0o777) + self._output_fifo_path = output_fifo_path + + def start(self, pre_process=None, name='fifo_receiver'): + self._running_thread = threading.Thread(target=self._send_thread_func, name=name, + args=(pre_process, )) + self._running_thread.start() + + def stop(self): + self._need_stop.set() + + def _send_thread_func(self, pre_process=None): + """ + 接收线程 + + :param pre_process: + :return: + """ + while not self._need_stop.is_set(): + output_fifo = os.open(self._output_fifo_path, os.O_WRONLY) + if self._input_source.empty(): + continue + data = self._input_source.get() + if pre_process is not None: + data = pre_process(data) + os.write(output_fifo, data) + os.close(output_fifo) + self._need_stop.clear() @staticmethod - def rgb_data_post_process(data): - if len(data) < 3: - threshold = int(float(data)) - print("[INFO] Get RGB threshold: ", threshold) - return threshold - else: - rgb_img = np.frombuffer(data, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1)) - return rgb_img + def mask_preprocess(mask: np.ndarray): + return mask.tobytes() + + def __del__(self): + self.stop() + if self._running_thread is not None: + self._running_thread.join() + + +class CmdImgSplitMidware(Transmitter): + """ + 用于控制命令和图像的中间件 + """ + def __init__(self, subscribers: typing.Dict[str, Queue], rgb_queue: Queue, spec_queue: Queue): + super().__init__() + self._rgb_queue = None + self._spec_queue = None + self._subscribers = None + self._server_thread = None + self.set_source(rgb_queue, spec_queue) + self.set_output(subscribers) + self.thread_stop = threading.Event() + + def set_source(self, rgb_queue: Queue, spec_queue: Queue): + self._rgb_queue = rgb_queue + self._spec_queue = spec_queue + + def set_output(self, output: typing.Dict[str, Queue]): + self._subscribers = output + + def start(self, name='CMD_thread'): + self._server_thread = threading.Thread(target=self._cmd_control_service) + self._server_thread.start() + + def stop(self): + self.thread_stop.set() + + def _cmd_control_service(self): + while not self.thread_stop.is_set(): + # 判断是否有数据,如果没有数据那么就等下次吧,如果有数据来,必须保证同时 + if self._rgb_queue.empty() or self._spec_queue.empty(): + continue + rgb_data = self._rgb_queue.get() + spec_data = self._spec_queue.get() + if isinstance(rgb_data, int) and isinstance(spec_data, int): + # 看是不是命令需要执行如果是命令,就执行 + Config.rgb_size_threshold = rgb_data + Config.spec_size_threshold = spec_data + continue + elif isinstance(spec_data, np.ndarray) and isinstance(rgb_data, np.ndarray): + # 如果是图片,交给预测的人 + for _, subscriber in self._subscribers.items(): + subscriber.put((spec_data, rgb_data)) + else: + # 否则程序出现毁灭性问题,立刻崩 + raise Exception("两个相机传回的数据没有对上") + self.thread_stop.clear() + + +class ImageSaver(Transmitter): + """ + 进行图片存储的中间件 + """ + def set_source(self, *args, **kwargs): + pass + + def start(self, *args, **kwargs): + pass + + def stop(self, *args, **kwargs): + pass + + +class ThreadDetector(Transmitter): + def __init__(self, input_queue: Queue, output_queue: Queue): + super().__init__() + self._input_queue, self._output_queue = input_queue, output_queue + self._spec_detector = SpecDetector(blk_model_path=Config.blk_model_path, + pixel_model_path=Config.pixel_model_path) + self._rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path, + background_model_path=Config.rgb_background_model_path) + self._predict_thread = None + self._thread_exit = threading.Event() + + def set_source(self, img_queue: Queue): + self._input_queue = img_queue + + def stop(self, *args, **kwargs): + self._thread_exit.set() + + def start(self, name='predict_thread'): + self._predict_thread = threading.Thread(target=self._predict_server, name=name) + self._predict_thread.start() + + def predict(self, spec: np.ndarray, rgb: np.ndarray): + mask = self._spec_detector.predict(spec) + # rgb识别 + mask_rgb = self._rgb_detector.predict(rgb) + # 结果合并 + mask_result = (mask | mask_rgb).astype(np.uint8) + mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8) + return mask_result + + def _predict_server(self): + while not self._thread_exit.is_set(): + if not self._input_queue.empty(): + spec, rgb = self._input_queue.get() + mask = self.predict(spec, rgb) + self._output_queue.put(mask) + self._thread_exit.clear() diff --git a/utils.py b/utils.py index 25607e1..f004242 100755 --- a/utils.py +++ b/utils.py @@ -5,6 +5,7 @@ # @Software:PyCharm import glob import os +from queue import Queue import cv2 import numpy as np @@ -26,6 +27,34 @@ class MergeDict(dict): return self +class ImgQueue(Queue): + """ + A custom queue subclass that provides a :meth:`clear` method. + """ + + def clear(self): + """ + Clears all items from the queue. + """ + + with self.mutex: + unfinished = self.unfinished_tasks - len(self.queue) + if unfinished <= 0: + if unfinished < 0: + raise ValueError('task_done() called too many times') + self.all_tasks_done.notify_all() + self.unfinished_tasks = unfinished + self.queue.clear() + self.not_full.notify_all() + + def safe_put(self, *args, **kwargs): + if self.full(): + _ = self.get() + return False + self.put(*args, **kwargs) + return True + + def read_labeled_img(dataset_dir: str, color_dict: dict, ext='.bmp', is_ps_color_space=True) -> dict: """ 根据dataset_dir下的文件创建数据集 From d697855abc79f7ce12f76d8655a9ba67c9576c41 Mon Sep 17 00:00:00 2001 From: "li.zhenye" <李> Date: Thu, 28 Jul 2022 13:47:32 +0800 Subject: [PATCH 5/5] Slow MultiThread Version --- efficient_ui.py | 8 ++++---- models.py | 2 +- transmit.py | 46 +++++++++++++++++++++++++++++++--------------- utils.py | 4 ++-- 4 files changed, 38 insertions(+), 22 deletions(-) diff --git a/efficient_ui.py b/efficient_ui.py index dd627b6..0c74260 100644 --- a/efficient_ui.py +++ b/efficient_ui.py @@ -15,7 +15,7 @@ class EfficientUI(object): rgb_fifo_path = "/tmp/dkrgb.fifo" # 创建队列用于链接各个线程 rgb_img_queue, spec_img_queue = Queue(), Queue() - detector_queue, save_queue, self.visual_queue = Queue(), Queue, Queue() + detector_queue, save_queue, self.visual_queue = Queue(), Queue(), Queue() mask_queue = Queue() # 两个接收者,接收光谱和rgb图像 spec_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节 @@ -31,11 +31,11 @@ class EfficientUI(object): # 发送 sender = transmit.FifoSender(output_fifo_path=mask_fifo_path, source=mask_queue) # 启动所有线程 - spec_receiver.start(post_process_func=transmit.PostProcessMethods.spec_data_post_process, name='spce_thread') - rgb_receiver.start(post_process_func=transmit.PostProcessMethods.rgb_data_post_process, name='rgb_thread') + spec_receiver.start(post_process_func=transmit.BeforeAfterMethods.spec_data_post_process, name='spce_thread') + rgb_receiver.start(post_process_func=transmit.BeforeAfterMethods.rgb_data_post_process, name='rgb_thread') cmd_img_controller.start(name='control_thread') detector.start(name='detector_thread') - sender.start(name='sender_thread') + sender.start(pre_process=transmit.BeforeAfterMethods.mask_preprocess, name='sender_thread') def start(self): # 启动图形化 diff --git a/models.py b/models.py index f503b96..abb1a74 100755 --- a/models.py +++ b/models.py @@ -20,7 +20,7 @@ from config import Config from utils import lab_scatter, read_labeled_img, size_threshold -deploy = False +deploy = True if not deploy: print("Training env") from tqdm import tqdm diff --git a/transmit.py b/transmit.py index c0fed63..2632a12 100644 --- a/transmit.py +++ b/transmit.py @@ -1,10 +1,15 @@ import os import threading +import time + from utils import ImgQueue as Queue import numpy as np from config import Config from models import SpecDetector, RgbDetector import typing +import logging +logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s %(message)s', + level=logging.DEBUG) class Transmitter(object): @@ -47,26 +52,33 @@ class Transmitter(object): raise NotImplementedError -class PostProcessMethods: +class BeforeAfterMethods: + @classmethod + def mask_preprocess(cls, mask: np.ndarray): + logging.info(f"Send mask with size {mask.shape}") + return mask.tobytes() + @classmethod def spec_data_post_process(cls, data): if len(data) < 3: threshold = int(float(data)) - print("[INFO] Get Spec threshold: ", threshold) + logging.info(f"Get Spec threshold: {threshold}") return threshold else: spec_img = np.frombuffer(data, dtype=np.float32).\ reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1) + logging.info(f"Get SPEC image with size {spec_img.shape}") return spec_img @classmethod def rgb_data_post_process(cls, data): if len(data) < 3: threshold = int(float(data)) - print("[INFO] Get RGB threshold: ", threshold) + logging.info(f"Get RGB threshold: {threshold}") return threshold else: rgb_img = np.frombuffer(data, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1)) + logging.info(f"Get RGB img with size {rgb_img.shape}") return rgb_img @@ -112,9 +124,7 @@ class FifoReceiver(Transmitter): data = os.read(input_fifo, self._max_len) if post_process_func is not None: data = post_process_func(data) - if not self._output_queue.safe_put(data): - if self._msg_queue is not None: - self._msg_queue.put('Fifo Receiver的接收者未来得及接收') + self._output_queue.safe_put(data) os.close(input_fifo) self._need_stop.clear() @@ -154,20 +164,16 @@ class FifoSender(Transmitter): :return: """ while not self._need_stop.is_set(): - output_fifo = os.open(self._output_fifo_path, os.O_WRONLY) if self._input_source.empty(): continue data = self._input_source.get() if pre_process is not None: data = pre_process(data) + output_fifo = os.open(self._output_fifo_path, os.O_WRONLY) os.write(output_fifo, data) os.close(output_fifo) self._need_stop.clear() - @staticmethod - def mask_preprocess(mask: np.ndarray): - return mask.tobytes() - def __del__(self): self.stop() if self._running_thread is not None: @@ -196,7 +202,7 @@ class CmdImgSplitMidware(Transmitter): self._subscribers = output def start(self, name='CMD_thread'): - self._server_thread = threading.Thread(target=self._cmd_control_service) + self._server_thread = threading.Thread(target=self._cmd_control_service, name=name) self._server_thread.start() def stop(self): @@ -216,8 +222,9 @@ class CmdImgSplitMidware(Transmitter): continue elif isinstance(spec_data, np.ndarray) and isinstance(rgb_data, np.ndarray): # 如果是图片,交给预测的人 - for _, subscriber in self._subscribers.items(): - subscriber.put((spec_data, rgb_data)) + for name, subscriber in self._subscribers.items(): + item = (spec_data, rgb_data) + subscriber.safe_put(item) else: # 否则程序出现毁灭性问题,立刻崩 raise Exception("两个相机传回的数据没有对上") @@ -260,12 +267,21 @@ class ThreadDetector(Transmitter): self._predict_thread.start() def predict(self, spec: np.ndarray, rgb: np.ndarray): + logging.info(f'Detector get image with shape {spec.shape} and {rgb.shape}') + t1 = time.time() mask = self._spec_detector.predict(spec) + t2 = time.time() + logging.info(f'Detector finish spec predict within {(t2 - t1) * 1000:.2f}ms') # rgb识别 mask_rgb = self._rgb_detector.predict(rgb) + t3 = time.time() + logging.info(f'Detector finish rgb predict within {(t3 - t2) * 1000:.2f}ms') # 结果合并 mask_result = (mask | mask_rgb).astype(np.uint8) mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8) + t4 = time.time() + logging.info(f'Detector finish merge within {(t4 - t3) * 1000: .2f}ms') + logging.info(f'Detector finish predict within {(time.time() -t1)*1000:.2f}ms') return mask_result def _predict_server(self): @@ -273,5 +289,5 @@ class ThreadDetector(Transmitter): if not self._input_queue.empty(): spec, rgb = self._input_queue.get() mask = self.predict(spec, rgb) - self._output_queue.put(mask) + self._output_queue.safe_put(mask) self._thread_exit.clear() diff --git a/utils.py b/utils.py index f004242..bf466f8 100755 --- a/utils.py +++ b/utils.py @@ -47,11 +47,11 @@ class ImgQueue(Queue): self.queue.clear() self.not_full.notify_all() - def safe_put(self, *args, **kwargs): + def safe_put(self, item): if self.full(): _ = self.get() return False - self.put(*args, **kwargs) + self.put(item) return True