diff --git a/config.py b/config.py index 483c629..2e5266c 100644 --- a/config.py +++ b/config.py @@ -20,7 +20,7 @@ class Config: # 光谱模型参数 blk_size = 4 pixel_model_path = r"./models/dt.p" - blk_model_path = r"./models/rf_4x4_c22_20_sen8_8.model" + blk_model_path = r"./models/rf_4x4_c22_20_sen8_9.model" spec_size_threshold = 3 # rgb模型参数 diff --git a/efficient_ui.py b/efficient_ui.py new file mode 100644 index 0000000..0c74260 --- /dev/null +++ b/efficient_ui.py @@ -0,0 +1,51 @@ +import cv2 +import numpy as np + +import transmit +from config import Config + +from utils import ImgQueue as Queue + + +class EfficientUI(object): + def __init__(self): + # 相关参数 + img_fifo_path = "/tmp/dkimg.fifo" + mask_fifo_path = "/tmp/dkmask.fifo" + rgb_fifo_path = "/tmp/dkrgb.fifo" + # 创建队列用于链接各个线程 + rgb_img_queue, spec_img_queue = Queue(), Queue() + detector_queue, save_queue, self.visual_queue = Queue(), Queue(), Queue() + mask_queue = Queue() + # 两个接收者,接收光谱和rgb图像 + spec_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节 + rgb_len = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量 + spec_receiver = transmit.FifoReceiver(fifo_path=img_fifo_path, output=spec_img_queue, read_max_num=spec_len) + rgb_receiver = transmit.FifoReceiver(fifo_path=rgb_fifo_path, output=rgb_img_queue, read_max_num=rgb_len) + # 指令执行与图像流向控制 + subscribers = {'detector': detector_queue, 'visualize': self.visual_queue, 'save': save_queue} + cmd_img_controller = transmit.CmdImgSplitMidware(rgb_queue=rgb_img_queue, spec_queue=spec_img_queue, + subscribers=subscribers) + # 探测器 + detector = transmit.ThreadDetector(input_queue=detector_queue, output_queue=mask_queue) + # 发送 + sender = transmit.FifoSender(output_fifo_path=mask_fifo_path, source=mask_queue) + # 启动所有线程 + spec_receiver.start(post_process_func=transmit.BeforeAfterMethods.spec_data_post_process, name='spce_thread') + rgb_receiver.start(post_process_func=transmit.BeforeAfterMethods.rgb_data_post_process, name='rgb_thread') + cmd_img_controller.start(name='control_thread') + detector.start(name='detector_thread') + sender.start(pre_process=transmit.BeforeAfterMethods.mask_preprocess, name='sender_thread') + + def start(self): + # 启动图形化 + while True: + cv2.imshow('image_show', mat=np.ones((256, 1024))) + key_code = cv2.waitKey(30) + if key_code == ord("s"): + pass + + +if __name__ == '__main__': + app = EfficientUI() + app.start() diff --git a/main.py b/main.py index 32602f9..9f4a8f6 100755 --- a/main.py +++ b/main.py @@ -1,16 +1,18 @@ import os import time +from queue import Queue + import numpy as np -import scipy.io +from matplotlib import pyplot as plt + +import models +import transmit from config import Config -from models import RgbDetector, SpecDetector, ManualTree, AnonymousColorDetector +from models import RgbDetector, SpecDetector import cv2 -SAVE_IMG, SAVE_NUM = False, 30 - - def main(): spec_detector = SpecDetector(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path) rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path, @@ -23,8 +25,6 @@ def main(): os.mkfifo(mask_fifo_path, 0o777) if not os.access(rgb_fifo_path, os.F_OK): os.mkfifo(rgb_fifo_path, 0o777) - if SAVE_IMG: - img_list = [] while True: fd_img = os.open(img_fifo_path, os.O_RDONLY) fd_rgb = os.open(rgb_fifo_path, os.O_RDONLY) @@ -55,19 +55,12 @@ def main(): img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)) \ .transpose(0, 2, 1) rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1)) - if SAVE_IMG: - SAVE_NUM -= 1 - img_list.append((rgb_data, img_data)) - if SAVE_NUM <= 0: - break # 光谱识别 mask = spec_detector.predict(img_data) # rgb识别 mask_rgb = rgb_detector.predict(rgb_data) - # 结果合并 mask_result = (mask | mask_rgb).astype(np.uint8) - # mask_result = mask_rgb.astype(np.uint8) mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8) t2 = time.time() @@ -79,100 +72,37 @@ def main(): os.close(fd_mask) t3 = time.time() print(f'total time is:{t3 - t1}') - for i, img in enumerate(img_list): - print(f"writing img {i}...") - cv2.imwrite(f"./{i}.png", img[0][..., ::-1]) - np.save(f'./{i}.npy', img[1]) - i += 1 - -def save_main(): - threshold = Config.spec_size_threshold - rgb_threshold = Config.rgb_size_threshold - manual_tree = ManualTree(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path) - tobacco_detector = AnonymousColorDetector(file_path=Config.rgb_tobacco_model_path) - background_detector = AnonymousColorDetector(file_path=Config.rgb_background_model_path) - total_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节 - total_rgb = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量 - if not os.access(img_fifo_path, os.F_OK): - os.mkfifo(img_fifo_path, 0o777) - if not os.access(mask_fifo_path, os.F_OK): - os.mkfifo(mask_fifo_path, 0o777) - if not os.access(rgb_fifo_path, os.F_OK): - os.mkfifo(rgb_fifo_path, 0o777) - img_list = [] - idx = 0 - while idx <= 30: - idx += 1 - fd_img = os.open(img_fifo_path, os.O_RDONLY) - fd_rgb = os.open(rgb_fifo_path, os.O_RDONLY) - data = os.read(fd_img, total_len) - - # 读取(开启一个管道) - if len(data) < 3: - threshold = int(float(data)) - print("[INFO] Get threshold: ", threshold) - continue +def read_c_captures(buffer_path, no_mask=True, nrows=256, ncols=1024, selected_bands=None): + if os.path.isdir(buffer_path): + buffer_names = [buffer_name for buffer_name in os.listdir(buffer_path) if buffer_name.endswith('.raw')] + else: + buffer_names = [buffer_path, ] + for buffer_name in buffer_names: + with open(os.path.join(buffer_path, buffer_name), 'rb') as f: + data = f.read() + img = np.frombuffer(data, dtype=np.float32).reshape((nrows, -1, ncols)) \ + .transpose(0, 2, 1) + if selected_bands is not None: + img = img[..., selected_bands] + if img.shape[0] == 1: + img = img[0, ...] + if not no_mask: + mask_name = buffer_name.replace('buf', 'mask').replace('.raw', '') + with open(os.path.join(buffer_path, mask_name), 'rb') as f: + data = f.read() + mask = np.frombuffer(data, dtype=np.uint8).reshape((nrows, ncols, -1)) else: - data_total = data - rgb_data = os.read(fd_rgb, total_rgb) - if len(rgb_data) < 3: - rgb_threshold = int(float(rgb_data)) - print(rgb_threshold) - continue - else: - rgb_data_total = rgb_data - os.close(fd_img) - os.close(fd_rgb) - - # 识别 - t1 = time.time() - img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)). \ - transpose(0, 2, 1) - rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1)) - img_list.append((rgb_data.copy(), img_data.copy())) - - pixel_predict_result = manual_tree.pixel_predict_ml_dilation(data=img_data, iteration=1) - blk_predict_result = manual_tree.blk_predict(data=img_data) - rgb_data = tobacco_detector.pretreatment(rgb_data) - # print(rgb_data.shape) - rgb_predict_result = 1 - (background_detector.predict(rgb_data, threshold_low=Config.threshold_low, - threshold_high=Config.threshold_high) | - tobacco_detector.swell(tobacco_detector.predict(rgb_data, - threshold_low=Config.threshold_low, - threshold_high=Config.threshold_high))) - mask_rgb = rgb_predict_result.reshape(Config.nRows, Config.nCols // Config.blk_size, Config.blk_size) \ - .sum(axis=2).reshape(Config.nRows // 4, Config.blk_size, Config.nCols // Config.blk_size) \ - .sum(axis=1) - mask_rgb[mask_rgb <= rgb_threshold] = 0 - mask_rgb[mask_rgb > rgb_threshold] = 1 - mask = (pixel_predict_result & blk_predict_result).astype(np.uint8) - mask = mask.reshape(Config.nRows, Config.nCols // Config.blk_size, Config.blk_size) \ - .sum(axis=2).reshape(Config.nRows // 4, Config.blk_size, Config.nCols // Config.blk_size) \ - .sum(axis=1) - mask[mask <= threshold] = 0 - mask[mask > threshold] = 1 - mask_result = (mask | mask_rgb).astype(np.uint8) - # mask_result = mask_rgb - mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8) - t2 = time.time() - print(f'rgb len = {len(rgb_data)}') - - # 写出 - fd_mask = os.open(mask_fifo_path, os.O_WRONLY) - os.write(fd_mask, mask_result.tobytes()) - os.close(fd_mask) - t3 = time.time() - print(f'total time is:{t3 - t1}') - i = 0 - print("Stop Serving") - for img in img_list: - print(f"writing img {i}...") - cv2.imwrite(f"./{i}.png", img[0][..., ::-1]) - np.save(f'./{i}.npy', img[1]) - i += 1 - print("save success") + mask_name = "no mask" + mask = np.zeros_like(img) + # mask = cv2.resize(mask, (1024, 256)) + fig, axs = plt.subplots(2, 1) + axs[0].matshow(img) + axs[0].set_title(buffer_name) + axs[1].imshow(mask) + axs[1].set_title(mask_name) + plt.show() if __name__ == '__main__': @@ -183,3 +113,5 @@ if __name__ == '__main__': # 主函数 main() + # read_c_captures('/home/lzy/2022.7.15/tobacco_v1_0/', no_mask=True, nrows=256, ncols=1024, + # selected_bands=[380, 300, 200]) diff --git a/models.py b/models.py index 95d7f4a..abb1a74 100755 --- a/models.py +++ b/models.py @@ -5,10 +5,12 @@ # @Software:PyCharm、 import datetime import pickle +from queue import Queue import cv2 import numpy as np import scipy.io +import threading from scipy.ndimage import binary_dilation from sklearn.tree import DecisionTreeClassifier from sklearn.metrics import classification_report @@ -17,6 +19,7 @@ from sklearn.model_selection import train_test_split from config import Config from utils import lab_scatter, read_labeled_img, size_threshold + deploy = True if not deploy: print("Training env") @@ -415,7 +418,7 @@ class SpecDetector(Detector): # 烟梗mask中将背景赋值为0,将烟梗赋值为2 yellow_things[yellow_things] = tobacco yellow_things = yellow_things + 0 - # yellow_things = binary_dilation(yellow_things, iterations=iteration) + yellow_things = binary_dilation(yellow_things, iterations=iteration) yellow_things = yellow_things + 0 yellow_things[yellow_things == 1] = 2 diff --git a/transmit.py b/transmit.py index b990718..2632a12 100644 --- a/transmit.py +++ b/transmit.py @@ -1,11 +1,18 @@ import os import threading -from queue import Queue +import time + +from utils import ImgQueue as Queue import numpy as np from config import Config +from models import SpecDetector, RgbDetector +import typing +import logging +logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s %(message)s', + level=logging.DEBUG) -class Receiver(object): +class Transmitter(object): def __init__(self): self.output = None @@ -45,17 +52,49 @@ class Receiver(object): raise NotImplementedError -class FifoReceiver(Receiver): - def __init__(self, fifo_path: str, output: Queue, read_max_num: int): +class BeforeAfterMethods: + @classmethod + def mask_preprocess(cls, mask: np.ndarray): + logging.info(f"Send mask with size {mask.shape}") + return mask.tobytes() + + @classmethod + def spec_data_post_process(cls, data): + if len(data) < 3: + threshold = int(float(data)) + logging.info(f"Get Spec threshold: {threshold}") + return threshold + else: + spec_img = np.frombuffer(data, dtype=np.float32).\ + reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1) + logging.info(f"Get SPEC image with size {spec_img.shape}") + return spec_img + + @classmethod + def rgb_data_post_process(cls, data): + if len(data) < 3: + threshold = int(float(data)) + logging.info(f"Get RGB threshold: {threshold}") + return threshold + else: + rgb_img = np.frombuffer(data, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1)) + logging.info(f"Get RGB img with size {rgb_img.shape}") + return rgb_img + + +class FifoReceiver(Transmitter): + def __init__(self, fifo_path: str, output: Queue, read_max_num: int, msg_queue=None): super().__init__() self._input_fifo_path = None self._output_queue = None + self._msg_queue = msg_queue self._max_len = read_max_num self.set_source(fifo_path) self.set_output(output) self._need_stop = threading.Event() self._need_stop.clear() + self._running_thread = None def set_source(self, fifo_path: str): if not os.access(fifo_path, os.F_OK): @@ -66,9 +105,9 @@ class FifoReceiver(Receiver): self._output_queue = output def start(self, post_process_func=None, name='fifo_receiver'): - x = threading.Thread(target=self._receive_thread_func, - name=name, args=(post_process_func, )) - x.start() + self._running_thread = threading.Thread(target=self._receive_thread_func, + name=name, args=(post_process_func, )) + self._running_thread.start() def stop(self): self._need_stop.set() @@ -85,27 +124,170 @@ class FifoReceiver(Receiver): data = os.read(input_fifo, self._max_len) if post_process_func is not None: data = post_process_func(data) - self._output_queue.put(data) + self._output_queue.safe_put(data) os.close(input_fifo) self._need_stop.clear() - @staticmethod - def spec_data_post_process(data): - if len(data) < 3: - threshold = int(float(data)) - print("[INFO] Get Spec threshold: ", threshold) - return threshold - else: - spec_img = np.frombuffer(data, dtype=np.float32).\ - reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1) - return spec_img - @staticmethod - def rgb_data_post_process(data): - if len(data) < 3: - threshold = int(float(data)) - print("[INFO] Get RGB threshold: ", threshold) - return threshold - else: - rgb_img = np.frombuffer(data, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1)) - return rgb_img +class FifoSender(Transmitter): + def __init__(self, output_fifo_path: str, source: Queue): + super().__init__() + self._input_source = None + self._output_fifo_path = None + self.set_source(source) + self.set_output(output_fifo_path) + self._need_stop = threading.Event() + self._need_stop.clear() + self._running_thread = None + + def set_source(self, source: Queue): + self._input_source = source + + def set_output(self, output_fifo_path: str): + if not os.access(output_fifo_path, os.F_OK): + os.mkfifo(output_fifo_path, 0o777) + self._output_fifo_path = output_fifo_path + + def start(self, pre_process=None, name='fifo_receiver'): + self._running_thread = threading.Thread(target=self._send_thread_func, name=name, + args=(pre_process, )) + self._running_thread.start() + + def stop(self): + self._need_stop.set() + + def _send_thread_func(self, pre_process=None): + """ + 接收线程 + + :param pre_process: + :return: + """ + while not self._need_stop.is_set(): + if self._input_source.empty(): + continue + data = self._input_source.get() + if pre_process is not None: + data = pre_process(data) + output_fifo = os.open(self._output_fifo_path, os.O_WRONLY) + os.write(output_fifo, data) + os.close(output_fifo) + self._need_stop.clear() + + def __del__(self): + self.stop() + if self._running_thread is not None: + self._running_thread.join() + + +class CmdImgSplitMidware(Transmitter): + """ + 用于控制命令和图像的中间件 + """ + def __init__(self, subscribers: typing.Dict[str, Queue], rgb_queue: Queue, spec_queue: Queue): + super().__init__() + self._rgb_queue = None + self._spec_queue = None + self._subscribers = None + self._server_thread = None + self.set_source(rgb_queue, spec_queue) + self.set_output(subscribers) + self.thread_stop = threading.Event() + + def set_source(self, rgb_queue: Queue, spec_queue: Queue): + self._rgb_queue = rgb_queue + self._spec_queue = spec_queue + + def set_output(self, output: typing.Dict[str, Queue]): + self._subscribers = output + + def start(self, name='CMD_thread'): + self._server_thread = threading.Thread(target=self._cmd_control_service, name=name) + self._server_thread.start() + + def stop(self): + self.thread_stop.set() + + def _cmd_control_service(self): + while not self.thread_stop.is_set(): + # 判断是否有数据,如果没有数据那么就等下次吧,如果有数据来,必须保证同时 + if self._rgb_queue.empty() or self._spec_queue.empty(): + continue + rgb_data = self._rgb_queue.get() + spec_data = self._spec_queue.get() + if isinstance(rgb_data, int) and isinstance(spec_data, int): + # 看是不是命令需要执行如果是命令,就执行 + Config.rgb_size_threshold = rgb_data + Config.spec_size_threshold = spec_data + continue + elif isinstance(spec_data, np.ndarray) and isinstance(rgb_data, np.ndarray): + # 如果是图片,交给预测的人 + for name, subscriber in self._subscribers.items(): + item = (spec_data, rgb_data) + subscriber.safe_put(item) + else: + # 否则程序出现毁灭性问题,立刻崩 + raise Exception("两个相机传回的数据没有对上") + self.thread_stop.clear() + + +class ImageSaver(Transmitter): + """ + 进行图片存储的中间件 + """ + def set_source(self, *args, **kwargs): + pass + + def start(self, *args, **kwargs): + pass + + def stop(self, *args, **kwargs): + pass + + +class ThreadDetector(Transmitter): + def __init__(self, input_queue: Queue, output_queue: Queue): + super().__init__() + self._input_queue, self._output_queue = input_queue, output_queue + self._spec_detector = SpecDetector(blk_model_path=Config.blk_model_path, + pixel_model_path=Config.pixel_model_path) + self._rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path, + background_model_path=Config.rgb_background_model_path) + self._predict_thread = None + self._thread_exit = threading.Event() + + def set_source(self, img_queue: Queue): + self._input_queue = img_queue + + def stop(self, *args, **kwargs): + self._thread_exit.set() + + def start(self, name='predict_thread'): + self._predict_thread = threading.Thread(target=self._predict_server, name=name) + self._predict_thread.start() + + def predict(self, spec: np.ndarray, rgb: np.ndarray): + logging.info(f'Detector get image with shape {spec.shape} and {rgb.shape}') + t1 = time.time() + mask = self._spec_detector.predict(spec) + t2 = time.time() + logging.info(f'Detector finish spec predict within {(t2 - t1) * 1000:.2f}ms') + # rgb识别 + mask_rgb = self._rgb_detector.predict(rgb) + t3 = time.time() + logging.info(f'Detector finish rgb predict within {(t3 - t2) * 1000:.2f}ms') + # 结果合并 + mask_result = (mask | mask_rgb).astype(np.uint8) + mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8) + t4 = time.time() + logging.info(f'Detector finish merge within {(t4 - t3) * 1000: .2f}ms') + logging.info(f'Detector finish predict within {(time.time() -t1)*1000:.2f}ms') + return mask_result + + def _predict_server(self): + while not self._thread_exit.is_set(): + if not self._input_queue.empty(): + spec, rgb = self._input_queue.get() + mask = self.predict(spec, rgb) + self._output_queue.safe_put(mask) + self._thread_exit.clear() diff --git a/utils.py b/utils.py index 25607e1..bf466f8 100755 --- a/utils.py +++ b/utils.py @@ -5,6 +5,7 @@ # @Software:PyCharm import glob import os +from queue import Queue import cv2 import numpy as np @@ -26,6 +27,34 @@ class MergeDict(dict): return self +class ImgQueue(Queue): + """ + A custom queue subclass that provides a :meth:`clear` method. + """ + + def clear(self): + """ + Clears all items from the queue. + """ + + with self.mutex: + unfinished = self.unfinished_tasks - len(self.queue) + if unfinished <= 0: + if unfinished < 0: + raise ValueError('task_done() called too many times') + self.all_tasks_done.notify_all() + self.unfinished_tasks = unfinished + self.queue.clear() + self.not_full.notify_all() + + def safe_put(self, item): + if self.full(): + _ = self.get() + return False + self.put(item) + return True + + def read_labeled_img(dataset_dir: str, color_dict: dict, ext='.bmp', is_ps_color_space=True) -> dict: """ 根据dataset_dir下的文件创建数据集