计算误差完成版

This commit is contained in:
li.zhenye 2022-07-29 16:06:11 +08:00
parent 144b88543d
commit 1a19246dc3
3 changed files with 93 additions and 30 deletions

View File

@ -66,11 +66,6 @@ class TestMain:
if get_delta: if get_delta:
spec_cv = np.clip(spec_img[..., [21, 3, 0]], a_min=0, a_max=1) * 255 spec_cv = np.clip(spec_img[..., [21, 3, 0]], a_min=0, a_max=1) * 255
spec_cv = spec_cv.astype(np.uint8) spec_cv = spec_cv.astype(np.uint8)
plt.imshow(spec_cv)
plt.show()
spec_cv = spec_cv.astype(np.uint8)
plt.imshow(spec_cv)
plt.show()
delta = self.calculate_delta(rgb_img, spec_cv) delta = self.calculate_delta(rgb_img, spec_cv)
print(delta) print(delta)
self.merge(rgb_img=rgb_img, rgb_mask=rgb_mask, self.merge(rgb_img=rgb_img, rgb_mask=rgb_mask,
@ -112,25 +107,43 @@ class TestMain:
plt.show() plt.show()
return mask_result return mask_result
def calculate_delta(self, rgb_img, spec_img): def calculate_delta(self, rgb_img, spec_img, search_area_size=(200, 200), eps=1):
rgb_grey, spec_grey = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2GRAY), cv2.cvtColor(spec_img, cv2.COLOR_RGB2GRAY) rgb_grey, spec_grey = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2GRAY), cv2.cvtColor(spec_img, cv2.COLOR_RGB2GRAY)
_, rgb_bin = cv2.threshold(rgb_grey, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) _, rgb_bin = cv2.threshold(rgb_grey, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
_, spec_bin = cv2.threshold(spec_grey, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) _, spec_bin = cv2.threshold(spec_grey, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
spec_bin = cv2.resize(spec_bin, dsize=(rgb_bin.shape[1], rgb_bin.shape[0])) spec_bin = cv2.resize(spec_bin, dsize=(rgb_bin.shape[1], rgb_bin.shape[0]))
fig, axs = plt.subplots(2, 1) search_area = np.zeros(search_area_size)
axs[0].imshow(rgb_bin) for x in range(0, search_area_size[0], eps):
axs[1].imshow(spec_bin) for y in range(0, search_area_size[1], eps):
plt.show() delta_x, delta_y = x - search_area_size[0] // 2, y - search_area_size[1] // 2
search_area = np.zeros_like(spec_bin)
for x in range(0, spec_bin.shape[0] // 10, 10):
for y in range(0, spec_bin.shape[1] // 10, 10):
delta_x, delta_y = x - spec_bin.shape[0] // 2, y - spec_bin.shape[1] // 2
rgb_cross_area = self.get_cross_area(rgb_bin, delta_x, delta_y) rgb_cross_area = self.get_cross_area(rgb_bin, delta_x, delta_y)
spce_cross_area = self.get_cross_area(spec_bin, -delta_x, -delta_y) spce_cross_area = self.get_cross_area(spec_bin, -delta_x, -delta_y)
response_altitude = np.sum(np.sum(rgb_cross_area & spce_cross_area)) response_altitude = np.sum(np.sum(rgb_cross_area & spce_cross_area))
search_area[x, y] = response_altitude search_area[x, y] = response_altitude
delta = np.argmax(search_area) delta = np.unravel_index(np.argmax(search_area), search_area.shape)
print(delta) delta = (delta[0] - search_area_size[1]//2, delta[1] - search_area_size[1]//2)
delta_x, delta_y = delta
rgb_cross_area = self.get_cross_area(rgb_bin, delta_x, delta_y)
spce_cross_area = self.get_cross_area(spec_bin, -delta_x, -delta_y)
human_word = "SPEC is " + str(abs(delta_x)) + " pixels "
human_word += 'after' if delta_x >= 0 else ' before '
human_word += "RGB and " + str(abs(delta_y)) + " pixels "
human_word += "right " if delta_y >= 0 else "left "
human_word += "the RGB"
fig, axs = plt.subplots(3, 1)
axs[0].imshow(rgb_img)
axs[0].set_title("RGB img")
axs[1].imshow(spec_img)
axs[1].set_title("spec img")
axs[2].imshow(rgb_cross_area & spce_cross_area)
axs[2].set_title("cross part")
plt.suptitle(human_word)
plt.show()
print(human_word)
return delta return delta
@staticmethod @staticmethod
@ -148,5 +161,5 @@ class TestMain:
if __name__ == '__main__': if __name__ == '__main__':
testor = TestMain() testor = TestMain()
testor.pony_run(test_path=r'E:\zhouchao\728-tobacco\correct', testor.pony_run(test_path=r'/Volumes/LENOVO_USB_HDD/zhouchao/728-tobacco/correct',
test_rgb=True, test_spectra=True, get_delta=True) test_rgb=True, test_spectra=True, get_delta=False)

View File

@ -1,15 +1,16 @@
import os import os
import threading import threading
from multiprocessing import Process, Queue
import time import time
from utils import ImgQueue as Queue from utils import ImgQueue as ImgQueue
import numpy as np import numpy as np
from config import Config from config import Config
from models import SpecDetector, RgbDetector from models import SpecDetector, RgbDetector
import typing import typing
import logging import logging
logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s %(message)s', logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s %(message)s',
level=logging.INFO) level=logging.WARNING)
class Transmitter(object): class Transmitter(object):
@ -25,7 +26,7 @@ class Transmitter(object):
""" """
raise NotImplementedError raise NotImplementedError
def set_output(self, output: Queue): def set_output(self, output: ImgQueue):
""" """
设置单个输出源 设置单个输出源
:param output: :param output:
@ -83,7 +84,7 @@ class BeforeAfterMethods:
class FifoReceiver(Transmitter): class FifoReceiver(Transmitter):
def __init__(self, fifo_path: str, output: Queue, read_max_num: int, msg_queue=None): def __init__(self, fifo_path: str, output: ImgQueue, read_max_num: int, msg_queue=None):
super().__init__() super().__init__()
self._input_fifo_path = None self._input_fifo_path = None
self._output_queue = None self._output_queue = None
@ -101,7 +102,7 @@ class FifoReceiver(Transmitter):
os.mkfifo(fifo_path, 0o777) os.mkfifo(fifo_path, 0o777)
self._input_fifo_path = fifo_path self._input_fifo_path = fifo_path
def set_output(self, output: Queue): def set_output(self, output: ImgQueue):
self._output_queue = output self._output_queue = output
def start(self, post_process_func=None, name='fifo_receiver'): def start(self, post_process_func=None, name='fifo_receiver'):
@ -130,7 +131,7 @@ class FifoReceiver(Transmitter):
class FifoSender(Transmitter): class FifoSender(Transmitter):
def __init__(self, output_fifo_path: str, source: Queue): def __init__(self, output_fifo_path: str, source: ImgQueue):
super().__init__() super().__init__()
self._input_source = None self._input_source = None
self._output_fifo_path = None self._output_fifo_path = None
@ -140,7 +141,7 @@ class FifoSender(Transmitter):
self._need_stop.clear() self._need_stop.clear()
self._running_thread = None self._running_thread = None
def set_source(self, source: Queue): def set_source(self, source: ImgQueue):
self._input_source = source self._input_source = source
def set_output(self, output_fifo_path: str): def set_output(self, output_fifo_path: str):
@ -184,7 +185,7 @@ class CmdImgSplitMidware(Transmitter):
""" """
用于控制命令和图像的中间件 用于控制命令和图像的中间件
""" """
def __init__(self, subscribers: typing.Dict[str, Queue], rgb_queue: Queue, spec_queue: Queue): def __init__(self, subscribers: typing.Dict[str, ImgQueue], rgb_queue: ImgQueue, spec_queue: ImgQueue):
super().__init__() super().__init__()
self._rgb_queue = None self._rgb_queue = None
self._spec_queue = None self._spec_queue = None
@ -194,11 +195,11 @@ class CmdImgSplitMidware(Transmitter):
self.set_output(subscribers) self.set_output(subscribers)
self.thread_stop = threading.Event() self.thread_stop = threading.Event()
def set_source(self, rgb_queue: Queue, spec_queue: Queue): def set_source(self, rgb_queue: ImgQueue, spec_queue: ImgQueue):
self._rgb_queue = rgb_queue self._rgb_queue = rgb_queue
self._spec_queue = spec_queue self._spec_queue = spec_queue
def set_output(self, output: typing.Dict[str, Queue]): def set_output(self, output: typing.Dict[str, ImgQueue]):
self._subscribers = output self._subscribers = output
def start(self, name='CMD_thread'): def start(self, name='CMD_thread'):
@ -246,7 +247,7 @@ class ImageSaver(Transmitter):
class ThreadDetector(Transmitter): class ThreadDetector(Transmitter):
def __init__(self, input_queue: Queue, output_queue: Queue): def __init__(self, input_queue: ImgQueue, output_queue: ImgQueue):
super().__init__() super().__init__()
self._input_queue, self._output_queue = input_queue, output_queue self._input_queue, self._output_queue = input_queue, output_queue
self._spec_detector = SpecDetector(blk_model_path=Config.blk_model_path, self._spec_detector = SpecDetector(blk_model_path=Config.blk_model_path,
@ -256,7 +257,7 @@ class ThreadDetector(Transmitter):
self._predict_thread = None self._predict_thread = None
self._thread_exit = threading.Event() self._thread_exit = threading.Event()
def set_source(self, img_queue: Queue): def set_source(self, img_queue: ImgQueue):
self._input_queue = img_queue self._input_queue = img_queue
def stop(self, *args, **kwargs): def stop(self, *args, **kwargs):
@ -291,3 +292,51 @@ class ThreadDetector(Transmitter):
mask = self.predict(spec, rgb) mask = self.predict(spec, rgb)
self._output_queue.safe_put(mask) self._output_queue.safe_put(mask)
self._thread_exit.clear() self._thread_exit.clear()
class ProcessDetector(Transmitter):
def __init__(self, input_queue: Queue, output_queue: Queue):
super().__init__()
self._input_queue, self._output_queue = input_queue, output_queue
self._spec_detector = SpecDetector(blk_model_path=Config.blk_model_path,
pixel_model_path=Config.pixel_model_path)
self._rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path,
background_model_path=Config.rgb_background_model_path)
self._predict_thread = None
self._thread_exit = threading.Event()
def set_source(self, img_queue: Queue):
self._input_queue = img_queue
def stop(self, *args, **kwargs):
self._thread_exit.set()
def start(self, name='predict_thread'):
self._predict_thread = Process(target=self._predict_server, name=name, daemon=True)
self._predict_thread.start()
def predict(self, spec: np.ndarray, rgb: np.ndarray):
logging.info(f'Detector get image with shape {spec.shape} and {rgb.shape}')
t1 = time.time()
mask = self._spec_detector.predict(spec)
t2 = time.time()
logging.info(f'Detector finish spec predict within {(t2 - t1) * 1000:.2f}ms')
# rgb识别
mask_rgb = self._rgb_detector.predict(rgb)
t3 = time.time()
logging.info(f'Detector finish rgb predict within {(t3 - t2) * 1000:.2f}ms')
# 结果合并
mask_result = (mask | mask_rgb).astype(np.uint8)
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
t4 = time.time()
logging.info(f'Detector finish merge within {(t4 - t3) * 1000: .2f}ms')
logging.info(f'Detector finish predict within {(time.time() -t1)*1000:.2f}ms')
return mask_result
def _predict_server(self):
while not self._thread_exit.is_set():
if not self._input_queue.empty():
spec, rgb = self._input_queue.get()
mask = self.predict(spec, rgb)
self._output_queue.put(mask)
self._thread_exit.clear()

View File

@ -140,6 +140,7 @@ def size_threshold(img, blk_size, threshold):
return mask return mask
if __name__ == '__main__': if __name__ == '__main__':
color_dict = {(0, 0, 255): "yangeng", (255, 0, 0): "bejing", (0, 255, 0): "hongdianxian", color_dict = {(0, 0, 255): "yangeng", (255, 0, 0): "bejing", (0, 255, 0): "hongdianxian",
(255, 0, 255): "chengsebangbangtang", (0, 255, 255): "lvdianxian"} (255, 0, 255): "chengsebangbangtang", (0, 255, 255): "lvdianxian"}