计算误差完成版

This commit is contained in:
li.zhenye 2022-07-29 16:06:11 +08:00
parent 144b88543d
commit 1a19246dc3
3 changed files with 93 additions and 30 deletions

View File

@ -66,11 +66,6 @@ class TestMain:
if get_delta:
spec_cv = np.clip(spec_img[..., [21, 3, 0]], a_min=0, a_max=1) * 255
spec_cv = spec_cv.astype(np.uint8)
plt.imshow(spec_cv)
plt.show()
spec_cv = spec_cv.astype(np.uint8)
plt.imshow(spec_cv)
plt.show()
delta = self.calculate_delta(rgb_img, spec_cv)
print(delta)
self.merge(rgb_img=rgb_img, rgb_mask=rgb_mask,
@ -112,25 +107,43 @@ class TestMain:
plt.show()
return mask_result
def calculate_delta(self, rgb_img, spec_img):
def calculate_delta(self, rgb_img, spec_img, search_area_size=(200, 200), eps=1):
rgb_grey, spec_grey = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2GRAY), cv2.cvtColor(spec_img, cv2.COLOR_RGB2GRAY)
_, rgb_bin = cv2.threshold(rgb_grey, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
_, spec_bin = cv2.threshold(spec_grey, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
spec_bin = cv2.resize(spec_bin, dsize=(rgb_bin.shape[1], rgb_bin.shape[0]))
fig, axs = plt.subplots(2, 1)
axs[0].imshow(rgb_bin)
axs[1].imshow(spec_bin)
plt.show()
search_area = np.zeros_like(spec_bin)
for x in range(0, spec_bin.shape[0] // 10, 10):
for y in range(0, spec_bin.shape[1] // 10, 10):
delta_x, delta_y = x - spec_bin.shape[0] // 2, y - spec_bin.shape[1] // 2
search_area = np.zeros(search_area_size)
for x in range(0, search_area_size[0], eps):
for y in range(0, search_area_size[1], eps):
delta_x, delta_y = x - search_area_size[0] // 2, y - search_area_size[1] // 2
rgb_cross_area = self.get_cross_area(rgb_bin, delta_x, delta_y)
spce_cross_area = self.get_cross_area(spec_bin, -delta_x, -delta_y)
response_altitude = np.sum(np.sum(rgb_cross_area & spce_cross_area))
search_area[x, y] = response_altitude
delta = np.argmax(search_area)
print(delta)
delta = np.unravel_index(np.argmax(search_area), search_area.shape)
delta = (delta[0] - search_area_size[1]//2, delta[1] - search_area_size[1]//2)
delta_x, delta_y = delta
rgb_cross_area = self.get_cross_area(rgb_bin, delta_x, delta_y)
spce_cross_area = self.get_cross_area(spec_bin, -delta_x, -delta_y)
human_word = "SPEC is " + str(abs(delta_x)) + " pixels "
human_word += 'after' if delta_x >= 0 else ' before '
human_word += "RGB and " + str(abs(delta_y)) + " pixels "
human_word += "right " if delta_y >= 0 else "left "
human_word += "the RGB"
fig, axs = plt.subplots(3, 1)
axs[0].imshow(rgb_img)
axs[0].set_title("RGB img")
axs[1].imshow(spec_img)
axs[1].set_title("spec img")
axs[2].imshow(rgb_cross_area & spce_cross_area)
axs[2].set_title("cross part")
plt.suptitle(human_word)
plt.show()
print(human_word)
return delta
@staticmethod
@ -148,5 +161,5 @@ class TestMain:
if __name__ == '__main__':
testor = TestMain()
testor.pony_run(test_path=r'E:\zhouchao\728-tobacco\correct',
test_rgb=True, test_spectra=True, get_delta=True)
testor.pony_run(test_path=r'/Volumes/LENOVO_USB_HDD/zhouchao/728-tobacco/correct',
test_rgb=True, test_spectra=True, get_delta=False)

View File

@ -1,15 +1,16 @@
import os
import threading
from multiprocessing import Process, Queue
import time
from utils import ImgQueue as Queue
from utils import ImgQueue as ImgQueue
import numpy as np
from config import Config
from models import SpecDetector, RgbDetector
import typing
import logging
logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s %(message)s',
level=logging.INFO)
level=logging.WARNING)
class Transmitter(object):
@ -25,7 +26,7 @@ class Transmitter(object):
"""
raise NotImplementedError
def set_output(self, output: Queue):
def set_output(self, output: ImgQueue):
"""
设置单个输出源
:param output:
@ -83,7 +84,7 @@ class BeforeAfterMethods:
class FifoReceiver(Transmitter):
def __init__(self, fifo_path: str, output: Queue, read_max_num: int, msg_queue=None):
def __init__(self, fifo_path: str, output: ImgQueue, read_max_num: int, msg_queue=None):
super().__init__()
self._input_fifo_path = None
self._output_queue = None
@ -101,7 +102,7 @@ class FifoReceiver(Transmitter):
os.mkfifo(fifo_path, 0o777)
self._input_fifo_path = fifo_path
def set_output(self, output: Queue):
def set_output(self, output: ImgQueue):
self._output_queue = output
def start(self, post_process_func=None, name='fifo_receiver'):
@ -130,7 +131,7 @@ class FifoReceiver(Transmitter):
class FifoSender(Transmitter):
def __init__(self, output_fifo_path: str, source: Queue):
def __init__(self, output_fifo_path: str, source: ImgQueue):
super().__init__()
self._input_source = None
self._output_fifo_path = None
@ -140,7 +141,7 @@ class FifoSender(Transmitter):
self._need_stop.clear()
self._running_thread = None
def set_source(self, source: Queue):
def set_source(self, source: ImgQueue):
self._input_source = source
def set_output(self, output_fifo_path: str):
@ -184,7 +185,7 @@ class CmdImgSplitMidware(Transmitter):
"""
用于控制命令和图像的中间件
"""
def __init__(self, subscribers: typing.Dict[str, Queue], rgb_queue: Queue, spec_queue: Queue):
def __init__(self, subscribers: typing.Dict[str, ImgQueue], rgb_queue: ImgQueue, spec_queue: ImgQueue):
super().__init__()
self._rgb_queue = None
self._spec_queue = None
@ -194,11 +195,11 @@ class CmdImgSplitMidware(Transmitter):
self.set_output(subscribers)
self.thread_stop = threading.Event()
def set_source(self, rgb_queue: Queue, spec_queue: Queue):
def set_source(self, rgb_queue: ImgQueue, spec_queue: ImgQueue):
self._rgb_queue = rgb_queue
self._spec_queue = spec_queue
def set_output(self, output: typing.Dict[str, Queue]):
def set_output(self, output: typing.Dict[str, ImgQueue]):
self._subscribers = output
def start(self, name='CMD_thread'):
@ -246,7 +247,7 @@ class ImageSaver(Transmitter):
class ThreadDetector(Transmitter):
def __init__(self, input_queue: Queue, output_queue: Queue):
def __init__(self, input_queue: ImgQueue, output_queue: ImgQueue):
super().__init__()
self._input_queue, self._output_queue = input_queue, output_queue
self._spec_detector = SpecDetector(blk_model_path=Config.blk_model_path,
@ -256,7 +257,7 @@ class ThreadDetector(Transmitter):
self._predict_thread = None
self._thread_exit = threading.Event()
def set_source(self, img_queue: Queue):
def set_source(self, img_queue: ImgQueue):
self._input_queue = img_queue
def stop(self, *args, **kwargs):
@ -291,3 +292,51 @@ class ThreadDetector(Transmitter):
mask = self.predict(spec, rgb)
self._output_queue.safe_put(mask)
self._thread_exit.clear()
class ProcessDetector(Transmitter):
def __init__(self, input_queue: Queue, output_queue: Queue):
super().__init__()
self._input_queue, self._output_queue = input_queue, output_queue
self._spec_detector = SpecDetector(blk_model_path=Config.blk_model_path,
pixel_model_path=Config.pixel_model_path)
self._rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path,
background_model_path=Config.rgb_background_model_path)
self._predict_thread = None
self._thread_exit = threading.Event()
def set_source(self, img_queue: Queue):
self._input_queue = img_queue
def stop(self, *args, **kwargs):
self._thread_exit.set()
def start(self, name='predict_thread'):
self._predict_thread = Process(target=self._predict_server, name=name, daemon=True)
self._predict_thread.start()
def predict(self, spec: np.ndarray, rgb: np.ndarray):
logging.info(f'Detector get image with shape {spec.shape} and {rgb.shape}')
t1 = time.time()
mask = self._spec_detector.predict(spec)
t2 = time.time()
logging.info(f'Detector finish spec predict within {(t2 - t1) * 1000:.2f}ms')
# rgb识别
mask_rgb = self._rgb_detector.predict(rgb)
t3 = time.time()
logging.info(f'Detector finish rgb predict within {(t3 - t2) * 1000:.2f}ms')
# 结果合并
mask_result = (mask | mask_rgb).astype(np.uint8)
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
t4 = time.time()
logging.info(f'Detector finish merge within {(t4 - t3) * 1000: .2f}ms')
logging.info(f'Detector finish predict within {(time.time() -t1)*1000:.2f}ms')
return mask_result
def _predict_server(self):
while not self._thread_exit.is_set():
if not self._input_queue.empty():
spec, rgb = self._input_queue.get()
mask = self.predict(spec, rgb)
self._output_queue.put(mask)
self._thread_exit.clear()

View File

@ -140,6 +140,7 @@ def size_threshold(img, blk_size, threshold):
return mask
if __name__ == '__main__':
color_dict = {(0, 0, 255): "yangeng", (255, 0, 0): "bejing", (0, 255, 0): "hongdianxian",
(255, 0, 255): "chengsebangbangtang", (0, 255, 255): "lvdianxian"}