mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 14:23:55 +00:00
Merge branch 'master' of https://github.com/FEIJINTI/tobacco_color
This commit is contained in:
commit
cdbe341591
@ -20,7 +20,7 @@ class Config:
|
|||||||
# 光谱模型参数
|
# 光谱模型参数
|
||||||
blk_size = 4
|
blk_size = 4
|
||||||
pixel_model_path = r"./models/dt.p"
|
pixel_model_path = r"./models/dt.p"
|
||||||
blk_model_path = r"./models/rf_4x4_c22_20_sen8_8.model"
|
blk_model_path = r"./models/rf_4x4_c22_20_sen8_9.model"
|
||||||
spec_size_threshold = 3
|
spec_size_threshold = 3
|
||||||
|
|
||||||
# rgb模型参数
|
# rgb模型参数
|
||||||
|
|||||||
51
efficient_ui.py
Normal file
51
efficient_ui.py
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import transmit
|
||||||
|
from config import Config
|
||||||
|
|
||||||
|
from utils import ImgQueue as Queue
|
||||||
|
|
||||||
|
|
||||||
|
class EfficientUI(object):
|
||||||
|
def __init__(self):
|
||||||
|
# 相关参数
|
||||||
|
img_fifo_path = "/tmp/dkimg.fifo"
|
||||||
|
mask_fifo_path = "/tmp/dkmask.fifo"
|
||||||
|
rgb_fifo_path = "/tmp/dkrgb.fifo"
|
||||||
|
# 创建队列用于链接各个线程
|
||||||
|
rgb_img_queue, spec_img_queue = Queue(), Queue()
|
||||||
|
detector_queue, save_queue, self.visual_queue = Queue(), Queue(), Queue()
|
||||||
|
mask_queue = Queue()
|
||||||
|
# 两个接收者,接收光谱和rgb图像
|
||||||
|
spec_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节
|
||||||
|
rgb_len = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量
|
||||||
|
spec_receiver = transmit.FifoReceiver(fifo_path=img_fifo_path, output=spec_img_queue, read_max_num=spec_len)
|
||||||
|
rgb_receiver = transmit.FifoReceiver(fifo_path=rgb_fifo_path, output=rgb_img_queue, read_max_num=rgb_len)
|
||||||
|
# 指令执行与图像流向控制
|
||||||
|
subscribers = {'detector': detector_queue, 'visualize': self.visual_queue, 'save': save_queue}
|
||||||
|
cmd_img_controller = transmit.CmdImgSplitMidware(rgb_queue=rgb_img_queue, spec_queue=spec_img_queue,
|
||||||
|
subscribers=subscribers)
|
||||||
|
# 探测器
|
||||||
|
detector = transmit.ThreadDetector(input_queue=detector_queue, output_queue=mask_queue)
|
||||||
|
# 发送
|
||||||
|
sender = transmit.FifoSender(output_fifo_path=mask_fifo_path, source=mask_queue)
|
||||||
|
# 启动所有线程
|
||||||
|
spec_receiver.start(post_process_func=transmit.BeforeAfterMethods.spec_data_post_process, name='spce_thread')
|
||||||
|
rgb_receiver.start(post_process_func=transmit.BeforeAfterMethods.rgb_data_post_process, name='rgb_thread')
|
||||||
|
cmd_img_controller.start(name='control_thread')
|
||||||
|
detector.start(name='detector_thread')
|
||||||
|
sender.start(pre_process=transmit.BeforeAfterMethods.mask_preprocess, name='sender_thread')
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
# 启动图形化
|
||||||
|
while True:
|
||||||
|
cv2.imshow('image_show', mat=np.ones((256, 1024)))
|
||||||
|
key_code = cv2.waitKey(30)
|
||||||
|
if key_code == ord("s"):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
app = EfficientUI()
|
||||||
|
app.start()
|
||||||
142
main.py
142
main.py
@ -1,16 +1,18 @@
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
|
from queue import Queue
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.io
|
from matplotlib import pyplot as plt
|
||||||
|
|
||||||
|
import models
|
||||||
|
import transmit
|
||||||
|
|
||||||
from config import Config
|
from config import Config
|
||||||
from models import RgbDetector, SpecDetector, ManualTree, AnonymousColorDetector
|
from models import RgbDetector, SpecDetector
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
SAVE_IMG, SAVE_NUM = False, 30
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
spec_detector = SpecDetector(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path)
|
spec_detector = SpecDetector(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path)
|
||||||
rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path,
|
rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path,
|
||||||
@ -23,8 +25,6 @@ def main():
|
|||||||
os.mkfifo(mask_fifo_path, 0o777)
|
os.mkfifo(mask_fifo_path, 0o777)
|
||||||
if not os.access(rgb_fifo_path, os.F_OK):
|
if not os.access(rgb_fifo_path, os.F_OK):
|
||||||
os.mkfifo(rgb_fifo_path, 0o777)
|
os.mkfifo(rgb_fifo_path, 0o777)
|
||||||
if SAVE_IMG:
|
|
||||||
img_list = []
|
|
||||||
while True:
|
while True:
|
||||||
fd_img = os.open(img_fifo_path, os.O_RDONLY)
|
fd_img = os.open(img_fifo_path, os.O_RDONLY)
|
||||||
fd_rgb = os.open(rgb_fifo_path, os.O_RDONLY)
|
fd_rgb = os.open(rgb_fifo_path, os.O_RDONLY)
|
||||||
@ -55,19 +55,12 @@ def main():
|
|||||||
img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)) \
|
img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)) \
|
||||||
.transpose(0, 2, 1)
|
.transpose(0, 2, 1)
|
||||||
rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
|
rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
|
||||||
if SAVE_IMG:
|
|
||||||
SAVE_NUM -= 1
|
|
||||||
img_list.append((rgb_data, img_data))
|
|
||||||
if SAVE_NUM <= 0:
|
|
||||||
break
|
|
||||||
# 光谱识别
|
# 光谱识别
|
||||||
mask = spec_detector.predict(img_data)
|
mask = spec_detector.predict(img_data)
|
||||||
# rgb识别
|
# rgb识别
|
||||||
mask_rgb = rgb_detector.predict(rgb_data)
|
mask_rgb = rgb_detector.predict(rgb_data)
|
||||||
|
|
||||||
# 结果合并
|
# 结果合并
|
||||||
mask_result = (mask | mask_rgb).astype(np.uint8)
|
mask_result = (mask | mask_rgb).astype(np.uint8)
|
||||||
|
|
||||||
# mask_result = mask_rgb.astype(np.uint8)
|
# mask_result = mask_rgb.astype(np.uint8)
|
||||||
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
|
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
|
||||||
t2 = time.time()
|
t2 = time.time()
|
||||||
@ -79,100 +72,37 @@ def main():
|
|||||||
os.close(fd_mask)
|
os.close(fd_mask)
|
||||||
t3 = time.time()
|
t3 = time.time()
|
||||||
print(f'total time is:{t3 - t1}')
|
print(f'total time is:{t3 - t1}')
|
||||||
for i, img in enumerate(img_list):
|
|
||||||
print(f"writing img {i}...")
|
|
||||||
cv2.imwrite(f"./{i}.png", img[0][..., ::-1])
|
|
||||||
np.save(f'./{i}.npy', img[1])
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def save_main():
|
def read_c_captures(buffer_path, no_mask=True, nrows=256, ncols=1024, selected_bands=None):
|
||||||
threshold = Config.spec_size_threshold
|
if os.path.isdir(buffer_path):
|
||||||
rgb_threshold = Config.rgb_size_threshold
|
buffer_names = [buffer_name for buffer_name in os.listdir(buffer_path) if buffer_name.endswith('.raw')]
|
||||||
manual_tree = ManualTree(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path)
|
else:
|
||||||
tobacco_detector = AnonymousColorDetector(file_path=Config.rgb_tobacco_model_path)
|
buffer_names = [buffer_path, ]
|
||||||
background_detector = AnonymousColorDetector(file_path=Config.rgb_background_model_path)
|
for buffer_name in buffer_names:
|
||||||
total_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节
|
with open(os.path.join(buffer_path, buffer_name), 'rb') as f:
|
||||||
total_rgb = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量
|
data = f.read()
|
||||||
if not os.access(img_fifo_path, os.F_OK):
|
img = np.frombuffer(data, dtype=np.float32).reshape((nrows, -1, ncols)) \
|
||||||
os.mkfifo(img_fifo_path, 0o777)
|
.transpose(0, 2, 1)
|
||||||
if not os.access(mask_fifo_path, os.F_OK):
|
if selected_bands is not None:
|
||||||
os.mkfifo(mask_fifo_path, 0o777)
|
img = img[..., selected_bands]
|
||||||
if not os.access(rgb_fifo_path, os.F_OK):
|
if img.shape[0] == 1:
|
||||||
os.mkfifo(rgb_fifo_path, 0o777)
|
img = img[0, ...]
|
||||||
img_list = []
|
if not no_mask:
|
||||||
idx = 0
|
mask_name = buffer_name.replace('buf', 'mask').replace('.raw', '')
|
||||||
while idx <= 30:
|
with open(os.path.join(buffer_path, mask_name), 'rb') as f:
|
||||||
idx += 1
|
data = f.read()
|
||||||
fd_img = os.open(img_fifo_path, os.O_RDONLY)
|
mask = np.frombuffer(data, dtype=np.uint8).reshape((nrows, ncols, -1))
|
||||||
fd_rgb = os.open(rgb_fifo_path, os.O_RDONLY)
|
|
||||||
data = os.read(fd_img, total_len)
|
|
||||||
|
|
||||||
# 读取(开启一个管道)
|
|
||||||
if len(data) < 3:
|
|
||||||
threshold = int(float(data))
|
|
||||||
print("[INFO] Get threshold: ", threshold)
|
|
||||||
continue
|
|
||||||
else:
|
else:
|
||||||
data_total = data
|
mask_name = "no mask"
|
||||||
rgb_data = os.read(fd_rgb, total_rgb)
|
mask = np.zeros_like(img)
|
||||||
if len(rgb_data) < 3:
|
# mask = cv2.resize(mask, (1024, 256))
|
||||||
rgb_threshold = int(float(rgb_data))
|
fig, axs = plt.subplots(2, 1)
|
||||||
print(rgb_threshold)
|
axs[0].matshow(img)
|
||||||
continue
|
axs[0].set_title(buffer_name)
|
||||||
else:
|
axs[1].imshow(mask)
|
||||||
rgb_data_total = rgb_data
|
axs[1].set_title(mask_name)
|
||||||
os.close(fd_img)
|
plt.show()
|
||||||
os.close(fd_rgb)
|
|
||||||
|
|
||||||
# 识别
|
|
||||||
t1 = time.time()
|
|
||||||
img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)). \
|
|
||||||
transpose(0, 2, 1)
|
|
||||||
rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
|
|
||||||
img_list.append((rgb_data.copy(), img_data.copy()))
|
|
||||||
|
|
||||||
pixel_predict_result = manual_tree.pixel_predict_ml_dilation(data=img_data, iteration=1)
|
|
||||||
blk_predict_result = manual_tree.blk_predict(data=img_data)
|
|
||||||
rgb_data = tobacco_detector.pretreatment(rgb_data)
|
|
||||||
# print(rgb_data.shape)
|
|
||||||
rgb_predict_result = 1 - (background_detector.predict(rgb_data, threshold_low=Config.threshold_low,
|
|
||||||
threshold_high=Config.threshold_high) |
|
|
||||||
tobacco_detector.swell(tobacco_detector.predict(rgb_data,
|
|
||||||
threshold_low=Config.threshold_low,
|
|
||||||
threshold_high=Config.threshold_high)))
|
|
||||||
mask_rgb = rgb_predict_result.reshape(Config.nRows, Config.nCols // Config.blk_size, Config.blk_size) \
|
|
||||||
.sum(axis=2).reshape(Config.nRows // 4, Config.blk_size, Config.nCols // Config.blk_size) \
|
|
||||||
.sum(axis=1)
|
|
||||||
mask_rgb[mask_rgb <= rgb_threshold] = 0
|
|
||||||
mask_rgb[mask_rgb > rgb_threshold] = 1
|
|
||||||
mask = (pixel_predict_result & blk_predict_result).astype(np.uint8)
|
|
||||||
mask = mask.reshape(Config.nRows, Config.nCols // Config.blk_size, Config.blk_size) \
|
|
||||||
.sum(axis=2).reshape(Config.nRows // 4, Config.blk_size, Config.nCols // Config.blk_size) \
|
|
||||||
.sum(axis=1)
|
|
||||||
mask[mask <= threshold] = 0
|
|
||||||
mask[mask > threshold] = 1
|
|
||||||
mask_result = (mask | mask_rgb).astype(np.uint8)
|
|
||||||
# mask_result = mask_rgb
|
|
||||||
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
|
|
||||||
t2 = time.time()
|
|
||||||
print(f'rgb len = {len(rgb_data)}')
|
|
||||||
|
|
||||||
# 写出
|
|
||||||
fd_mask = os.open(mask_fifo_path, os.O_WRONLY)
|
|
||||||
os.write(fd_mask, mask_result.tobytes())
|
|
||||||
os.close(fd_mask)
|
|
||||||
t3 = time.time()
|
|
||||||
print(f'total time is:{t3 - t1}')
|
|
||||||
i = 0
|
|
||||||
print("Stop Serving")
|
|
||||||
for img in img_list:
|
|
||||||
print(f"writing img {i}...")
|
|
||||||
cv2.imwrite(f"./{i}.png", img[0][..., ::-1])
|
|
||||||
np.save(f'./{i}.npy', img[1])
|
|
||||||
i += 1
|
|
||||||
print("save success")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@ -183,3 +113,5 @@ if __name__ == '__main__':
|
|||||||
|
|
||||||
# 主函数
|
# 主函数
|
||||||
main()
|
main()
|
||||||
|
# read_c_captures('/home/lzy/2022.7.15/tobacco_v1_0/', no_mask=True, nrows=256, ncols=1024,
|
||||||
|
# selected_bands=[380, 300, 200])
|
||||||
|
|||||||
@ -5,10 +5,12 @@
|
|||||||
# @Software:PyCharm、
|
# @Software:PyCharm、
|
||||||
import datetime
|
import datetime
|
||||||
import pickle
|
import pickle
|
||||||
|
from queue import Queue
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy.io
|
import scipy.io
|
||||||
|
import threading
|
||||||
from scipy.ndimage import binary_dilation
|
from scipy.ndimage import binary_dilation
|
||||||
from sklearn.tree import DecisionTreeClassifier
|
from sklearn.tree import DecisionTreeClassifier
|
||||||
from sklearn.metrics import classification_report
|
from sklearn.metrics import classification_report
|
||||||
@ -17,6 +19,7 @@ from sklearn.model_selection import train_test_split
|
|||||||
from config import Config
|
from config import Config
|
||||||
from utils import lab_scatter, read_labeled_img, size_threshold
|
from utils import lab_scatter, read_labeled_img, size_threshold
|
||||||
|
|
||||||
|
|
||||||
deploy = True
|
deploy = True
|
||||||
if not deploy:
|
if not deploy:
|
||||||
print("Training env")
|
print("Training env")
|
||||||
@ -415,7 +418,7 @@ class SpecDetector(Detector):
|
|||||||
# 烟梗mask中将背景赋值为0,将烟梗赋值为2
|
# 烟梗mask中将背景赋值为0,将烟梗赋值为2
|
||||||
yellow_things[yellow_things] = tobacco
|
yellow_things[yellow_things] = tobacco
|
||||||
yellow_things = yellow_things + 0
|
yellow_things = yellow_things + 0
|
||||||
# yellow_things = binary_dilation(yellow_things, iterations=iteration)
|
yellow_things = binary_dilation(yellow_things, iterations=iteration)
|
||||||
yellow_things = yellow_things + 0
|
yellow_things = yellow_things + 0
|
||||||
yellow_things[yellow_things == 1] = 2
|
yellow_things[yellow_things == 1] = 2
|
||||||
|
|
||||||
|
|||||||
236
transmit.py
236
transmit.py
@ -1,11 +1,18 @@
|
|||||||
import os
|
import os
|
||||||
import threading
|
import threading
|
||||||
from queue import Queue
|
import time
|
||||||
|
|
||||||
|
from utils import ImgQueue as Queue
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from config import Config
|
from config import Config
|
||||||
|
from models import SpecDetector, RgbDetector
|
||||||
|
import typing
|
||||||
|
import logging
|
||||||
|
logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s %(message)s',
|
||||||
|
level=logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
class Receiver(object):
|
class Transmitter(object):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.output = None
|
self.output = None
|
||||||
|
|
||||||
@ -45,17 +52,49 @@ class Receiver(object):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class FifoReceiver(Receiver):
|
class BeforeAfterMethods:
|
||||||
def __init__(self, fifo_path: str, output: Queue, read_max_num: int):
|
@classmethod
|
||||||
|
def mask_preprocess(cls, mask: np.ndarray):
|
||||||
|
logging.info(f"Send mask with size {mask.shape}")
|
||||||
|
return mask.tobytes()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def spec_data_post_process(cls, data):
|
||||||
|
if len(data) < 3:
|
||||||
|
threshold = int(float(data))
|
||||||
|
logging.info(f"Get Spec threshold: {threshold}")
|
||||||
|
return threshold
|
||||||
|
else:
|
||||||
|
spec_img = np.frombuffer(data, dtype=np.float32).\
|
||||||
|
reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1)
|
||||||
|
logging.info(f"Get SPEC image with size {spec_img.shape}")
|
||||||
|
return spec_img
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def rgb_data_post_process(cls, data):
|
||||||
|
if len(data) < 3:
|
||||||
|
threshold = int(float(data))
|
||||||
|
logging.info(f"Get RGB threshold: {threshold}")
|
||||||
|
return threshold
|
||||||
|
else:
|
||||||
|
rgb_img = np.frombuffer(data, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
|
||||||
|
logging.info(f"Get RGB img with size {rgb_img.shape}")
|
||||||
|
return rgb_img
|
||||||
|
|
||||||
|
|
||||||
|
class FifoReceiver(Transmitter):
|
||||||
|
def __init__(self, fifo_path: str, output: Queue, read_max_num: int, msg_queue=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._input_fifo_path = None
|
self._input_fifo_path = None
|
||||||
self._output_queue = None
|
self._output_queue = None
|
||||||
|
self._msg_queue = msg_queue
|
||||||
self._max_len = read_max_num
|
self._max_len = read_max_num
|
||||||
|
|
||||||
self.set_source(fifo_path)
|
self.set_source(fifo_path)
|
||||||
self.set_output(output)
|
self.set_output(output)
|
||||||
self._need_stop = threading.Event()
|
self._need_stop = threading.Event()
|
||||||
self._need_stop.clear()
|
self._need_stop.clear()
|
||||||
|
self._running_thread = None
|
||||||
|
|
||||||
def set_source(self, fifo_path: str):
|
def set_source(self, fifo_path: str):
|
||||||
if not os.access(fifo_path, os.F_OK):
|
if not os.access(fifo_path, os.F_OK):
|
||||||
@ -66,9 +105,9 @@ class FifoReceiver(Receiver):
|
|||||||
self._output_queue = output
|
self._output_queue = output
|
||||||
|
|
||||||
def start(self, post_process_func=None, name='fifo_receiver'):
|
def start(self, post_process_func=None, name='fifo_receiver'):
|
||||||
x = threading.Thread(target=self._receive_thread_func,
|
self._running_thread = threading.Thread(target=self._receive_thread_func,
|
||||||
name=name, args=(post_process_func, ))
|
name=name, args=(post_process_func, ))
|
||||||
x.start()
|
self._running_thread.start()
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
self._need_stop.set()
|
self._need_stop.set()
|
||||||
@ -85,27 +124,170 @@ class FifoReceiver(Receiver):
|
|||||||
data = os.read(input_fifo, self._max_len)
|
data = os.read(input_fifo, self._max_len)
|
||||||
if post_process_func is not None:
|
if post_process_func is not None:
|
||||||
data = post_process_func(data)
|
data = post_process_func(data)
|
||||||
self._output_queue.put(data)
|
self._output_queue.safe_put(data)
|
||||||
os.close(input_fifo)
|
os.close(input_fifo)
|
||||||
self._need_stop.clear()
|
self._need_stop.clear()
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def spec_data_post_process(data):
|
|
||||||
if len(data) < 3:
|
|
||||||
threshold = int(float(data))
|
|
||||||
print("[INFO] Get Spec threshold: ", threshold)
|
|
||||||
return threshold
|
|
||||||
else:
|
|
||||||
spec_img = np.frombuffer(data, dtype=np.float32).\
|
|
||||||
reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1)
|
|
||||||
return spec_img
|
|
||||||
|
|
||||||
@staticmethod
|
class FifoSender(Transmitter):
|
||||||
def rgb_data_post_process(data):
|
def __init__(self, output_fifo_path: str, source: Queue):
|
||||||
if len(data) < 3:
|
super().__init__()
|
||||||
threshold = int(float(data))
|
self._input_source = None
|
||||||
print("[INFO] Get RGB threshold: ", threshold)
|
self._output_fifo_path = None
|
||||||
return threshold
|
self.set_source(source)
|
||||||
else:
|
self.set_output(output_fifo_path)
|
||||||
rgb_img = np.frombuffer(data, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
|
self._need_stop = threading.Event()
|
||||||
return rgb_img
|
self._need_stop.clear()
|
||||||
|
self._running_thread = None
|
||||||
|
|
||||||
|
def set_source(self, source: Queue):
|
||||||
|
self._input_source = source
|
||||||
|
|
||||||
|
def set_output(self, output_fifo_path: str):
|
||||||
|
if not os.access(output_fifo_path, os.F_OK):
|
||||||
|
os.mkfifo(output_fifo_path, 0o777)
|
||||||
|
self._output_fifo_path = output_fifo_path
|
||||||
|
|
||||||
|
def start(self, pre_process=None, name='fifo_receiver'):
|
||||||
|
self._running_thread = threading.Thread(target=self._send_thread_func, name=name,
|
||||||
|
args=(pre_process, ))
|
||||||
|
self._running_thread.start()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self._need_stop.set()
|
||||||
|
|
||||||
|
def _send_thread_func(self, pre_process=None):
|
||||||
|
"""
|
||||||
|
接收线程
|
||||||
|
|
||||||
|
:param pre_process:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
while not self._need_stop.is_set():
|
||||||
|
if self._input_source.empty():
|
||||||
|
continue
|
||||||
|
data = self._input_source.get()
|
||||||
|
if pre_process is not None:
|
||||||
|
data = pre_process(data)
|
||||||
|
output_fifo = os.open(self._output_fifo_path, os.O_WRONLY)
|
||||||
|
os.write(output_fifo, data)
|
||||||
|
os.close(output_fifo)
|
||||||
|
self._need_stop.clear()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
self.stop()
|
||||||
|
if self._running_thread is not None:
|
||||||
|
self._running_thread.join()
|
||||||
|
|
||||||
|
|
||||||
|
class CmdImgSplitMidware(Transmitter):
|
||||||
|
"""
|
||||||
|
用于控制命令和图像的中间件
|
||||||
|
"""
|
||||||
|
def __init__(self, subscribers: typing.Dict[str, Queue], rgb_queue: Queue, spec_queue: Queue):
|
||||||
|
super().__init__()
|
||||||
|
self._rgb_queue = None
|
||||||
|
self._spec_queue = None
|
||||||
|
self._subscribers = None
|
||||||
|
self._server_thread = None
|
||||||
|
self.set_source(rgb_queue, spec_queue)
|
||||||
|
self.set_output(subscribers)
|
||||||
|
self.thread_stop = threading.Event()
|
||||||
|
|
||||||
|
def set_source(self, rgb_queue: Queue, spec_queue: Queue):
|
||||||
|
self._rgb_queue = rgb_queue
|
||||||
|
self._spec_queue = spec_queue
|
||||||
|
|
||||||
|
def set_output(self, output: typing.Dict[str, Queue]):
|
||||||
|
self._subscribers = output
|
||||||
|
|
||||||
|
def start(self, name='CMD_thread'):
|
||||||
|
self._server_thread = threading.Thread(target=self._cmd_control_service, name=name)
|
||||||
|
self._server_thread.start()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
self.thread_stop.set()
|
||||||
|
|
||||||
|
def _cmd_control_service(self):
|
||||||
|
while not self.thread_stop.is_set():
|
||||||
|
# 判断是否有数据,如果没有数据那么就等下次吧,如果有数据来,必须保证同时
|
||||||
|
if self._rgb_queue.empty() or self._spec_queue.empty():
|
||||||
|
continue
|
||||||
|
rgb_data = self._rgb_queue.get()
|
||||||
|
spec_data = self._spec_queue.get()
|
||||||
|
if isinstance(rgb_data, int) and isinstance(spec_data, int):
|
||||||
|
# 看是不是命令需要执行如果是命令,就执行
|
||||||
|
Config.rgb_size_threshold = rgb_data
|
||||||
|
Config.spec_size_threshold = spec_data
|
||||||
|
continue
|
||||||
|
elif isinstance(spec_data, np.ndarray) and isinstance(rgb_data, np.ndarray):
|
||||||
|
# 如果是图片,交给预测的人
|
||||||
|
for name, subscriber in self._subscribers.items():
|
||||||
|
item = (spec_data, rgb_data)
|
||||||
|
subscriber.safe_put(item)
|
||||||
|
else:
|
||||||
|
# 否则程序出现毁灭性问题,立刻崩
|
||||||
|
raise Exception("两个相机传回的数据没有对上")
|
||||||
|
self.thread_stop.clear()
|
||||||
|
|
||||||
|
|
||||||
|
class ImageSaver(Transmitter):
|
||||||
|
"""
|
||||||
|
进行图片存储的中间件
|
||||||
|
"""
|
||||||
|
def set_source(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def start(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def stop(self, *args, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadDetector(Transmitter):
|
||||||
|
def __init__(self, input_queue: Queue, output_queue: Queue):
|
||||||
|
super().__init__()
|
||||||
|
self._input_queue, self._output_queue = input_queue, output_queue
|
||||||
|
self._spec_detector = SpecDetector(blk_model_path=Config.blk_model_path,
|
||||||
|
pixel_model_path=Config.pixel_model_path)
|
||||||
|
self._rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path,
|
||||||
|
background_model_path=Config.rgb_background_model_path)
|
||||||
|
self._predict_thread = None
|
||||||
|
self._thread_exit = threading.Event()
|
||||||
|
|
||||||
|
def set_source(self, img_queue: Queue):
|
||||||
|
self._input_queue = img_queue
|
||||||
|
|
||||||
|
def stop(self, *args, **kwargs):
|
||||||
|
self._thread_exit.set()
|
||||||
|
|
||||||
|
def start(self, name='predict_thread'):
|
||||||
|
self._predict_thread = threading.Thread(target=self._predict_server, name=name)
|
||||||
|
self._predict_thread.start()
|
||||||
|
|
||||||
|
def predict(self, spec: np.ndarray, rgb: np.ndarray):
|
||||||
|
logging.info(f'Detector get image with shape {spec.shape} and {rgb.shape}')
|
||||||
|
t1 = time.time()
|
||||||
|
mask = self._spec_detector.predict(spec)
|
||||||
|
t2 = time.time()
|
||||||
|
logging.info(f'Detector finish spec predict within {(t2 - t1) * 1000:.2f}ms')
|
||||||
|
# rgb识别
|
||||||
|
mask_rgb = self._rgb_detector.predict(rgb)
|
||||||
|
t3 = time.time()
|
||||||
|
logging.info(f'Detector finish rgb predict within {(t3 - t2) * 1000:.2f}ms')
|
||||||
|
# 结果合并
|
||||||
|
mask_result = (mask | mask_rgb).astype(np.uint8)
|
||||||
|
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
|
||||||
|
t4 = time.time()
|
||||||
|
logging.info(f'Detector finish merge within {(t4 - t3) * 1000: .2f}ms')
|
||||||
|
logging.info(f'Detector finish predict within {(time.time() -t1)*1000:.2f}ms')
|
||||||
|
return mask_result
|
||||||
|
|
||||||
|
def _predict_server(self):
|
||||||
|
while not self._thread_exit.is_set():
|
||||||
|
if not self._input_queue.empty():
|
||||||
|
spec, rgb = self._input_queue.get()
|
||||||
|
mask = self.predict(spec, rgb)
|
||||||
|
self._output_queue.safe_put(mask)
|
||||||
|
self._thread_exit.clear()
|
||||||
|
|||||||
29
utils.py
29
utils.py
@ -5,6 +5,7 @@
|
|||||||
# @Software:PyCharm
|
# @Software:PyCharm
|
||||||
import glob
|
import glob
|
||||||
import os
|
import os
|
||||||
|
from queue import Queue
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -26,6 +27,34 @@ class MergeDict(dict):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class ImgQueue(Queue):
|
||||||
|
"""
|
||||||
|
A custom queue subclass that provides a :meth:`clear` method.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
"""
|
||||||
|
Clears all items from the queue.
|
||||||
|
"""
|
||||||
|
|
||||||
|
with self.mutex:
|
||||||
|
unfinished = self.unfinished_tasks - len(self.queue)
|
||||||
|
if unfinished <= 0:
|
||||||
|
if unfinished < 0:
|
||||||
|
raise ValueError('task_done() called too many times')
|
||||||
|
self.all_tasks_done.notify_all()
|
||||||
|
self.unfinished_tasks = unfinished
|
||||||
|
self.queue.clear()
|
||||||
|
self.not_full.notify_all()
|
||||||
|
|
||||||
|
def safe_put(self, item):
|
||||||
|
if self.full():
|
||||||
|
_ = self.get()
|
||||||
|
return False
|
||||||
|
self.put(item)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
def read_labeled_img(dataset_dir: str, color_dict: dict, ext='.bmp', is_ps_color_space=True) -> dict:
|
def read_labeled_img(dataset_dir: str, color_dict: dict, ext='.bmp', is_ps_color_space=True) -> dict:
|
||||||
"""
|
"""
|
||||||
根据dataset_dir下的文件创建数据集
|
根据dataset_dir下的文件创建数据集
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user