mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 14:23:55 +00:00
计算误差完成版
This commit is contained in:
parent
144b88543d
commit
1a19246dc3
49
main_test.py
49
main_test.py
@ -66,11 +66,6 @@ class TestMain:
|
||||
if get_delta:
|
||||
spec_cv = np.clip(spec_img[..., [21, 3, 0]], a_min=0, a_max=1) * 255
|
||||
spec_cv = spec_cv.astype(np.uint8)
|
||||
plt.imshow(spec_cv)
|
||||
plt.show()
|
||||
spec_cv = spec_cv.astype(np.uint8)
|
||||
plt.imshow(spec_cv)
|
||||
plt.show()
|
||||
delta = self.calculate_delta(rgb_img, spec_cv)
|
||||
print(delta)
|
||||
self.merge(rgb_img=rgb_img, rgb_mask=rgb_mask,
|
||||
@ -112,25 +107,43 @@ class TestMain:
|
||||
plt.show()
|
||||
return mask_result
|
||||
|
||||
def calculate_delta(self, rgb_img, spec_img):
|
||||
def calculate_delta(self, rgb_img, spec_img, search_area_size=(200, 200), eps=1):
|
||||
rgb_grey, spec_grey = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2GRAY), cv2.cvtColor(spec_img, cv2.COLOR_RGB2GRAY)
|
||||
_, rgb_bin = cv2.threshold(rgb_grey, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||||
_, spec_bin = cv2.threshold(spec_grey, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||||
spec_bin = cv2.resize(spec_bin, dsize=(rgb_bin.shape[1], rgb_bin.shape[0]))
|
||||
fig, axs = plt.subplots(2, 1)
|
||||
axs[0].imshow(rgb_bin)
|
||||
axs[1].imshow(spec_bin)
|
||||
plt.show()
|
||||
search_area = np.zeros_like(spec_bin)
|
||||
for x in range(0, spec_bin.shape[0] // 10, 10):
|
||||
for y in range(0, spec_bin.shape[1] // 10, 10):
|
||||
delta_x, delta_y = x - spec_bin.shape[0] // 2, y - spec_bin.shape[1] // 2
|
||||
search_area = np.zeros(search_area_size)
|
||||
for x in range(0, search_area_size[0], eps):
|
||||
for y in range(0, search_area_size[1], eps):
|
||||
delta_x, delta_y = x - search_area_size[0] // 2, y - search_area_size[1] // 2
|
||||
rgb_cross_area = self.get_cross_area(rgb_bin, delta_x, delta_y)
|
||||
spce_cross_area = self.get_cross_area(spec_bin, -delta_x, -delta_y)
|
||||
response_altitude = np.sum(np.sum(rgb_cross_area & spce_cross_area))
|
||||
search_area[x, y] = response_altitude
|
||||
delta = np.argmax(search_area)
|
||||
print(delta)
|
||||
delta = np.unravel_index(np.argmax(search_area), search_area.shape)
|
||||
delta = (delta[0] - search_area_size[1]//2, delta[1] - search_area_size[1]//2)
|
||||
delta_x, delta_y = delta
|
||||
|
||||
rgb_cross_area = self.get_cross_area(rgb_bin, delta_x, delta_y)
|
||||
spce_cross_area = self.get_cross_area(spec_bin, -delta_x, -delta_y)
|
||||
|
||||
human_word = "SPEC is " + str(abs(delta_x)) + " pixels "
|
||||
human_word += 'after' if delta_x >= 0 else ' before '
|
||||
human_word += "RGB and " + str(abs(delta_y)) + " pixels "
|
||||
human_word += "right " if delta_y >= 0 else "left "
|
||||
human_word += "the RGB"
|
||||
|
||||
fig, axs = plt.subplots(3, 1)
|
||||
axs[0].imshow(rgb_img)
|
||||
axs[0].set_title("RGB img")
|
||||
axs[1].imshow(spec_img)
|
||||
axs[1].set_title("spec img")
|
||||
axs[2].imshow(rgb_cross_area & spce_cross_area)
|
||||
axs[2].set_title("cross part")
|
||||
plt.suptitle(human_word)
|
||||
plt.show()
|
||||
|
||||
print(human_word)
|
||||
return delta
|
||||
|
||||
@staticmethod
|
||||
@ -148,5 +161,5 @@ class TestMain:
|
||||
|
||||
if __name__ == '__main__':
|
||||
testor = TestMain()
|
||||
testor.pony_run(test_path=r'E:\zhouchao\728-tobacco\correct',
|
||||
test_rgb=True, test_spectra=True, get_delta=True)
|
||||
testor.pony_run(test_path=r'/Volumes/LENOVO_USB_HDD/zhouchao/728-tobacco/correct',
|
||||
test_rgb=True, test_spectra=True, get_delta=False)
|
||||
|
||||
73
transmit.py
73
transmit.py
@ -1,15 +1,16 @@
|
||||
import os
|
||||
import threading
|
||||
from multiprocessing import Process, Queue
|
||||
import time
|
||||
|
||||
from utils import ImgQueue as Queue
|
||||
from utils import ImgQueue as ImgQueue
|
||||
import numpy as np
|
||||
from config import Config
|
||||
from models import SpecDetector, RgbDetector
|
||||
import typing
|
||||
import logging
|
||||
logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s %(message)s',
|
||||
level=logging.INFO)
|
||||
level=logging.WARNING)
|
||||
|
||||
|
||||
class Transmitter(object):
|
||||
@ -25,7 +26,7 @@ class Transmitter(object):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def set_output(self, output: Queue):
|
||||
def set_output(self, output: ImgQueue):
|
||||
"""
|
||||
设置单个输出源
|
||||
:param output:
|
||||
@ -83,7 +84,7 @@ class BeforeAfterMethods:
|
||||
|
||||
|
||||
class FifoReceiver(Transmitter):
|
||||
def __init__(self, fifo_path: str, output: Queue, read_max_num: int, msg_queue=None):
|
||||
def __init__(self, fifo_path: str, output: ImgQueue, read_max_num: int, msg_queue=None):
|
||||
super().__init__()
|
||||
self._input_fifo_path = None
|
||||
self._output_queue = None
|
||||
@ -101,7 +102,7 @@ class FifoReceiver(Transmitter):
|
||||
os.mkfifo(fifo_path, 0o777)
|
||||
self._input_fifo_path = fifo_path
|
||||
|
||||
def set_output(self, output: Queue):
|
||||
def set_output(self, output: ImgQueue):
|
||||
self._output_queue = output
|
||||
|
||||
def start(self, post_process_func=None, name='fifo_receiver'):
|
||||
@ -130,7 +131,7 @@ class FifoReceiver(Transmitter):
|
||||
|
||||
|
||||
class FifoSender(Transmitter):
|
||||
def __init__(self, output_fifo_path: str, source: Queue):
|
||||
def __init__(self, output_fifo_path: str, source: ImgQueue):
|
||||
super().__init__()
|
||||
self._input_source = None
|
||||
self._output_fifo_path = None
|
||||
@ -140,7 +141,7 @@ class FifoSender(Transmitter):
|
||||
self._need_stop.clear()
|
||||
self._running_thread = None
|
||||
|
||||
def set_source(self, source: Queue):
|
||||
def set_source(self, source: ImgQueue):
|
||||
self._input_source = source
|
||||
|
||||
def set_output(self, output_fifo_path: str):
|
||||
@ -184,7 +185,7 @@ class CmdImgSplitMidware(Transmitter):
|
||||
"""
|
||||
用于控制命令和图像的中间件
|
||||
"""
|
||||
def __init__(self, subscribers: typing.Dict[str, Queue], rgb_queue: Queue, spec_queue: Queue):
|
||||
def __init__(self, subscribers: typing.Dict[str, ImgQueue], rgb_queue: ImgQueue, spec_queue: ImgQueue):
|
||||
super().__init__()
|
||||
self._rgb_queue = None
|
||||
self._spec_queue = None
|
||||
@ -194,11 +195,11 @@ class CmdImgSplitMidware(Transmitter):
|
||||
self.set_output(subscribers)
|
||||
self.thread_stop = threading.Event()
|
||||
|
||||
def set_source(self, rgb_queue: Queue, spec_queue: Queue):
|
||||
def set_source(self, rgb_queue: ImgQueue, spec_queue: ImgQueue):
|
||||
self._rgb_queue = rgb_queue
|
||||
self._spec_queue = spec_queue
|
||||
|
||||
def set_output(self, output: typing.Dict[str, Queue]):
|
||||
def set_output(self, output: typing.Dict[str, ImgQueue]):
|
||||
self._subscribers = output
|
||||
|
||||
def start(self, name='CMD_thread'):
|
||||
@ -246,7 +247,7 @@ class ImageSaver(Transmitter):
|
||||
|
||||
|
||||
class ThreadDetector(Transmitter):
|
||||
def __init__(self, input_queue: Queue, output_queue: Queue):
|
||||
def __init__(self, input_queue: ImgQueue, output_queue: ImgQueue):
|
||||
super().__init__()
|
||||
self._input_queue, self._output_queue = input_queue, output_queue
|
||||
self._spec_detector = SpecDetector(blk_model_path=Config.blk_model_path,
|
||||
@ -256,7 +257,7 @@ class ThreadDetector(Transmitter):
|
||||
self._predict_thread = None
|
||||
self._thread_exit = threading.Event()
|
||||
|
||||
def set_source(self, img_queue: Queue):
|
||||
def set_source(self, img_queue: ImgQueue):
|
||||
self._input_queue = img_queue
|
||||
|
||||
def stop(self, *args, **kwargs):
|
||||
@ -291,3 +292,51 @@ class ThreadDetector(Transmitter):
|
||||
mask = self.predict(spec, rgb)
|
||||
self._output_queue.safe_put(mask)
|
||||
self._thread_exit.clear()
|
||||
|
||||
|
||||
class ProcessDetector(Transmitter):
|
||||
def __init__(self, input_queue: Queue, output_queue: Queue):
|
||||
super().__init__()
|
||||
self._input_queue, self._output_queue = input_queue, output_queue
|
||||
self._spec_detector = SpecDetector(blk_model_path=Config.blk_model_path,
|
||||
pixel_model_path=Config.pixel_model_path)
|
||||
self._rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path,
|
||||
background_model_path=Config.rgb_background_model_path)
|
||||
self._predict_thread = None
|
||||
self._thread_exit = threading.Event()
|
||||
|
||||
def set_source(self, img_queue: Queue):
|
||||
self._input_queue = img_queue
|
||||
|
||||
def stop(self, *args, **kwargs):
|
||||
self._thread_exit.set()
|
||||
|
||||
def start(self, name='predict_thread'):
|
||||
self._predict_thread = Process(target=self._predict_server, name=name, daemon=True)
|
||||
self._predict_thread.start()
|
||||
|
||||
def predict(self, spec: np.ndarray, rgb: np.ndarray):
|
||||
logging.info(f'Detector get image with shape {spec.shape} and {rgb.shape}')
|
||||
t1 = time.time()
|
||||
mask = self._spec_detector.predict(spec)
|
||||
t2 = time.time()
|
||||
logging.info(f'Detector finish spec predict within {(t2 - t1) * 1000:.2f}ms')
|
||||
# rgb识别
|
||||
mask_rgb = self._rgb_detector.predict(rgb)
|
||||
t3 = time.time()
|
||||
logging.info(f'Detector finish rgb predict within {(t3 - t2) * 1000:.2f}ms')
|
||||
# 结果合并
|
||||
mask_result = (mask | mask_rgb).astype(np.uint8)
|
||||
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
|
||||
t4 = time.time()
|
||||
logging.info(f'Detector finish merge within {(t4 - t3) * 1000: .2f}ms')
|
||||
logging.info(f'Detector finish predict within {(time.time() -t1)*1000:.2f}ms')
|
||||
return mask_result
|
||||
|
||||
def _predict_server(self):
|
||||
while not self._thread_exit.is_set():
|
||||
if not self._input_queue.empty():
|
||||
spec, rgb = self._input_queue.get()
|
||||
mask = self.predict(spec, rgb)
|
||||
self._output_queue.put(mask)
|
||||
self._thread_exit.clear()
|
||||
Loading…
Reference in New Issue
Block a user