From 1a19246dc3cefd6efd3d60db71751a2eca4eba6d Mon Sep 17 00:00:00 2001
From: "li.zhenye"
Date: Fri, 29 Jul 2022 16:06:11 +0800
Subject: [PATCH] =?UTF-8?q?=E8=AE=A1=E7=AE=97=E8=AF=AF=E5=B7=AE=E5=AE=8C?=
=?UTF-8?q?=E6=88=90=E7=89=88?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
main_test.py | 49 ++++++++++++++++++++++-------------
transmit.py | 73 +++++++++++++++++++++++++++++++++++++++++++---------
utils.py | 1 +
3 files changed, 93 insertions(+), 30 deletions(-)
diff --git a/main_test.py b/main_test.py
index 8ef72f9..109ec0d 100644
--- a/main_test.py
+++ b/main_test.py
@@ -66,11 +66,6 @@ class TestMain:
if get_delta:
spec_cv = np.clip(spec_img[..., [21, 3, 0]], a_min=0, a_max=1) * 255
spec_cv = spec_cv.astype(np.uint8)
- plt.imshow(spec_cv)
- plt.show()
- spec_cv = spec_cv.astype(np.uint8)
- plt.imshow(spec_cv)
- plt.show()
delta = self.calculate_delta(rgb_img, spec_cv)
print(delta)
self.merge(rgb_img=rgb_img, rgb_mask=rgb_mask,
@@ -112,25 +107,43 @@ class TestMain:
plt.show()
return mask_result
- def calculate_delta(self, rgb_img, spec_img):
+ def calculate_delta(self, rgb_img, spec_img, search_area_size=(200, 200), eps=1):
rgb_grey, spec_grey = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2GRAY), cv2.cvtColor(spec_img, cv2.COLOR_RGB2GRAY)
_, rgb_bin = cv2.threshold(rgb_grey, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
_, spec_bin = cv2.threshold(spec_grey, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
spec_bin = cv2.resize(spec_bin, dsize=(rgb_bin.shape[1], rgb_bin.shape[0]))
- fig, axs = plt.subplots(2, 1)
- axs[0].imshow(rgb_bin)
- axs[1].imshow(spec_bin)
- plt.show()
- search_area = np.zeros_like(spec_bin)
- for x in range(0, spec_bin.shape[0] // 10, 10):
- for y in range(0, spec_bin.shape[1] // 10, 10):
- delta_x, delta_y = x - spec_bin.shape[0] // 2, y - spec_bin.shape[1] // 2
+ search_area = np.zeros(search_area_size)
+ for x in range(0, search_area_size[0], eps):
+ for y in range(0, search_area_size[1], eps):
+ delta_x, delta_y = x - search_area_size[0] // 2, y - search_area_size[1] // 2
rgb_cross_area = self.get_cross_area(rgb_bin, delta_x, delta_y)
spce_cross_area = self.get_cross_area(spec_bin, -delta_x, -delta_y)
response_altitude = np.sum(np.sum(rgb_cross_area & spce_cross_area))
search_area[x, y] = response_altitude
- delta = np.argmax(search_area)
- print(delta)
+ delta = np.unravel_index(np.argmax(search_area), search_area.shape)
+ delta = (delta[0] - search_area_size[1]//2, delta[1] - search_area_size[1]//2)
+ delta_x, delta_y = delta
+
+ rgb_cross_area = self.get_cross_area(rgb_bin, delta_x, delta_y)
+ spce_cross_area = self.get_cross_area(spec_bin, -delta_x, -delta_y)
+
+ human_word = "SPEC is " + str(abs(delta_x)) + " pixels "
+ human_word += 'after' if delta_x >= 0 else ' before '
+ human_word += "RGB and " + str(abs(delta_y)) + " pixels "
+ human_word += "right " if delta_y >= 0 else "left "
+ human_word += "the RGB"
+
+ fig, axs = plt.subplots(3, 1)
+ axs[0].imshow(rgb_img)
+ axs[0].set_title("RGB img")
+ axs[1].imshow(spec_img)
+ axs[1].set_title("spec img")
+ axs[2].imshow(rgb_cross_area & spce_cross_area)
+ axs[2].set_title("cross part")
+ plt.suptitle(human_word)
+ plt.show()
+
+ print(human_word)
return delta
@staticmethod
@@ -148,5 +161,5 @@ class TestMain:
if __name__ == '__main__':
testor = TestMain()
- testor.pony_run(test_path=r'E:\zhouchao\728-tobacco\correct',
- test_rgb=True, test_spectra=True, get_delta=True)
+ testor.pony_run(test_path=r'/Volumes/LENOVO_USB_HDD/zhouchao/728-tobacco/correct',
+ test_rgb=True, test_spectra=True, get_delta=False)
diff --git a/transmit.py b/transmit.py
index 099befd..a118a1a 100644
--- a/transmit.py
+++ b/transmit.py
@@ -1,15 +1,16 @@
import os
import threading
+from multiprocessing import Process, Queue
import time
-from utils import ImgQueue as Queue
+from utils import ImgQueue as ImgQueue
import numpy as np
from config import Config
from models import SpecDetector, RgbDetector
import typing
import logging
logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s %(message)s',
- level=logging.INFO)
+ level=logging.WARNING)
class Transmitter(object):
@@ -25,7 +26,7 @@ class Transmitter(object):
"""
raise NotImplementedError
- def set_output(self, output: Queue):
+ def set_output(self, output: ImgQueue):
"""
设置单个输出源
:param output:
@@ -83,7 +84,7 @@ class BeforeAfterMethods:
class FifoReceiver(Transmitter):
- def __init__(self, fifo_path: str, output: Queue, read_max_num: int, msg_queue=None):
+ def __init__(self, fifo_path: str, output: ImgQueue, read_max_num: int, msg_queue=None):
super().__init__()
self._input_fifo_path = None
self._output_queue = None
@@ -101,7 +102,7 @@ class FifoReceiver(Transmitter):
os.mkfifo(fifo_path, 0o777)
self._input_fifo_path = fifo_path
- def set_output(self, output: Queue):
+ def set_output(self, output: ImgQueue):
self._output_queue = output
def start(self, post_process_func=None, name='fifo_receiver'):
@@ -130,7 +131,7 @@ class FifoReceiver(Transmitter):
class FifoSender(Transmitter):
- def __init__(self, output_fifo_path: str, source: Queue):
+ def __init__(self, output_fifo_path: str, source: ImgQueue):
super().__init__()
self._input_source = None
self._output_fifo_path = None
@@ -140,7 +141,7 @@ class FifoSender(Transmitter):
self._need_stop.clear()
self._running_thread = None
- def set_source(self, source: Queue):
+ def set_source(self, source: ImgQueue):
self._input_source = source
def set_output(self, output_fifo_path: str):
@@ -184,7 +185,7 @@ class CmdImgSplitMidware(Transmitter):
"""
用于控制命令和图像的中间件
"""
- def __init__(self, subscribers: typing.Dict[str, Queue], rgb_queue: Queue, spec_queue: Queue):
+ def __init__(self, subscribers: typing.Dict[str, ImgQueue], rgb_queue: ImgQueue, spec_queue: ImgQueue):
super().__init__()
self._rgb_queue = None
self._spec_queue = None
@@ -194,11 +195,11 @@ class CmdImgSplitMidware(Transmitter):
self.set_output(subscribers)
self.thread_stop = threading.Event()
- def set_source(self, rgb_queue: Queue, spec_queue: Queue):
+ def set_source(self, rgb_queue: ImgQueue, spec_queue: ImgQueue):
self._rgb_queue = rgb_queue
self._spec_queue = spec_queue
- def set_output(self, output: typing.Dict[str, Queue]):
+ def set_output(self, output: typing.Dict[str, ImgQueue]):
self._subscribers = output
def start(self, name='CMD_thread'):
@@ -246,7 +247,7 @@ class ImageSaver(Transmitter):
class ThreadDetector(Transmitter):
- def __init__(self, input_queue: Queue, output_queue: Queue):
+ def __init__(self, input_queue: ImgQueue, output_queue: ImgQueue):
super().__init__()
self._input_queue, self._output_queue = input_queue, output_queue
self._spec_detector = SpecDetector(blk_model_path=Config.blk_model_path,
@@ -256,7 +257,7 @@ class ThreadDetector(Transmitter):
self._predict_thread = None
self._thread_exit = threading.Event()
- def set_source(self, img_queue: Queue):
+ def set_source(self, img_queue: ImgQueue):
self._input_queue = img_queue
def stop(self, *args, **kwargs):
@@ -291,3 +292,51 @@ class ThreadDetector(Transmitter):
mask = self.predict(spec, rgb)
self._output_queue.safe_put(mask)
self._thread_exit.clear()
+
+
+class ProcessDetector(Transmitter):
+ def __init__(self, input_queue: Queue, output_queue: Queue):
+ super().__init__()
+ self._input_queue, self._output_queue = input_queue, output_queue
+ self._spec_detector = SpecDetector(blk_model_path=Config.blk_model_path,
+ pixel_model_path=Config.pixel_model_path)
+ self._rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path,
+ background_model_path=Config.rgb_background_model_path)
+ self._predict_thread = None
+ self._thread_exit = threading.Event()
+
+ def set_source(self, img_queue: Queue):
+ self._input_queue = img_queue
+
+ def stop(self, *args, **kwargs):
+ self._thread_exit.set()
+
+ def start(self, name='predict_thread'):
+ self._predict_thread = Process(target=self._predict_server, name=name, daemon=True)
+ self._predict_thread.start()
+
+ def predict(self, spec: np.ndarray, rgb: np.ndarray):
+ logging.info(f'Detector get image with shape {spec.shape} and {rgb.shape}')
+ t1 = time.time()
+ mask = self._spec_detector.predict(spec)
+ t2 = time.time()
+ logging.info(f'Detector finish spec predict within {(t2 - t1) * 1000:.2f}ms')
+ # rgb识别
+ mask_rgb = self._rgb_detector.predict(rgb)
+ t3 = time.time()
+ logging.info(f'Detector finish rgb predict within {(t3 - t2) * 1000:.2f}ms')
+ # 结果合并
+ mask_result = (mask | mask_rgb).astype(np.uint8)
+ mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
+ t4 = time.time()
+ logging.info(f'Detector finish merge within {(t4 - t3) * 1000: .2f}ms')
+ logging.info(f'Detector finish predict within {(time.time() -t1)*1000:.2f}ms')
+ return mask_result
+
+ def _predict_server(self):
+ while not self._thread_exit.is_set():
+ if not self._input_queue.empty():
+ spec, rgb = self._input_queue.get()
+ mask = self.predict(spec, rgb)
+ self._output_queue.put(mask)
+ self._thread_exit.clear()
\ No newline at end of file
diff --git a/utils.py b/utils.py
index 9359bc4..2800eaf 100755
--- a/utils.py
+++ b/utils.py
@@ -140,6 +140,7 @@ def size_threshold(img, blk_size, threshold):
return mask
+
if __name__ == '__main__':
color_dict = {(0, 0, 255): "yangeng", (255, 0, 0): "bejing", (0, 255, 0): "hongdianxian",
(255, 0, 255): "chengsebangbangtang", (0, 255, 255): "lvdianxian"}