This commit is contained in:
FEIJINTI 2022-07-28 14:58:23 +08:00
commit cdbe341591
6 changed files with 331 additions and 134 deletions

View File

@ -20,7 +20,7 @@ class Config:
# 光谱模型参数
blk_size = 4
pixel_model_path = r"./models/dt.p"
blk_model_path = r"./models/rf_4x4_c22_20_sen8_8.model"
blk_model_path = r"./models/rf_4x4_c22_20_sen8_9.model"
spec_size_threshold = 3
# rgb模型参数

51
efficient_ui.py Normal file
View File

@ -0,0 +1,51 @@
import cv2
import numpy as np
import transmit
from config import Config
from utils import ImgQueue as Queue
class EfficientUI(object):
def __init__(self):
# 相关参数
img_fifo_path = "/tmp/dkimg.fifo"
mask_fifo_path = "/tmp/dkmask.fifo"
rgb_fifo_path = "/tmp/dkrgb.fifo"
# 创建队列用于链接各个线程
rgb_img_queue, spec_img_queue = Queue(), Queue()
detector_queue, save_queue, self.visual_queue = Queue(), Queue(), Queue()
mask_queue = Queue()
# 两个接收者,接收光谱和rgb图像
spec_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节
rgb_len = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量
spec_receiver = transmit.FifoReceiver(fifo_path=img_fifo_path, output=spec_img_queue, read_max_num=spec_len)
rgb_receiver = transmit.FifoReceiver(fifo_path=rgb_fifo_path, output=rgb_img_queue, read_max_num=rgb_len)
# 指令执行与图像流向控制
subscribers = {'detector': detector_queue, 'visualize': self.visual_queue, 'save': save_queue}
cmd_img_controller = transmit.CmdImgSplitMidware(rgb_queue=rgb_img_queue, spec_queue=spec_img_queue,
subscribers=subscribers)
# 探测器
detector = transmit.ThreadDetector(input_queue=detector_queue, output_queue=mask_queue)
# 发送
sender = transmit.FifoSender(output_fifo_path=mask_fifo_path, source=mask_queue)
# 启动所有线程
spec_receiver.start(post_process_func=transmit.BeforeAfterMethods.spec_data_post_process, name='spce_thread')
rgb_receiver.start(post_process_func=transmit.BeforeAfterMethods.rgb_data_post_process, name='rgb_thread')
cmd_img_controller.start(name='control_thread')
detector.start(name='detector_thread')
sender.start(pre_process=transmit.BeforeAfterMethods.mask_preprocess, name='sender_thread')
def start(self):
# 启动图形化
while True:
cv2.imshow('image_show', mat=np.ones((256, 1024)))
key_code = cv2.waitKey(30)
if key_code == ord("s"):
pass
if __name__ == '__main__':
app = EfficientUI()
app.start()

142
main.py
View File

