This commit is contained in:
FEIJINTI 2022-07-28 14:58:23 +08:00
commit cdbe341591
6 changed files with 331 additions and 134 deletions

View File

@ -20,7 +20,7 @@ class Config:
# 光谱模型参数 # 光谱模型参数
blk_size = 4 blk_size = 4
pixel_model_path = r"./models/dt.p" pixel_model_path = r"./models/dt.p"
blk_model_path = r"./models/rf_4x4_c22_20_sen8_8.model" blk_model_path = r"./models/rf_4x4_c22_20_sen8_9.model"
spec_size_threshold = 3 spec_size_threshold = 3
# rgb模型参数 # rgb模型参数

51
efficient_ui.py Normal file
View File

@ -0,0 +1,51 @@
import cv2
import numpy as np
import transmit
from config import Config
from utils import ImgQueue as Queue
class EfficientUI(object):
def __init__(self):
# 相关参数
img_fifo_path = "/tmp/dkimg.fifo"
mask_fifo_path = "/tmp/dkmask.fifo"
rgb_fifo_path = "/tmp/dkrgb.fifo"
# 创建队列用于链接各个线程
rgb_img_queue, spec_img_queue = Queue(), Queue()
detector_queue, save_queue, self.visual_queue = Queue(), Queue(), Queue()
mask_queue = Queue()
# 两个接收者,接收光谱和rgb图像
spec_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节
rgb_len = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量
spec_receiver = transmit.FifoReceiver(fifo_path=img_fifo_path, output=spec_img_queue, read_max_num=spec_len)
rgb_receiver = transmit.FifoReceiver(fifo_path=rgb_fifo_path, output=rgb_img_queue, read_max_num=rgb_len)
# 指令执行与图像流向控制
subscribers = {'detector': detector_queue, 'visualize': self.visual_queue, 'save': save_queue}
cmd_img_controller = transmit.CmdImgSplitMidware(rgb_queue=rgb_img_queue, spec_queue=spec_img_queue,
subscribers=subscribers)
# 探测器
detector = transmit.ThreadDetector(input_queue=detector_queue, output_queue=mask_queue)
# 发送
sender = transmit.FifoSender(output_fifo_path=mask_fifo_path, source=mask_queue)
# 启动所有线程
spec_receiver.start(post_process_func=transmit.BeforeAfterMethods.spec_data_post_process, name='spce_thread')
rgb_receiver.start(post_process_func=transmit.BeforeAfterMethods.rgb_data_post_process, name='rgb_thread')
cmd_img_controller.start(name='control_thread')
detector.start(name='detector_thread')
sender.start(pre_process=transmit.BeforeAfterMethods.mask_preprocess, name='sender_thread')
def start(self):
# 启动图形化
while True:
cv2.imshow('image_show', mat=np.ones((256, 1024)))
key_code = cv2.waitKey(30)
if key_code == ord("s"):
pass
if __name__ == '__main__':
app = EfficientUI()
app.start()

142
main.py
View File

