From 1a19246dc3cefd6efd3d60db71751a2eca4eba6d Mon Sep 17 00:00:00 2001 From: "li.zhenye" Date: Fri, 29 Jul 2022 16:06:11 +0800 Subject: [PATCH] =?UTF-8?q?=E8=AE=A1=E7=AE=97=E8=AF=AF=E5=B7=AE=E5=AE=8C?= =?UTF-8?q?=E6=88=90=E7=89=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main_test.py | 49 ++++++++++++++++++++++------------- transmit.py | 73 +++++++++++++++++++++++++++++++++++++++++++--------- utils.py | 1 + 3 files changed, 93 insertions(+), 30 deletions(-) diff --git a/main_test.py b/main_test.py index 8ef72f9..109ec0d 100644 --- a/main_test.py +++ b/main_test.py @@ -66,11 +66,6 @@ class TestMain: if get_delta: spec_cv = np.clip(spec_img[..., [21, 3, 0]], a_min=0, a_max=1) * 255 spec_cv = spec_cv.astype(np.uint8) - plt.imshow(spec_cv) - plt.show() - spec_cv = spec_cv.astype(np.uint8) - plt.imshow(spec_cv) - plt.show() delta = self.calculate_delta(rgb_img, spec_cv) print(delta) self.merge(rgb_img=rgb_img, rgb_mask=rgb_mask, @@ -112,25 +107,43 @@ class TestMain: plt.show() return mask_result - def calculate_delta(self, rgb_img, spec_img): + def calculate_delta(self, rgb_img, spec_img, search_area_size=(200, 200), eps=1): rgb_grey, spec_grey = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2GRAY), cv2.cvtColor(spec_img, cv2.COLOR_RGB2GRAY) _, rgb_bin = cv2.threshold(rgb_grey, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) _, spec_bin = cv2.threshold(spec_grey, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) spec_bin = cv2.resize(spec_bin, dsize=(rgb_bin.shape[1], rgb_bin.shape[0])) - fig, axs = plt.subplots(2, 1) - axs[0].imshow(rgb_bin) - axs[1].imshow(spec_bin) - plt.show() - search_area = np.zeros_like(spec_bin) - for x in range(0, spec_bin.shape[0] // 10, 10): - for y in range(0, spec_bin.shape[1] // 10, 10): - delta_x, delta_y = x - spec_bin.shape[0] // 2, y - spec_bin.shape[1] // 2 + search_area = np.zeros(search_area_size) + for x in range(0, search_area_size[0], eps): + for y in range(0, search_area_size[1], eps): + delta_x, delta_y = x - search_area_size[0] // 2, y - search_area_size[1] // 2 rgb_cross_area = self.get_cross_area(rgb_bin, delta_x, delta_y) spce_cross_area = self.get_cross_area(spec_bin, -delta_x, -delta_y) response_altitude = np.sum(np.sum(rgb_cross_area & spce_cross_area)) search_area[x, y] = response_altitude - delta = np.argmax(search_area) - print(delta) + delta = np.unravel_index(np.argmax(search_area), search_area.shape) + delta = (delta[0] - search_area_size[1]//2, delta[1] - search_area_size[1]//2) + delta_x, delta_y = delta + + rgb_cross_area = self.get_cross_area(rgb_bin, delta_x, delta_y) + spce_cross_area = self.get_cross_area(spec_bin, -delta_x, -delta_y) + + human_word = "SPEC is " + str(abs(delta_x)) + " pixels " + human_word += 'after' if delta_x >= 0 else ' before ' + human_word += "RGB and " + str(abs(delta_y)) + " pixels " + human_word += "right " if delta_y >= 0 else "left " + human_word += "the RGB" + + fig, axs = plt.subplots(3, 1) + axs[0].imshow(rgb_img) + axs[0].set_title("RGB img") + axs[1].imshow(spec_img) + axs[1].set_title("spec img") + axs[2].imshow(rgb_cross_area & spce_cross_area) + axs[2].set_title("cross part") + plt.suptitle(human_word) + plt.show() + + print(human_word) return delta @staticmethod @@ -148,5 +161,5 @@ class TestMain: if __name__ == '__main__': testor = TestMain() - testor.pony_run(test_path=r'E:\zhouchao\728-tobacco\correct', - test_rgb=True, test_spectra=True, get_delta=True) + testor.pony_run(test_path=r'/Volumes/LENOVO_USB_HDD/zhouchao/728-tobacco/correct', + test_rgb=True, test_spectra=True, get_delta=False) diff --git a/transmit.py b/transmit.py index 099befd..a118a1a 100644 --- a/transmit.py +++ b/transmit.py @@ -1,15 +1,16 @@ import os import threading +from multiprocessing import Process, Queue import time -from utils import ImgQueue as Queue +from utils import ImgQueue as ImgQueue import numpy as np from config import Config from models import SpecDetector, RgbDetector import typing import logging logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s %(message)s', - level=logging.INFO) + level=logging.WARNING) class Transmitter(object): @@ -25,7 +26,7 @@ class Transmitter(object): """ raise NotImplementedError - def set_output(self, output: Queue): + def set_output(self, output: ImgQueue): """ 设置单个输出源 :param output: @@ -83,7 +84,7 @@ class BeforeAfterMethods: class FifoReceiver(Transmitter): - def __init__(self, fifo_path: str, output: Queue, read_max_num: int, msg_queue=None): + def __init__(self, fifo_path: str, output: ImgQueue, read_max_num: int, msg_queue=None): super().__init__() self._input_fifo_path = None self._output_queue = None @@ -101,7 +102,7 @@ class FifoReceiver(Transmitter): os.mkfifo(fifo_path, 0o777) self._input_fifo_path = fifo_path - def set_output(self, output: Queue): + def set_output(self, output: ImgQueue): self._output_queue = output def start(self, post_process_func=None, name='fifo_receiver'): @@ -130,7 +131,7 @@ class FifoReceiver(Transmitter): class FifoSender(Transmitter): - def __init__(self, output_fifo_path: str, source: Queue): + def __init__(self, output_fifo_path: str, source: ImgQueue): super().__init__() self._input_source = None self._output_fifo_path = None @@ -140,7 +141,7 @@ class FifoSender(Transmitter): self._need_stop.clear() self._running_thread = None - def set_source(self, source: Queue): + def set_source(self, source: ImgQueue): self._input_source = source def set_output(self, output_fifo_path: str): @@ -184,7 +185,7 @@ class CmdImgSplitMidware(Transmitter): """ 用于控制命令和图像的中间件 """ - def __init__(self, subscribers: typing.Dict[str, Queue], rgb_queue: Queue, spec_queue: Queue): + def __init__(self, subscribers: typing.Dict[str, ImgQueue], rgb_queue: ImgQueue, spec_queue: ImgQueue): super().__init__() self._rgb_queue = None self._spec_queue = None @@ -194,11 +195,11 @@ class CmdImgSplitMidware(Transmitter): self.set_output(subscribers) self.thread_stop = threading.Event() - def set_source(self, rgb_queue: Queue, spec_queue: Queue): + def set_source(self, rgb_queue: ImgQueue, spec_queue: ImgQueue): self._rgb_queue = rgb_queue self._spec_queue = spec_queue - def set_output(self, output: typing.Dict[str, Queue]): + def set_output(self, output: typing.Dict[str, ImgQueue]): self._subscribers = output def start(self, name='CMD_thread'): @@ -246,7 +247,7 @@ class ImageSaver(Transmitter): class ThreadDetector(Transmitter): - def __init__(self, input_queue: Queue, output_queue: Queue): + def __init__(self, input_queue: ImgQueue, output_queue: ImgQueue): super().__init__() self._input_queue, self._output_queue = input_queue, output_queue self._spec_detector = SpecDetector(blk_model_path=Config.blk_model_path, @@ -256,7 +257,7 @@ class ThreadDetector(Transmitter): self._predict_thread = None self._thread_exit = threading.Event() - def set_source(self, img_queue: Queue): + def set_source(self, img_queue: ImgQueue): self._input_queue = img_queue def stop(self, *args, **kwargs): @@ -291,3 +292,51 @@ class ThreadDetector(Transmitter): mask = self.predict(spec, rgb) self._output_queue.safe_put(mask) self._thread_exit.clear() + + +class ProcessDetector(Transmitter): + def __init__(self, input_queue: Queue, output_queue: Queue): + super().__init__() + self._input_queue, self._output_queue = input_queue, output_queue + self._spec_detector = SpecDetector(blk_model_path=Config.blk_model_path, + pixel_model_path=Config.pixel_model_path) + self._rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path, + background_model_path=Config.rgb_background_model_path) + self._predict_thread = None + self._thread_exit = threading.Event() + + def set_source(self, img_queue: Queue): + self._input_queue = img_queue + + def stop(self, *args, **kwargs): + self._thread_exit.set() + + def start(self, name='predict_thread'): + self._predict_thread = Process(target=self._predict_server, name=name, daemon=True) + self._predict_thread.start() + + def predict(self, spec: np.ndarray, rgb: np.ndarray): + logging.info(f'Detector get image with shape {spec.shape} and {rgb.shape}') + t1 = time.time() + mask = self._spec_detector.predict(spec) + t2 = time.time() + logging.info(f'Detector finish spec predict within {(t2 - t1) * 1000:.2f}ms') + # rgb识别 + mask_rgb = self._rgb_detector.predict(rgb) + t3 = time.time() + logging.info(f'Detector finish rgb predict within {(t3 - t2) * 1000:.2f}ms') + # 结果合并 + mask_result = (mask | mask_rgb).astype(np.uint8) + mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8) + t4 = time.time() + logging.info(f'Detector finish merge within {(t4 - t3) * 1000: .2f}ms') + logging.info(f'Detector finish predict within {(time.time() -t1)*1000:.2f}ms') + return mask_result + + def _predict_server(self): + while not self._thread_exit.is_set(): + if not self._input_queue.empty(): + spec, rgb = self._input_queue.get() + mask = self.predict(spec, rgb) + self._output_queue.put(mask) + self._thread_exit.clear() \ No newline at end of file diff --git a/utils.py b/utils.py index 9359bc4..2800eaf 100755 --- a/utils.py +++ b/utils.py @@ -140,6 +140,7 @@ def size_threshold(img, blk_size, threshold): return mask + if __name__ == '__main__': color_dict = {(0, 0, 255): "yangeng", (255, 0, 0): "bejing", (0, 255, 0): "hongdianxian", (255, 0, 255): "chengsebangbangtang", (0, 255, 255): "lvdianxian"}