@ -1,16 +1,18 @@
import os
import time
from queue import Queue
import numpy as np
import scipy.io
from matplotlib import pyplot as plt
import models
import transmit
from config import Config
from models import RgbDetector, SpecDetector, ManualTree, AnonymousColorDetector
from models import RgbDetector, SpecDetector
import cv2
SAVE_IMG, SAVE_NUM = False, 30
def main():
spec_detector = SpecDetector(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path)
rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path,
@ -23,8 +25,6 @@ def main():
os.mkfifo(mask_fifo_path, 0o777)
if not os.access(rgb_fifo_path, os.F_OK):
os.mkfifo(rgb_fifo_path, 0o777)
if SAVE_IMG:
img_list = []
while True:
fd_img = os.open(img_fifo_path, os.O_RDONLY)
fd_rgb = os.open(rgb_fifo_path, os.O_RDONLY)
@ -55,19 +55,12 @@ def main():
img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)) \
.transpose(0, 2, 1)
rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
if SAVE_IMG:
SAVE_NUM -= 1
img_list.append((rgb_data, img_data))
if SAVE_NUM <= 0:
break
# 光谱识别
mask = spec_detector.predict(img_data)
# rgb识别
mask_rgb = rgb_detector.predict(rgb_data)
# 结果合并
mask_result = (mask | mask_rgb).astype(np.uint8)
# mask_result = mask_rgb.astype(np.uint8)
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
t2 = time.time()
@ -79,100 +72,37 @@ def main():
os.close(fd_mask)
t3 = time.time()
print(f'total time is:{t3 - t1}')
for i, img in enumerate(img_list):
print(f"writing img {i}...")
cv2.imwrite(f"./{i}.png", img[0][..., ::-1])
np.save(f'./{i}.npy', img[1])
i += 1
def save_main():
threshold = Config.spec_size_threshold
rgb_threshold = Config.rgb_size_threshold
manual_tree = ManualTree(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path)
tobacco_detector = AnonymousColorDetector(file_path=Config.rgb_tobacco_model_path)
background_detector = AnonymousColorDetector(file_path=Config.rgb_background_model_path)
total_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节
total_rgb = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量
if not os.access(img_fifo_path, os.F_OK):
os.mkfifo(img_fifo_path, 0o777)
if not os.access(mask_fifo_path, os.F_OK):
os.mkfifo(mask_fifo_path, 0o777)
if not os.access(rgb_fifo_path, os.F_OK):
os.mkfifo(rgb_fifo_path, 0o777)
img_list = []
idx = 0
while idx <= 30:
idx += 1
fd_img = os.open(img_fifo_path, os.O_RDONLY)
fd_rgb = os.open(rgb_fifo_path, os.O_RDONLY)
data = os.read(fd_img, total_len)
# 读取(开启一个管道)
if len(data) < 3:
threshold = int(float(data))
print("[INFO] Get threshold: ", threshold)
continue
def read_c_captures(buffer_path, no_mask=True, nrows=256, ncols=1024, selected_bands=None):
if os.path.isdir(buffer_path):
buffer_names = [buffer_name for buffer_name in os.listdir(buffer_path) if buffer_name.endswith('.raw')]
else:
buffer_names = [buffer_path, ]
for buffer_name in buffer_names:
with open(os.path.join(buffer_path, buffer_name), 'rb') as f:
data = f.read()
img = np.frombuffer(data, dtype=np.float32).reshape((nrows, -1, ncols)) \
.transpose(0, 2, 1)
if selected_bands is not None:
img = img[..., selected_bands]
if img.shape[0] == 1:
img = img[0, ...]
if not no_mask:
mask_name = buffer_name.replace('buf', 'mask').replace('.raw', '')
with open(os.path.join(buffer_path, mask_name), 'rb') as f:
data = f.read()
mask = np.frombuffer(data, dtype=np.uint8).reshape((nrows, ncols, -1))
else:
data_total = data
rgb_data = os.read(fd_rgb, total_rgb)
if len(rgb_data) < 3:
rgb_threshold = int(float(rgb_data))
print(rgb_threshold)
continue
else:
rgb_data_total = rgb_data
os.close(fd_img)
os.close(fd_rgb)
# 识别
t1 = time.time()
img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)). \
transpose(0, 2, 1)
rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
img_list.append((rgb_data.copy(), img_data.copy()))
pixel_predict_result = manual_tree.pixel_predict_ml_dilation(data=img_data, iteration=1)
blk_predict_result = manual_tree.blk_predict(data=img_data)
rgb_data = tobacco_detector.pretreatment(rgb_data)
# print(rgb_data.shape)
rgb_predict_result = 1 - (background_detector.predict(rgb_data, threshold_low=Config.threshold_low,
threshold_high=Config.threshold_high) |
tobacco_detector.swell(tobacco_detector.predict(rgb_data,
threshold_low=Config.threshold_low,
threshold_high=Config.threshold_high)))
mask_rgb = rgb_predict_result.reshape(Config.nRows, Config.nCols // Config.blk_size, Config.blk_size) \
.sum(axis=2).reshape(Config.nRows // 4, Config.blk_size, Config.nCols // Config.blk_size) \
.sum(axis=1)
mask_rgb[mask_rgb <= rgb_threshold] = 0
mask_rgb[mask_rgb > rgb_threshold] = 1
mask = (pixel_predict_result & blk_predict_result).astype(np.uint8)
mask = mask.reshape(Config.nRows, Config.nCols // Config.blk_size, Config.blk_size) \
.sum(axis=2).reshape(Config.nRows // 4, Config.blk_size, Config.nCols // Config.blk_size) \
.sum(axis=1)
mask[mask <= threshold] = 0
mask[mask > threshold] = 1
mask_result = (mask | mask_rgb).astype(np.uint8)
# mask_result = mask_rgb
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
t2 = time.time()
print(f'rgb len = {len(rgb_data)}')
# 写出
fd_mask = os.open(mask_fifo_path, os.O_WRONLY)
os.write(fd_mask, mask_result.tobytes())
os.close(fd_mask)
t3 = time.time()
print(f'total time is:{t3 - t1}')
i = 0
print("Stop Serving")
for img in img_list:
print(f"writing img {i}...")
cv2.imwrite(f"./{i}.png", img[0][..., ::-1])
np.save(f'./{i}.npy', img[1])
i += 1
print("save success")
mask_name = "no mask"
mask = np.zeros_like(img)
# mask = cv2.resize(mask, (1024, 256))
fig, axs = plt.subplots(2, 1)
axs[0].matshow(img)
axs[0].set_title(buffer_name)
axs[1].imshow(mask)
axs[1].set_title(mask_name)
plt.show()
if __name__ == '__main__':
@ -183,3 +113,5 @@ if __name__ == '__main__':
# 主函数
main()
# read_c_captures('/home/lzy/2022.7.15/tobacco_v1_0/', no_mask=True, nrows=256, ncols=1024,
# selected_bands=[380, 300, 200])

View File

@ -5,10 +5,12 @@
# @Software:PyCharm、
import datetime
import pickle
from queue import Queue
import cv2
import numpy as np
import scipy.io
import threading
from scipy.ndimage import binary_dilation
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report
@ -17,6 +19,7 @@ from sklearn.model_selection import train_test_split
from config import Config
from utils import lab_scatter, read_labeled_img, size_threshold
deploy = True
if not deploy:
print("Training env")
@ -415,7 +418,7 @@ class SpecDetector(Detector):
# 烟梗mask中将背景赋值为0,将烟梗赋值为2
yellow_things[yellow_things] = tobacco
yellow_things = yellow_things + 0
# yellow_things = binary_dilation(yellow_things, iterations=iteration)
yellow_things = binary_dilation(yellow_things, iterations=iteration)
yellow_things = yellow_things + 0
yellow_things[yellow_things == 1] = 2

View File

@ -1,11 +1,18 @@
import os
import threading
from queue import Queue
import time
from utils import ImgQueue as Queue
import numpy as np
from config import Config
from models import SpecDetector, RgbDetector
import typing
import logging
logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s %(message)s',
level=logging.DEBUG)
class Receiver(object):
class Transmitter(object):
def __init__(self):
self.output = None
@ -45,17 +52,49 @@ class Receiver(object):
raise NotImplementedError
class FifoReceiver(Receiver):
def __init__(self, fifo_path: str, output: Queue, read_max_num: int):
class BeforeAfterMethods:
@classmethod
def mask_preprocess(cls, mask: np.ndarray):
logging.info(f"Send mask with size {mask.shape}")
return mask.tobytes()
@classmethod
def spec_data_post_process(cls, data):
if len(data) < 3:
threshold = int(float(data))
logging.info(f"Get Spec threshold: {threshold}")
return threshold
else:
spec_img = np.frombuffer(data, dtype=np.float32).\
reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1)
logging.info(f"Get SPEC image with size {spec_img.shape}")
return spec_img
@classmethod
def rgb_data_post_process(cls, data):
if len(data) < 3:
threshold = int(float(data))
logging.info(f"Get RGB threshold: {threshold}")
return threshold
else:
rgb_img = np.frombuffer(data, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
logging.info(f"Get RGB img with size {rgb_img.shape}")
return rgb_img
class FifoReceiver(Transmitter):
def __init__(self, fifo_path: str, output: Queue, read_max_num: int, msg_queue=None):
super().__init__()
self._input_fifo_path = None
self._output_queue = None
self._msg_queue = msg_queue
self._max_len = read_max_num
self.set_source(fifo_path)
self.set_output(output)
self._need_stop = threading.Event()
self._need_stop.clear()
self._running_thread = None
def set_source(self, fifo_path: str):
if not os.access(fifo_path, os.F_OK):
@ -66,9 +105,9 @@ class FifoReceiver(Receiver):
self._output_queue = output
def start(self, post_process_func=None, name='fifo_receiver'):
x = threading.Thread(target=self._receive_thread_func,
name=name, args=(post_process_func, ))
x.start()
self._running_thread = threading.Thread(target=self._receive_thread_func,
name=name, args=(post_process_func, ))
self._running_thread.start()
def stop(self):
self._need_stop.set()
@ -85,27 +124,170 @@ class FifoReceiver(Receiver):
data = os.read(input_fifo, self._max_len)
if post_process_func is not None:
data = post_process_func(data)
self._output_queue.put(data)
self._output_queue.safe_put(data)
os.close(input_fifo)
self._need_stop.clear()
@staticmethod
def spec_data_post_process(data):
if len(data) < 3:
threshold = int(float(data))
print("[INFO] Get Spec threshold: ", threshold)
return threshold
else:
spec_img = np.frombuffer(data, dtype=np.float32).\
reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1)
return spec_img
@staticmethod
def rgb_data_post_process(data):
if len(data) < 3:
threshold = int(float(data))
print("[INFO] Get RGB threshold: ", threshold)
return threshold
else:
rgb_img = np.frombuffer(data, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
return rgb_img
class FifoSender(Transmitter):
def __init__(self, output_fifo_path: str, source: Queue):
super().__init__()
self._input_source = None
self._output_fifo_path = None
self.set_source(source)
self.set_output(output_fifo_path)
self._need_stop = threading.Event()
self._need_stop.clear()
self._running_thread = None
def set_source(self, source: Queue):
self._input_source = source
def set_output(self, output_fifo_path: str):
if not os.access(output_fifo_path, os.F_OK):
os.mkfifo(output_fifo_path, 0o777)
self._output_fifo_path = output_fifo_path
def start(self, pre_process=None, name='fifo_receiver'):
self._running_thread = threading.Thread(target=self._send_thread_func, name=name,
args=(pre_process, ))
self._running_thread.start()
def stop(self):
self._need_stop.set()
def _send_thread_func(self, pre_process=None):
"""
接收线程
:param pre_process:
:return:
"""
while not self._need_stop.is_set():
if self._input_source.empty():
continue
data = self._input_source.get()
if pre_process is not None:
data = pre_process(data)
output_fifo = os.open(self._output_fifo_path, os.O_WRONLY)
os.write(output_fifo, data)
os.close(output_fifo)
self._need_stop.clear()
def __del__(self):
self.stop()
if self._running_thread is not None:
self._running_thread.join()
class CmdImgSplitMidware(Transmitter):
"""
用于控制命令和图像的中间件
"""
def __init__(self, subscribers: typing.Dict[str, Queue], rgb_queue: Queue, spec_queue: Queue):
super().__init__()
self._rgb_queue = None
self._spec_queue = None
self._subscribers = None
self._server_thread = None
self.set_source(rgb_queue, spec_queue)
self.set_output(subscribers)
self.thread_stop = threading.Event()
def set_source(self, rgb_queue: Queue, spec_queue: Queue):
self._rgb_queue = rgb_queue
self._spec_queue = spec_queue
def set_output(self, output: typing.Dict[str, Queue]):
self._subscribers = output
def start(self, name='CMD_thread'):
self._server_thread = threading.Thread(target=self._cmd_control_service, name=name)
self._server_thread.start()
def stop(self):
self.thread_stop.set()
def _cmd_control_service(self):
while not self.thread_stop.is_set():
# 判断是否有数据,如果没有数据那么就等下次吧,如果有数据来,必须保证同时
if self._rgb_queue.empty() or self._spec_queue.empty():
continue
rgb_data = self._rgb_queue.get()
spec_data = self._spec_queue.get()
if isinstance(rgb_data, int) and isinstance(spec_data, int):
# 看是不是命令需要执行如果是命令,就执行
Config.rgb_size_threshold = rgb_data
Config.spec_size_threshold = spec_data
continue
elif isinstance(spec_data, np.ndarray) and isinstance(rgb_data, np.ndarray):
# 如果是图片,交给预测的人
for name, subscriber in self._subscribers.items():
item = (spec_data, rgb_data)
subscriber.safe_put(item)
else:
# 否则程序出现毁灭性问题,立刻崩
raise Exception("两个相机传回的数据没有对上")
self.thread_stop.clear()
class ImageSaver(Transmitter):
"""
进行图片存储的中间件
"""
def set_source(self, *args, **kwargs):
pass
def start(self, *args, **kwargs):
pass
def stop(self, *args, **kwargs):
pass
class ThreadDetector(Transmitter):
def __init__(self, input_queue: Queue, output_queue: Queue):
super().__init__()
self._input_queue, self._output_queue = input_queue, output_queue
self._spec_detector = SpecDetector(blk_model_path=Config.blk_model_path,
pixel_model_path=Config.pixel_model_path)
self._rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path,
background_model_path=Config.rgb_background_model_path)
self._predict_thread = None
self._thread_exit = threading.Event()
def set_source(self, img_queue: Queue):
self._input_queue = img_queue
def stop(self, *args, **kwargs):
self._thread_exit.set()
def start(self, name='predict_thread'):
self._predict_thread = threading.Thread(target=self._predict_server, name=name)
self._predict_thread.start()
def predict(self, spec: np.ndarray, rgb: np.ndarray):
logging.info(f'Detector get image with shape {spec.shape} and {rgb.shape}')
t1 = time.time()
mask = self._spec_detector.predict(spec)
t2 = time.time()
logging.info(f'Detector finish spec predict within {(t2 - t1) * 1000:.2f}ms')
# rgb识别
mask_rgb = self._rgb_detector.predict(rgb)
t3 = time.time()
logging.info(f'Detector finish rgb predict within {(t3 - t2) * 1000:.2f}ms')
# 结果合并
mask_result = (mask | mask_rgb).astype(np.uint8)
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
t4 = time.time()
logging.info(f'Detector finish merge within {(t4 - t3) * 1000: .2f}ms')
logging.info(f'Detector finish predict within {(time.time() -t1)*1000:.2f}ms')
return mask_result
def _predict_server(self):
while not self._thread_exit.is_set():
if not self._input_queue.empty():
spec, rgb = self._input_queue.get()
mask = self.predict(spec, rgb)
self._output_queue.safe_put(mask)
self._thread_exit.clear()

View File

@ -5,6 +5,7 @@
# @Software:PyCharm
import glob
import os
from queue import Queue
import cv2
import numpy as np
@ -26,6 +27,34 @@ class MergeDict(dict):
return self
class ImgQueue(Queue):
"""
A custom queue subclass that provides a :meth:`clear` method.
"""
def clear(self):
"""
Clears all items from the queue.
"""
with self.mutex:
unfinished = self.unfinished_tasks - len(self.queue)
if unfinished <= 0:
if unfinished < 0:
raise ValueError('task_done() called too many times')
self.all_tasks_done.notify_all()
self.unfinished_tasks = unfinished
self.queue.clear()
self.not_full.notify_all()
def safe_put(self, item):
if self.full():
_ = self.get()
return False
self.put(item)
return True
def read_labeled_img(dataset_dir: str, color_dict: dict, ext='.bmp', is_ps_color_space=True) -> dict:
"""
根据dataset_dir下的文件创建数据集