From a0d14940e850472d5229a12b2765fd97f012a009 Mon Sep 17 00:00:00 2001
From: "li.zhenye"
Date: Thu, 28 Jul 2022 12:39:54 +0800
Subject: [PATCH] =?UTF-8?q?=E5=A4=9A=E7=BA=BF=E7=A8=8B=E7=89=88=E6=9C=AC?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
efficient_ui.py | 51 +++++++++++
main.py | 14 ++--
models.py | 5 +-
transmit.py | 218 ++++++++++++++++++++++++++++++++++++++++++------
utils.py | 29 +++++++
5 files changed, 284 insertions(+), 33 deletions(-)
create mode 100644 efficient_ui.py
diff --git a/efficient_ui.py b/efficient_ui.py
new file mode 100644
index 0000000..dd627b6
--- /dev/null
+++ b/efficient_ui.py
@@ -0,0 +1,51 @@
+import cv2
+import numpy as np
+
+import transmit
+from config import Config
+
+from utils import ImgQueue as Queue
+
+
+class EfficientUI(object):
+ def __init__(self):
+ # 相关参数
+ img_fifo_path = "/tmp/dkimg.fifo"
+ mask_fifo_path = "/tmp/dkmask.fifo"
+ rgb_fifo_path = "/tmp/dkrgb.fifo"
+ # 创建队列用于链接各个线程
+ rgb_img_queue, spec_img_queue = Queue(), Queue()
+ detector_queue, save_queue, self.visual_queue = Queue(), Queue, Queue()
+ mask_queue = Queue()
+ # 两个接收者,接收光谱和rgb图像
+ spec_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节
+ rgb_len = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量
+ spec_receiver = transmit.FifoReceiver(fifo_path=img_fifo_path, output=spec_img_queue, read_max_num=spec_len)
+ rgb_receiver = transmit.FifoReceiver(fifo_path=rgb_fifo_path, output=rgb_img_queue, read_max_num=rgb_len)
+ # 指令执行与图像流向控制
+ subscribers = {'detector': detector_queue, 'visualize': self.visual_queue, 'save': save_queue}
+ cmd_img_controller = transmit.CmdImgSplitMidware(rgb_queue=rgb_img_queue, spec_queue=spec_img_queue,
+ subscribers=subscribers)
+ # 探测器
+ detector = transmit.ThreadDetector(input_queue=detector_queue, output_queue=mask_queue)
+ # 发送
+ sender = transmit.FifoSender(output_fifo_path=mask_fifo_path, source=mask_queue)
+ # 启动所有线程
+ spec_receiver.start(post_process_func=transmit.PostProcessMethods.spec_data_post_process, name='spce_thread')
+ rgb_receiver.start(post_process_func=transmit.PostProcessMethods.rgb_data_post_process, name='rgb_thread')
+ cmd_img_controller.start(name='control_thread')
+ detector.start(name='detector_thread')
+ sender.start(name='sender_thread')
+
+ def start(self):
+ # 启动图形化
+ while True:
+ cv2.imshow('image_show', mat=np.ones((256, 1024)))
+ key_code = cv2.waitKey(30)
+ if key_code == ord("s"):
+ pass
+
+
+if __name__ == '__main__':
+ app = EfficientUI()
+ app.start()
diff --git a/main.py b/main.py
index ee51cf4..9f4a8f6 100755
--- a/main.py
+++ b/main.py
@@ -1,12 +1,15 @@
import os
import time
+from queue import Queue
-import matplotlib.pyplot as plt
import numpy as np
-import scipy.io
+from matplotlib import pyplot as plt
+
+import models
+import transmit
from config import Config
-from models import RgbDetector, SpecDetector, ManualTree, AnonymousColorDetector
+from models import RgbDetector, SpecDetector
import cv2
@@ -56,20 +59,19 @@ def main():
mask = spec_detector.predict(img_data)
# rgb识别
mask_rgb = rgb_detector.predict(rgb_data)
-
# 结果合并
mask_result = (mask | mask_rgb).astype(np.uint8)
-
# mask_result = mask_rgb.astype(np.uint8)
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
t2 = time.time()
+ print(f'rgb len = {len(rgb_data)}')
# 写出
fd_mask = os.open(mask_fifo_path, os.O_WRONLY)
os.write(fd_mask, mask_result.tobytes())
os.close(fd_mask)
t3 = time.time()
- print(f'total time is:{t3 - t1}\n')
+ print(f'total time is:{t3 - t1}')
def read_c_captures(buffer_path, no_mask=True, nrows=256, ncols=1024, selected_bands=None):
diff --git a/models.py b/models.py
index 181651d..f503b96 100755
--- a/models.py
+++ b/models.py
@@ -5,10 +5,12 @@
# @Software:PyCharm、
import datetime
import pickle
+from queue import Queue
import cv2
import numpy as np
import scipy.io
+import threading
from scipy.ndimage import binary_dilation
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report
@@ -17,7 +19,8 @@ from sklearn.model_selection import train_test_split
from config import Config
from utils import lab_scatter, read_labeled_img, size_threshold
-deploy = True
+
+deploy = False
if not deploy:
print("Training env")
from tqdm import tqdm
diff --git a/transmit.py b/transmit.py
index b990718..c0fed63 100644
--- a/transmit.py
+++ b/transmit.py
@@ -1,11 +1,13 @@
import os
import threading
-from queue import Queue
+from utils import ImgQueue as Queue
import numpy as np
from config import Config
+from models import SpecDetector, RgbDetector
+import typing
-class Receiver(object):
+class Transmitter(object):
def __init__(self):
self.output = None
@@ -45,17 +47,42 @@ class Receiver(object):
raise NotImplementedError
-class FifoReceiver(Receiver):
- def __init__(self, fifo_path: str, output: Queue, read_max_num: int):
+class PostProcessMethods:
+ @classmethod
+ def spec_data_post_process(cls, data):
+ if len(data) < 3:
+ threshold = int(float(data))
+ print("[INFO] Get Spec threshold: ", threshold)
+ return threshold
+ else:
+ spec_img = np.frombuffer(data, dtype=np.float32).\
+ reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1)
+ return spec_img
+
+ @classmethod
+ def rgb_data_post_process(cls, data):
+ if len(data) < 3:
+ threshold = int(float(data))
+ print("[INFO] Get RGB threshold: ", threshold)
+ return threshold
+ else:
+ rgb_img = np.frombuffer(data, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
+ return rgb_img
+
+
+class FifoReceiver(Transmitter):
+ def __init__(self, fifo_path: str, output: Queue, read_max_num: int, msg_queue=None):
super().__init__()
self._input_fifo_path = None
self._output_queue = None
+ self._msg_queue = msg_queue
self._max_len = read_max_num
self.set_source(fifo_path)
self.set_output(output)
self._need_stop = threading.Event()
self._need_stop.clear()
+ self._running_thread = None
def set_source(self, fifo_path: str):
if not os.access(fifo_path, os.F_OK):
@@ -66,9 +93,9 @@ class FifoReceiver(Receiver):
self._output_queue = output
def start(self, post_process_func=None, name='fifo_receiver'):
- x = threading.Thread(target=self._receive_thread_func,
- name=name, args=(post_process_func, ))
- x.start()
+ self._running_thread = threading.Thread(target=self._receive_thread_func,
+ name=name, args=(post_process_func, ))
+ self._running_thread.start()
def stop(self):
self._need_stop.set()
@@ -85,27 +112,166 @@ class FifoReceiver(Receiver):
data = os.read(input_fifo, self._max_len)
if post_process_func is not None:
data = post_process_func(data)
- self._output_queue.put(data)
+ if not self._output_queue.safe_put(data):
+ if self._msg_queue is not None:
+ self._msg_queue.put('Fifo Receiver的接收者未来得及接收')
os.close(input_fifo)
self._need_stop.clear()
- @staticmethod
- def spec_data_post_process(data):
- if len(data) < 3:
- threshold = int(float(data))
- print("[INFO] Get Spec threshold: ", threshold)
- return threshold
- else:
- spec_img = np.frombuffer(data, dtype=np.float32).\
- reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1)
- return spec_img
+
+class FifoSender(Transmitter):
+ def __init__(self, output_fifo_path: str, source: Queue):
+ super().__init__()
+ self._input_source = None
+ self._output_fifo_path = None
+ self.set_source(source)
+ self.set_output(output_fifo_path)
+ self._need_stop = threading.Event()
+ self._need_stop.clear()
+ self._running_thread = None
+
+ def set_source(self, source: Queue):
+ self._input_source = source
+
+ def set_output(self, output_fifo_path: str):
+ if not os.access(output_fifo_path, os.F_OK):
+ os.mkfifo(output_fifo_path, 0o777)
+ self._output_fifo_path = output_fifo_path
+
+ def start(self, pre_process=None, name='fifo_receiver'):
+ self._running_thread = threading.Thread(target=self._send_thread_func, name=name,
+ args=(pre_process, ))
+ self._running_thread.start()
+
+ def stop(self):
+ self._need_stop.set()
+
+ def _send_thread_func(self, pre_process=None):
+ """
+ 接收线程
+
+ :param pre_process:
+ :return:
+ """
+ while not self._need_stop.is_set():
+ output_fifo = os.open(self._output_fifo_path, os.O_WRONLY)
+ if self._input_source.empty():
+ continue
+ data = self._input_source.get()
+ if pre_process is not None:
+ data = pre_process(data)
+ os.write(output_fifo, data)
+ os.close(output_fifo)
+ self._need_stop.clear()
@staticmethod
- def rgb_data_post_process(data):
- if len(data) < 3:
- threshold = int(float(data))
- print("[INFO] Get RGB threshold: ", threshold)
- return threshold
- else:
- rgb_img = np.frombuffer(data, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
- return rgb_img
+ def mask_preprocess(mask: np.ndarray):
+ return mask.tobytes()
+
+ def __del__(self):
+ self.stop()
+ if self._running_thread is not None:
+ self._running_thread.join()
+
+
+class CmdImgSplitMidware(Transmitter):
+ """
+ 用于控制命令和图像的中间件
+ """
+ def __init__(self, subscribers: typing.Dict[str, Queue], rgb_queue: Queue, spec_queue: Queue):
+ super().__init__()
+ self._rgb_queue = None
+ self._spec_queue = None
+ self._subscribers = None
+ self._server_thread = None
+ self.set_source(rgb_queue, spec_queue)
+ self.set_output(subscribers)
+ self.thread_stop = threading.Event()
+
+ def set_source(self, rgb_queue: Queue, spec_queue: Queue):
+ self._rgb_queue = rgb_queue
+ self._spec_queue = spec_queue
+
+ def set_output(self, output: typing.Dict[str, Queue]):
+ self._subscribers = output
+
+ def start(self, name='CMD_thread'):
+ self._server_thread = threading.Thread(target=self._cmd_control_service)
+ self._server_thread.start()
+
+ def stop(self):
+ self.thread_stop.set()
+
+ def _cmd_control_service(self):
+ while not self.thread_stop.is_set():
+ # 判断是否有数据,如果没有数据那么就等下次吧,如果有数据来,必须保证同时
+ if self._rgb_queue.empty() or self._spec_queue.empty():
+ continue
+ rgb_data = self._rgb_queue.get()
+ spec_data = self._spec_queue.get()
+ if isinstance(rgb_data, int) and isinstance(spec_data, int):
+ # 看是不是命令需要执行如果是命令,就执行
+ Config.rgb_size_threshold = rgb_data
+ Config.spec_size_threshold = spec_data
+ continue
+ elif isinstance(spec_data, np.ndarray) and isinstance(rgb_data, np.ndarray):
+ # 如果是图片,交给预测的人
+ for _, subscriber in self._subscribers.items():
+ subscriber.put((spec_data, rgb_data))
+ else:
+ # 否则程序出现毁灭性问题,立刻崩
+ raise Exception("两个相机传回的数据没有对上")
+ self.thread_stop.clear()
+
+
+class ImageSaver(Transmitter):
+ """
+ 进行图片存储的中间件
+ """
+ def set_source(self, *args, **kwargs):
+ pass
+
+ def start(self, *args, **kwargs):
+ pass
+
+ def stop(self, *args, **kwargs):
+ pass
+
+
+class ThreadDetector(Transmitter):
+ def __init__(self, input_queue: Queue, output_queue: Queue):
+ super().__init__()
+ self._input_queue, self._output_queue = input_queue, output_queue
+ self._spec_detector = SpecDetector(blk_model_path=Config.blk_model_path,
+ pixel_model_path=Config.pixel_model_path)
+ self._rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path,
+ background_model_path=Config.rgb_background_model_path)
+ self._predict_thread = None
+ self._thread_exit = threading.Event()
+
+ def set_source(self, img_queue: Queue):
+ self._input_queue = img_queue
+
+ def stop(self, *args, **kwargs):
+ self._thread_exit.set()
+
+ def start(self, name='predict_thread'):
+ self._predict_thread = threading.Thread(target=self._predict_server, name=name)
+ self._predict_thread.start()
+
+ def predict(self, spec: np.ndarray, rgb: np.ndarray):
+ mask = self._spec_detector.predict(spec)
+ # rgb识别
+ mask_rgb = self._rgb_detector.predict(rgb)
+ # 结果合并
+ mask_result = (mask | mask_rgb).astype(np.uint8)
+ mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
+ return mask_result
+
+ def _predict_server(self):
+ while not self._thread_exit.is_set():
+ if not self._input_queue.empty():
+ spec, rgb = self._input_queue.get()
+ mask = self.predict(spec, rgb)
+ self._output_queue.put(mask)
+ self._thread_exit.clear()
diff --git a/utils.py b/utils.py
index 25607e1..f004242 100755
--- a/utils.py
+++ b/utils.py
@@ -5,6 +5,7 @@
# @Software:PyCharm
import glob
import os
+from queue import Queue
import cv2
import numpy as np
@@ -26,6 +27,34 @@ class MergeDict(dict):
return self
+class ImgQueue(Queue):
+ """
+ A custom queue subclass that provides a :meth:`clear` method.
+ """
+
+ def clear(self):
+ """
+ Clears all items from the queue.
+ """
+
+ with self.mutex:
+ unfinished = self.unfinished_tasks - len(self.queue)
+ if unfinished <= 0:
+ if unfinished < 0:
+ raise ValueError('task_done() called too many times')
+ self.all_tasks_done.notify_all()
+ self.unfinished_tasks = unfinished
+ self.queue.clear()
+ self.not_full.notify_all()
+
+ def safe_put(self, *args, **kwargs):
+ if self.full():
+ _ = self.get()
+ return False
+ self.put(*args, **kwargs)
+ return True
+
+
def read_labeled_img(dataset_dir: str, color_dict: dict, ext='.bmp', is_ps_color_space=True) -> dict:
"""
根据dataset_dir下的文件创建数据集