mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 14:23:55 +00:00
计算误差完成版
This commit is contained in:
parent
144b88543d
commit
1a19246dc3
49
main_test.py
49
main_test.py
@ -66,11 +66,6 @@ class TestMain:
|
|||||||
if get_delta:
|
if get_delta:
|
||||||
spec_cv = np.clip(spec_img[..., [21, 3, 0]], a_min=0, a_max=1) * 255
|
spec_cv = np.clip(spec_img[..., [21, 3, 0]], a_min=0, a_max=1) * 255
|
||||||
spec_cv = spec_cv.astype(np.uint8)
|
spec_cv = spec_cv.astype(np.uint8)
|
||||||
plt.imshow(spec_cv)
|
|
||||||
plt.show()
|
|
||||||
spec_cv = spec_cv.astype(np.uint8)
|
|
||||||
plt.imshow(spec_cv)
|
|
||||||
plt.show()
|
|
||||||
delta = self.calculate_delta(rgb_img, spec_cv)
|
delta = self.calculate_delta(rgb_img, spec_cv)
|
||||||
print(delta)
|
print(delta)
|
||||||
self.merge(rgb_img=rgb_img, rgb_mask=rgb_mask,
|
self.merge(rgb_img=rgb_img, rgb_mask=rgb_mask,
|
||||||
@ -112,25 +107,43 @@ class TestMain:
|
|||||||
plt.show()
|
plt.show()
|
||||||
return mask_result
|
return mask_result
|
||||||
|
|
||||||
def calculate_delta(self, rgb_img, spec_img):
|
def calculate_delta(self, rgb_img, spec_img, search_area_size=(200, 200), eps=1):
|
||||||
rgb_grey, spec_grey = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2GRAY), cv2.cvtColor(spec_img, cv2.COLOR_RGB2GRAY)
|
rgb_grey, spec_grey = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2GRAY), cv2.cvtColor(spec_img, cv2.COLOR_RGB2GRAY)
|
||||||
_, rgb_bin = cv2.threshold(rgb_grey, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
_, rgb_bin = cv2.threshold(rgb_grey, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||||||
_, spec_bin = cv2.threshold(spec_grey, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
_, spec_bin = cv2.threshold(spec_grey, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||||||
spec_bin = cv2.resize(spec_bin, dsize=(rgb_bin.shape[1], rgb_bin.shape[0]))
|
spec_bin = cv2.resize(spec_bin, dsize=(rgb_bin.shape[1], rgb_bin.shape[0]))
|
||||||
fig, axs = plt.subplots(2, 1)
|
search_area = np.zeros(search_area_size)
|
||||||
axs[0].imshow(rgb_bin)
|
for x in range(0, search_area_size[0], eps):
|
||||||
axs[1].imshow(spec_bin)
|
for y in range(0, search_area_size[1], eps):
|
||||||
plt.show()
|
delta_x, delta_y = x - search_area_size[0] // 2, y - search_area_size[1] // 2
|
||||||
search_area = np.zeros_like(spec_bin)
|
|
||||||
for x in range(0, spec_bin.shape[0] // 10, 10):
|
|
||||||
for y in range(0, spec_bin.shape[1] // 10, 10):
|
|
||||||
delta_x, delta_y = x - spec_bin.shape[0] // 2, y - spec_bin.shape[1] // 2
|
|
||||||
rgb_cross_area = self.get_cross_area(rgb_bin, delta_x, delta_y)
|
rgb_cross_area = self.get_cross_area(rgb_bin, delta_x, delta_y)
|
||||||
spce_cross_area = self.get_cross_area(spec_bin, -delta_x, -delta_y)
|
spce_cross_area = self.get_cross_area(spec_bin, -delta_x, -delta_y)
|
||||||
response_altitude = np.sum(np.sum(rgb_cross_area & spce_cross_area))
|
response_altitude = np.sum(np.sum(rgb_cross_area & spce_cross_area))
|
||||||
search_area[x, y] = response_altitude
|
search_area[x, y] = response_altitude
|
||||||
delta = np.argmax(search_area)
|
delta = np.unravel_index(np.argmax(search_area), search_area.shape)
|
||||||
print(delta)
|
delta = (delta[0] - search_area_size[1]//2, delta[1] - search_area_size[1]//2)
|
||||||
|
delta_x, delta_y = delta
|
||||||
|
|
||||||
|
rgb_cross_area = self.get_cross_area(rgb_bin, delta_x, delta_y)
|
||||||
|
spce_cross_area = self.get_cross_area(spec_bin, -delta_x, -delta_y)
|
||||||
|
|
||||||
|
human_word = "SPEC is " + str(abs(delta_x)) + " pixels "
|
||||||
|
human_word += 'after' if delta_x >= 0 else ' before '
|
||||||
|
human_word += "RGB and " + str(abs(delta_y)) + " pixels "
|
||||||
|
human_word += "right " if delta_y >= 0 else "left "
|
||||||
|
human_word += "the RGB"
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(3, 1)
|
||||||
|
axs[0].imshow(rgb_img)
|
||||||
|
axs[0].set_title("RGB img")
|
||||||
|
axs[1].imshow(spec_img)
|
||||||
|
axs[1].set_title("spec img")
|
||||||
|
axs[2].imshow(rgb_cross_area & spce_cross_area)
|
||||||
|
axs[2].set_title("cross part")
|
||||||
|
plt.suptitle(human_word)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
print(human_word)
|
||||||
return delta
|
return delta
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -148,5 +161,5 @@ class TestMain:
|
|||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
testor = TestMain()
|
testor = TestMain()
|
||||||
testor.pony_run(test_path=r'E:\zhouchao\728-tobacco\correct',
|
testor.pony_run(test_path=r'/Volumes/LENOVO_USB_HDD/zhouchao/728-tobacco/correct',
|
||||||
test_rgb=True, test_spectra=True, get_delta=True)
|
test_rgb=True, test_spectra=True, get_delta=False)
|
||||||
|
|||||||
73
transmit.py
73
transmit.py
@ -1,15 +1,16 @@
|
|||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
|
from multiprocessing import Process, Queue
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from utils import ImgQueue as Queue
|
from utils import ImgQueue as ImgQueue
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from config import Config
|
from config import Config
|
||||||
from models import SpecDetector, RgbDetector
|
from models import SpecDetector, RgbDetector
|
||||||
import typing
|
import typing
|
||||||
import logging
|
import logging
|
||||||
logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s %(message)s',
|
logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s %(message)s',
|
||||||
level=logging.INFO)
|
level=logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
class Transmitter(object):
|
class Transmitter(object):
|
||||||
@ -25,7 +26,7 @@ class Transmitter(object):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def set_output(self, output: Queue):
|
def set_output(self, output: ImgQueue):
|
||||||
"""
|
"""
|
||||||
设置单个输出源
|
设置单个输出源
|
||||||
:param output:
|
:param output:
|
||||||
@ -83,7 +84,7 @@ class BeforeAfterMethods:
|
|||||||
|
|
||||||
|
|
||||||
class FifoReceiver(Transmitter):
|
class FifoReceiver(Transmitter):
|
||||||
def __init__(self, fifo_path: str, output: Queue, read_max_num: int, msg_queue=None):
|
def __init__(self, fifo_path: str, output: ImgQueue, read_max_num: int, msg_queue=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._input_fifo_path = None
|
self._input_fifo_path = None
|
||||||
self._output_queue = None
|
self._output_queue = None
|
||||||
@ -101,7 +102,7 @@ class FifoReceiver(Transmitter):
|
|||||||
os.mkfifo(fifo_path, 0o777)
|
os.mkfifo(fifo_path, 0o777)
|
||||||
self._input_fifo_path = fifo_path
|
self._input_fifo_path = fifo_path
|
||||||
|
|
||||||
def set_output(self, output: Queue):
|
def set_output(self, output: ImgQueue):
|
||||||
self._output_queue = output
|
self._output_queue = output
|
||||||
|
|
||||||
def start(self, post_process_func=None, name='fifo_receiver'):
|
def start(self, post_process_func=None, name='fifo_receiver'):
|
||||||
@ -130,7 +131,7 @@ class FifoReceiver(Transmitter):
|
|||||||
|
|
||||||
|
|
||||||
class FifoSender(Transmitter):
|
class FifoSender(Transmitter):
|
||||||
def __init__(self, output_fifo_path: str, source: Queue):
|
def __init__(self, output_fifo_path: str, source: ImgQueue):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._input_source = None
|
self._input_source = None
|
||||||
self._output_fifo_path = None
|
self._output_fifo_path = None
|
||||||
@ -140,7 +141,7 @@ class FifoSender(Transmitter):
|
|||||||
self._need_stop.clear()
|
self._need_stop.clear()
|
||||||
self._running_thread = None
|
self._running_thread = None
|
||||||
|
|
||||||
def set_source(self, source: Queue):
|
def set_source(self, source: ImgQueue):
|
||||||
self._input_source = source
|
self._input_source = source
|
||||||
|
|
||||||
def set_output(self, output_fifo_path: str):
|
def set_output(self, output_fifo_path: str):
|
||||||
@ -184,7 +185,7 @@ class CmdImgSplitMidware(Transmitter):
|
|||||||
"""
|
"""
|
||||||
用于控制命令和图像的中间件
|
用于控制命令和图像的中间件
|
||||||
"""
|
"""
|
||||||
def __init__(self, subscribers: typing.Dict[str, Queue], rgb_queue: Queue, spec_queue: Queue):
|
def __init__(self, subscribers: typing.Dict[str, ImgQueue], rgb_queue: ImgQueue, spec_queue: ImgQueue):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._rgb_queue = None
|
self._rgb_queue = None
|
||||||
self._spec_queue = None
|
self._spec_queue = None
|
||||||
@ -194,11 +195,11 @@ class CmdImgSplitMidware(Transmitter):
|
|||||||
self.set_output(subscribers)
|
self.set_output(subscribers)
|
||||||
self.thread_stop = threading.Event()
|
self.thread_stop = threading.Event()
|
||||||
|
|
||||||
def set_source(self, rgb_queue: Queue, spec_queue: Queue):
|
def set_source(self, rgb_queue: ImgQueue, spec_queue: ImgQueue):
|
||||||
self._rgb_queue = rgb_queue
|
self._rgb_queue = rgb_queue
|
||||||
self._spec_queue = spec_queue
|
self._spec_queue = spec_queue
|
||||||
|
|
||||||
def set_output(self, output: typing.Dict[str, Queue]):
|
def set_output(self, output: typing.Dict[str, ImgQueue]):
|
||||||
self._subscribers = output
|
self._subscribers = output
|
||||||
|
|
||||||
def start(self, name='CMD_thread'):
|
def start(self, name='CMD_thread'):
|
||||||
@ -246,7 +247,7 @@ class ImageSaver(Transmitter):
|
|||||||
|
|
||||||
|
|
||||||
class ThreadDetector(Transmitter):
|
class ThreadDetector(Transmitter):
|
||||||
def __init__(self, input_queue: Queue, output_queue: Queue):
|
def __init__(self, input_queue: ImgQueue, output_queue: ImgQueue):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._input_queue, self._output_queue = input_queue, output_queue
|
self._input_queue, self._output_queue = input_queue, output_queue
|
||||||
self._spec_detector = SpecDetector(blk_model_path=Config.blk_model_path,
|
self._spec_detector = SpecDetector(blk_model_path=Config.blk_model_path,
|
||||||
@ -256,7 +257,7 @@ class ThreadDetector(Transmitter):
|
|||||||
self._predict_thread = None
|
self._predict_thread = None
|
||||||
self._thread_exit = threading.Event()
|
self._thread_exit = threading.Event()
|
||||||
|
|
||||||
def set_source(self, img_queue: Queue):
|
def set_source(self, img_queue: ImgQueue):
|
||||||
self._input_queue = img_queue
|
self._input_queue = img_queue
|
||||||
|
|
||||||
def stop(self, *args, **kwargs):
|
def stop(self, *args, **kwargs):
|
||||||
@ -291,3 +292,51 @@ class ThreadDetector(Transmitter):
|
|||||||
mask = self.predict(spec, rgb)
|
mask = self.predict(spec, rgb)
|
||||||
self._output_queue.safe_put(mask)
|
self._output_queue.safe_put(mask)
|
||||||
self._thread_exit.clear()
|
self._thread_exit.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessDetector(Transmitter):
|
||||||
|
def __init__(self, input_queue: Queue, output_queue: Queue):
|
||||||
|
super().__init__()
|
||||||
|
self._input_queue, self._output_queue = input_queue, output_queue
|
||||||
|
self._spec_detector = SpecDetector(blk_model_path=Config.blk_model_path,
|
||||||
|
pixel_model_path=Config.pixel_model_path)
|
||||||
|
self._rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path,
|
||||||
|
background_model_path=Config.rgb_background_model_path)
|
||||||
|
self._predict_thread = None
|
||||||
|
self._thread_exit = threading.Event()
|
||||||
|
|
||||||
|
def set_source(self, img_queue: Queue):
|
||||||
|
self._input_queue = img_queue
|
||||||
|
|
||||||
|
def stop(self, *args, **kwargs):
|
||||||
|
self._thread_exit.set()
|
||||||
|
|
||||||
|
def start(self, name='predict_thread'):
|
||||||
|
self._predict_thread = Process(target=self._predict_server, name=name, daemon=True)
|
||||||
|
self._predict_thread.start()
|
||||||
|
|
||||||
|
def predict(self, spec: np.ndarray, rgb: np.ndarray):
|
||||||
|
logging.info(f'Detector get image with shape {spec.shape} and {rgb.shape}')
|
||||||
|
t1 = time.time()
|
||||||
|
mask = self._spec_detector.predict(spec)
|
||||||
|
t2 = time.time()
|
||||||
|
logging.info(f'Detector finish spec predict within {(t2 - t1) * 1000:.2f}ms')
|
||||||
|
# rgb识别
|
||||||
|
mask_rgb = self._rgb_detector.predict(rgb)
|
||||||
|
t3 = time.time()
|
||||||
|
logging.info(f'Detector finish rgb predict within {(t3 - t2) * 1000:.2f}ms')
|
||||||
|
# 结果合并
|
||||||
|
mask_result = (mask | mask_rgb).astype(np.uint8)
|
||||||
|
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
|
||||||
|
t4 = time.time()
|
||||||
|
logging.info(f'Detector finish merge within {(t4 - t3) * 1000: .2f}ms')
|
||||||
|
logging.info(f'Detector finish predict within {(time.time() -t1)*1000:.2f}ms')
|
||||||
|
return mask_result
|
||||||
|
|
||||||
|
def _predict_server(self):
|
||||||
|
while not self._thread_exit.is_set():
|
||||||
|
if not self._input_queue.empty():
|
||||||
|
spec, rgb = self._input_queue.get()
|
||||||
|
mask = self.predict(spec, rgb)
|
||||||
|
self._output_queue.put(mask)
|
||||||
|
self._thread_exit.clear()
|
||||||
1
utils.py
1
utils.py
@ -140,6 +140,7 @@ def size_threshold(img, blk_size, threshold):
|
|||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
color_dict = {(0, 0, 255): "yangeng", (255, 0, 0): "bejing", (0, 255, 0): "hongdianxian",
|
color_dict = {(0, 0, 255): "yangeng", (255, 0, 0): "bejing", (0, 255, 0): "hongdianxian",
|
||||||
(255, 0, 255): "chengsebangbangtang", (0, 255, 255): "lvdianxian"}
|
(255, 0, 255): "chengsebangbangtang", (0, 255, 255): "lvdianxian"}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user