@ -1,16 +1,18 @@
import os import os
import time import time
from queue import Queue
import numpy as np import numpy as np
import scipy.io from matplotlib import pyplot as plt
import models
import transmit
from config import Config from config import Config
from models import RgbDetector, SpecDetector, ManualTree, AnonymousColorDetector from models import RgbDetector, SpecDetector
import cv2 import cv2
SAVE_IMG, SAVE_NUM = False, 30
def main(): def main():
spec_detector = SpecDetector(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path) spec_detector = SpecDetector(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path)
rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path, rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path,
@ -23,8 +25,6 @@ def main():
os.mkfifo(mask_fifo_path, 0o777) os.mkfifo(mask_fifo_path, 0o777)
if not os.access(rgb_fifo_path, os.F_OK): if not os.access(rgb_fifo_path, os.F_OK):
os.mkfifo(rgb_fifo_path, 0o777) os.mkfifo(rgb_fifo_path, 0o777)
if SAVE_IMG:
img_list = []
while True: while True:
fd_img = os.open(img_fifo_path, os.O_RDONLY) fd_img = os.open(img_fifo_path, os.O_RDONLY)
fd_rgb = os.open(rgb_fifo_path, os.O_RDONLY) fd_rgb = os.open(rgb_fifo_path, os.O_RDONLY)
@ -55,19 +55,12 @@ def main():
img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)) \ img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)) \
.transpose(0, 2, 1) .transpose(0, 2, 1)
rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1)) rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
if SAVE_IMG:
SAVE_NUM -= 1
img_list.append((rgb_data, img_data))
if SAVE_NUM <= 0:
break
# 光谱识别 # 光谱识别
mask = spec_detector.predict(img_data) mask = spec_detector.predict(img_data)
# rgb识别 # rgb识别
mask_rgb = rgb_detector.predict(rgb_data) mask_rgb = rgb_detector.predict(rgb_data)
# 结果合并 # 结果合并
mask_result = (mask | mask_rgb).astype(np.uint8) mask_result = (mask | mask_rgb).astype(np.uint8)
# mask_result = mask_rgb.astype(np.uint8) # mask_result = mask_rgb.astype(np.uint8)
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8) mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
t2 = time.time() t2 = time.time()
@ -79,100 +72,37 @@ def main():
os.close(fd_mask) os.close(fd_mask)
t3 = time.time() t3 = time.time()
print(f'total time is:{t3 - t1}') print(f'total time is:{t3 - t1}')
for i, img in enumerate(img_list):
print(f"writing img {i}...")
cv2.imwrite(f"./{i}.png", img[0][..., ::-1])
np.save(f'./{i}.npy', img[1])
i += 1
def read_c_captures(buffer_path, no_mask=True, nrows=256, ncols=1024, selected_bands=None):
def save_main(): if os.path.isdir(buffer_path):
threshold = Config.spec_size_threshold buffer_names = [buffer_name for buffer_name in os.listdir(buffer_path) if buffer_name.endswith('.raw')]
rgb_threshold = Config.rgb_size_threshold else:
manual_tree = ManualTree(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path) buffer_names = [buffer_path, ]
tobacco_detector = AnonymousColorDetector(file_path=Config.rgb_tobacco_model_path) for buffer_name in buffer_names:
background_detector = AnonymousColorDetector(file_path=Config.rgb_background_model_path) with open(os.path.join(buffer_path, buffer_name), 'rb') as f:
total_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节 data = f.read()
total_rgb = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量 img = np.frombuffer(data, dtype=np.float32).reshape((nrows, -1, ncols)) \
if not os.access(img_fifo_path, os.F_OK): .transpose(0, 2, 1)
os.mkfifo(img_fifo_path, 0o777) if selected_bands is not None:
if not os.access(mask_fifo_path, os.F_OK): img = img[..., selected_bands]
os.mkfifo(mask_fifo_path, 0o777) if img.shape[0] == 1:
if not os.access(rgb_fifo_path, os.F_OK): img = img[0, ...]
os.mkfifo(rgb_fifo_path, 0o777) if not no_mask:
img_list = [] mask_name = buffer_name.replace('buf', 'mask').replace('.raw', '')
idx = 0 with open(os.path.join(buffer_path, mask_name), 'rb') as f:
while idx <= 30: data = f.read()
idx += 1 mask = np.frombuffer(data, dtype=np.uint8).reshape((nrows, ncols, -1))
fd_img = os.open(img_fifo_path, os.O_RDONLY)
fd_rgb = os.open(rgb_fifo_path, os.O_RDONLY)
data = os.read(fd_img, total_len)
# 读取(开启一个管道)
if len(data) < 3:
threshold = int(float(data))
print("[INFO] Get threshold: ", threshold)
continue
else: else:
data_total = data mask_name = "no mask"
rgb_data = os.read(fd_rgb, total_rgb) mask = np.zeros_like(img)
if len(rgb_data) < 3: # mask = cv2.resize(mask, (1024, 256))
rgb_threshold = int(float(rgb_data)) fig, axs = plt.subplots(2, 1)
print(rgb_threshold) axs[0].matshow(img)
continue axs[0].set_title(buffer_name)
else: axs[1].imshow(mask)
rgb_data_total = rgb_data axs[1].set_title(mask_name)
os.close(fd_img) plt.show()
os.close(fd_rgb)
# 识别
t1 = time.time()
img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)). \
transpose(0, 2, 1)
rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
img_list.append((rgb_data.copy(), img_data.copy()))
pixel_predict_result = manual_tree.pixel_predict_ml_dilation(data=img_data, iteration=1)
blk_predict_result = manual_tree.blk_predict(data=img_data)
rgb_data = tobacco_detector.pretreatment(rgb_data)
# print(rgb_data.shape)
rgb_predict_result = 1 - (background_detector.predict(rgb_data, threshold_low=Config.threshold_low,
threshold_high=Config.threshold_high) |
tobacco_detector.swell(tobacco_detector.predict(rgb_data,
threshold_low=Config.threshold_low,
threshold_high=Config.threshold_high)))
mask_rgb = rgb_predict_result.reshape(Config.nRows, Config.nCols // Config.blk_size, Config.blk_size) \
.sum(axis=2).reshape(Config.nRows // 4, Config.blk_size, Config.nCols // Config.blk_size) \
.sum(axis=1)
mask_rgb[mask_rgb <= rgb_threshold] = 0
mask_rgb[mask_rgb > rgb_threshold] = 1
mask = (pixel_predict_result & blk_predict_result).astype(np.uint8)
mask = mask.reshape(Config.nRows, Config.nCols // Config.blk_size, Config.blk_size) \
.sum(axis=2).reshape(Config.nRows // 4, Config.blk_size, Config.nCols // Config.blk_size) \
.sum(axis=1)
mask[mask <= threshold] = 0
mask[mask > threshold] = 1
mask_result = (mask | mask_rgb).astype(np.uint8)
# mask_result = mask_rgb
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
t2 = time.time()
print(f'rgb len = {len(rgb_data)}')
# 写出
fd_mask = os.open(mask_fifo_path, os.O_WRONLY)
os.write(fd_mask, mask_result.tobytes())
os.close(fd_mask)
t3 = time.time()
print(f'total time is:{t3 - t1}')
i = 0
print("Stop Serving")
for img in img_list:
print(f"writing img {i}...")
cv2.imwrite(f"./{i}.png", img[0][..., ::-1])
np.save(f'./{i}.npy', img[1])
i += 1
print("save success")
if __name__ == '__main__': if __name__ == '__main__':
@ -183,3 +113,5 @@ if __name__ == '__main__':
# 主函数 # 主函数
main() main()
# read_c_captures('/home/lzy/2022.7.15/tobacco_v1_0/', no_mask=True, nrows=256, ncols=1024,
# selected_bands=[380, 300, 200])

View File

@ -5,10 +5,12 @@
# @Software:PyCharm、 # @Software:PyCharm、
import datetime import datetime
import pickle import pickle
from queue import Queue
import cv2 import cv2
import numpy as np import numpy as np
import scipy.io import scipy.io
import threading
from scipy.ndimage import binary_dilation from scipy.ndimage import binary_dilation
from sklearn.tree import DecisionTreeClassifier from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report from sklearn.metrics import classification_report
@ -17,6 +19,7 @@ from sklearn.model_selection import train_test_split
from config import Config from config import Config
from utils import lab_scatter, read_labeled_img, size_threshold from utils import lab_scatter, read_labeled_img, size_threshold
deploy = True deploy = True
if not deploy: if not deploy:
print("Training env") print("Training env")
@ -415,7 +418,7 @@ class SpecDetector(Detector):
# 烟梗mask中将背景赋值为0,将烟梗赋值为2 # 烟梗mask中将背景赋值为0,将烟梗赋值为2
yellow_things[yellow_things] = tobacco yellow_things[yellow_things] = tobacco
yellow_things = yellow_things + 0 yellow_things = yellow_things + 0
# yellow_things = binary_dilation(yellow_things, iterations=iteration) yellow_things = binary_dilation(yellow_things, iterations=iteration)
yellow_things = yellow_things + 0 yellow_things = yellow_things + 0
yellow_things[yellow_things == 1] = 2 yellow_things[yellow_things == 1] = 2

View File

@ -1,11 +1,18 @@
import os import os
import threading import threading
from queue import Queue import time
from utils import ImgQueue as Queue
import numpy as np import numpy as np
from config import Config from config import Config
from models import SpecDetector, RgbDetector
import typing
import logging
logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s %(message)s',
level=logging.DEBUG)
class Receiver(object): class Transmitter(object):
def __init__(self): def __init__(self):
self.output = None self.output = None
@ -45,17 +52,49 @@ class Receiver(object):
raise NotImplementedError raise NotImplementedError
class FifoReceiver(Receiver): class BeforeAfterMethods:
def __init__(self, fifo_path: str, output: Queue, read_max_num: int): @classmethod
def mask_preprocess(cls, mask: np.ndarray):
logging.info(f"Send mask with size {mask.shape}")
return mask.tobytes()
@classmethod
def spec_data_post_process(cls, data):
if len(data) < 3:
threshold = int(float(data))
logging.info(f"Get Spec threshold: {threshold}")
return threshold
else:
spec_img = np.frombuffer(data, dtype=np.float32).\
reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1)
logging.info(f"Get SPEC image with size {spec_img.shape}")
return spec_img
@classmethod
def rgb_data_post_process(cls, data):
if len(data) < 3:
threshold = int(float(data))
logging.info(f"Get RGB threshold: {threshold}")
return threshold
else:
rgb_img = np.frombuffer(data, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
logging.info(f"Get RGB img with size {rgb_img.shape}")
return rgb_img
class FifoReceiver(Transmitter):
def __init__(self, fifo_path: str, output: Queue, read_max_num: int, msg_queue=None):
super().__init__() super().__init__()
self._input_fifo_path = None self._input_fifo_path = None
self._output_queue = None self._output_queue = None
self._msg_queue = msg_queue
self._max_len = read_max_num self._max_len = read_max_num
self.set_source(fifo_path) self.set_source(fifo_path)
self.set_output(output) self.set_output(output)
self._need_stop = threading.Event() self._need_stop = threading.Event()
self._need_stop.clear() self._need_stop.clear()
self._running_thread = None
def set_source(self, fifo_path: str): def set_source(self, fifo_path: str):
if not os.access(fifo_path, os.F_OK): if not os.access(fifo_path, os.F_OK):
@ -66,9 +105,9 @@ class FifoReceiver(Receiver):
self._output_queue = output self._output_queue = output
def start(self, post_process_func=None, name='fifo_receiver'): def start(self, post_process_func=None, name='fifo_receiver'):
x = threading.Thread(target=self._receive_thread_func, self._running_thread = threading.Thread(target=self._receive_thread_func,
name=name, args=(post_process_func, )) name=name, args=(post_process_func, ))
x.start() self._running_thread.start()
def stop(self): def stop(self):
self._need_stop.set() self._need_stop.set()
@ -85,27 +124,170 @@ class FifoReceiver(Receiver):
data = os.read(input_fifo, self._max_len) data = os.read(input_fifo, self._max_len)
if post_process_func is not None: if post_process_func is not None:
data = post_process_func(data) data = post_process_func(data)
self._output_queue.put(data) self._output_queue.safe_put(data)
os.close(input_fifo) os.close(input_fifo)
self._need_stop.clear() self._need_stop.clear()
@staticmethod
def spec_data_post_process(data):
if len(data) < 3:
threshold = int(float(data))
print("[INFO] Get Spec threshold: ", threshold)
return threshold
else:
spec_img = np.frombuffer(data, dtype=np.float32).\
reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1)
return spec_img
@staticmethod class FifoSender(Transmitter):
def rgb_data_post_process(data): def __init__(self, output_fifo_path: str, source: Queue):
if len(data) < 3: super().__init__()
threshold = int(float(data)) self._input_source = None
print("[INFO] Get RGB threshold: ", threshold) self._output_fifo_path = None
return threshold self.set_source(source)
else: self.set_output(output_fifo_path)
rgb_img = np.frombuffer(data, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1)) self._need_stop = threading.Event()
return rgb_img self._need_stop.clear()
self._running_thread = None
def set_source(self, source: Queue):
self._input_source = source
def set_output(self, output_fifo_path: str):
if not os.access(output_fifo_path, os.F_OK):
os.mkfifo(output_fifo_path, 0o777)
self._output_fifo_path = output_fifo_path
def start(self, pre_process=None, name='fifo_receiver'):
self._running_thread = threading.Thread(target=self._send_thread_func, name=name,
args=(pre_process, ))
self._running_thread.start()
def stop(self):
self._need_stop.set()
def _send_thread_func(self, pre_process=None):
"""
接收线程
:param pre_process:
:return:
"""
while not self._need_stop.is_set():
if self._input_source.empty():
continue
data = self._input_source.get()
if pre_process is not None:
data = pre_process(data)
output_fifo = os.open(self._output_fifo_path, os.O_WRONLY)
os.write(output_fifo, data)
os.close(output_fifo)
self._need_stop.clear()
def __del__(self):
self.stop()
if self._running_thread is not None:
self._running_thread.join()
class CmdImgSplitMidware(Transmitter):
"""
用于控制命令和图像的中间件
"""
def __init__(self, subscribers: typing.Dict[str, Queue], rgb_queue: Queue, spec_queue: Queue):
super().__init__()
self._rgb_queue = None
self._spec_queue = None
self._subscribers = None
self._server_thread = None
self.set_source(rgb_queue, spec_queue)
self.set_output(subscribers)
self.thread_stop = threading.Event()
def set_source(self, rgb_queue: Queue, spec_queue: Queue):
self._rgb_queue = rgb_queue
self._spec_queue = spec_queue
def set_output(self, output: typing.Dict[str, Queue]):
self._subscribers = output
def start(self, name='CMD_thread'):
self._server_thread = threading.Thread(target=self._cmd_control_service, name=name)
self._server_thread.start()
def stop(self):
self.thread_stop.set()
def _cmd_control_service(self):
while not self.thread_stop.is_set():
# 判断是否有数据,如果没有数据那么就等下次吧,如果有数据来,必须保证同时
if self._rgb_queue.empty() or self._spec_queue.empty():
continue
rgb_data = self._rgb_queue.get()
spec_data = self._spec_queue.get()
if isinstance(rgb_data, int) and isinstance(spec_data, int):
# 看是不是命令需要执行如果是命令,就执行
Config.rgb_size_threshold = rgb_data
Config.spec_size_threshold = spec_data
continue
elif isinstance(spec_data, np.ndarray) and isinstance(rgb_data, np.ndarray):
# 如果是图片,交给预测的人
for name, subscriber in self._subscribers.items():
item = (spec_data, rgb_data)
subscriber.safe_put(item)
else:
# 否则程序出现毁灭性问题,立刻崩
raise Exception("两个相机传回的数据没有对上")
self.thread_stop.clear()
class ImageSaver(Transmitter):
"""
进行图片存储的中间件
"""
def set_source(self, *args, **kwargs):
pass
def start(self, *args, **kwargs):
pass
def stop(self, *args, **kwargs):
pass
class ThreadDetector(Transmitter):
def __init__(self, input_queue: Queue, output_queue: Queue):
super().__init__()
self._input_queue, self._output_queue = input_queue, output_queue
self._spec_detector = SpecDetector(blk_model_path=Config.blk_model_path,
pixel_model_path=Config.pixel_model_path)
self._rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path,
background_model_path=Config.rgb_background_model_path)
self._predict_thread = None
self._thread_exit = threading.Event()
def set_source(self, img_queue: Queue):
self._input_queue = img_queue
def stop(self, *args, **kwargs):
self._thread_exit.set()
def start(self, name='predict_thread'):
self._predict_thread = threading.Thread(target=self._predict_server, name=name)
self._predict_thread.start()
def predict(self, spec: np.ndarray, rgb: np.ndarray):
logging.info(f'Detector get image with shape {spec.shape} and {rgb.shape}')
t1 = time.time()
mask = self._spec_detector.predict(spec)
t2 = time.time()
logging.info(f'Detector finish spec predict within {(t2 - t1) * 1000:.2f}ms')
# rgb识别
mask_rgb = self._rgb_detector.predict(rgb)
t3 = time.time()
logging.info(f'Detector finish rgb predict within {(t3 - t2) * 1000:.2f}ms')
# 结果合并
mask_result = (mask | mask_rgb).astype(np.uint8)
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
t4 = time.time()
logging.info(f'Detector finish merge within {(t4 - t3) * 1000: .2f}ms')
logging.info(f'Detector finish predict within {(time.time() -t1)*1000:.2f}ms')
return mask_result
def _predict_server(self):
while not self._thread_exit.is_set():
if not self._input_queue.empty():
spec, rgb = self._input_queue.get()
mask = self.predict(spec, rgb)
self._output_queue.safe_put(mask)
self._thread_exit.clear()

View File

@ -5,6 +5,7 @@
# @Software:PyCharm # @Software:PyCharm
import glob import glob
import os import os
from queue import Queue
import cv2 import cv2
import numpy as np import numpy as np
@ -26,6 +27,34 @@ class MergeDict(dict):
return self return self
class ImgQueue(Queue):
"""
A custom queue subclass that provides a :meth:`clear` method.
"""
def clear(self):
"""
Clears all items from the queue.
"""
with self.mutex:
unfinished = self.unfinished_tasks - len(self.queue)
if unfinished <= 0:
if unfinished < 0:
raise ValueError('task_done() called too many times')
self.all_tasks_done.notify_all()
self.unfinished_tasks = unfinished
self.queue.clear()
self.not_full.notify_all()
def safe_put(self, item):
if self.full():
_ = self.get()
return False
self.put(item)
return True
def read_labeled_img(dataset_dir: str, color_dict: dict, ext='.bmp', is_ps_color_space=True) -> dict: def read_labeled_img(dataset_dir: str, color_dict: dict, ext='.bmp', is_ps_color_space=True) -> dict:
""" """
根据dataset_dir下的文件创建数据集 根据dataset_dir下的文件创建数据集