From 597284c3e1ef3f9062dee5600fd9254a713ce838 Mon Sep 17 00:00:00 2001
From: "li.zhenye" <李>
Date: Wed, 27 Jul 2022 13:40:35 +0800
Subject: [PATCH 1/5] remove the saving function
---
main.py | 107 +-------------------------------------------------------
1 file changed, 1 insertion(+), 106 deletions(-)
diff --git a/main.py b/main.py
index 32602f9..7ac3b21 100755
--- a/main.py
+++ b/main.py
@@ -8,9 +8,6 @@ from models import RgbDetector, SpecDetector, ManualTree, AnonymousColorDetector
import cv2
-SAVE_IMG, SAVE_NUM = False, 30
-
-
def main():
spec_detector = SpecDetector(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path)
rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path,
@@ -23,8 +20,6 @@ def main():
os.mkfifo(mask_fifo_path, 0o777)
if not os.access(rgb_fifo_path, os.F_OK):
os.mkfifo(rgb_fifo_path, 0o777)
- if SAVE_IMG:
- img_list = []
while True:
fd_img = os.open(img_fifo_path, os.O_RDONLY)
fd_rgb = os.open(rgb_fifo_path, os.O_RDONLY)
@@ -55,11 +50,6 @@ def main():
img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)) \
.transpose(0, 2, 1)
rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
- if SAVE_IMG:
- SAVE_NUM -= 1
- img_list.append((rgb_data, img_data))
- if SAVE_NUM <= 0:
- break
# 光谱识别
mask = spec_detector.predict(img_data)
# rgb识别
@@ -71,108 +61,13 @@ def main():
# mask_result = mask_rgb.astype(np.uint8)
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
t2 = time.time()
- print(f'rgb len = {len(rgb_data)}')
# 写出
fd_mask = os.open(mask_fifo_path, os.O_WRONLY)
os.write(fd_mask, mask_result.tobytes())
os.close(fd_mask)
t3 = time.time()
- print(f'total time is:{t3 - t1}')
- for i, img in enumerate(img_list):
- print(f"writing img {i}...")
- cv2.imwrite(f"./{i}.png", img[0][..., ::-1])
- np.save(f'./{i}.npy', img[1])
- i += 1
-
-
-
-def save_main():
- threshold = Config.spec_size_threshold
- rgb_threshold = Config.rgb_size_threshold
- manual_tree = ManualTree(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path)
- tobacco_detector = AnonymousColorDetector(file_path=Config.rgb_tobacco_model_path)
- background_detector = AnonymousColorDetector(file_path=Config.rgb_background_model_path)
- total_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节
- total_rgb = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量
- if not os.access(img_fifo_path, os.F_OK):
- os.mkfifo(img_fifo_path, 0o777)
- if not os.access(mask_fifo_path, os.F_OK):
- os.mkfifo(mask_fifo_path, 0o777)
- if not os.access(rgb_fifo_path, os.F_OK):
- os.mkfifo(rgb_fifo_path, 0o777)
- img_list = []
- idx = 0
- while idx <= 30:
- idx += 1
- fd_img = os.open(img_fifo_path, os.O_RDONLY)
- fd_rgb = os.open(rgb_fifo_path, os.O_RDONLY)
- data = os.read(fd_img, total_len)
-
- # 读取(开启一个管道)
- if len(data) < 3:
- threshold = int(float(data))
- print("[INFO] Get threshold: ", threshold)
- continue
- else:
- data_total = data
- rgb_data = os.read(fd_rgb, total_rgb)
- if len(rgb_data) < 3:
- rgb_threshold = int(float(rgb_data))
- print(rgb_threshold)
- continue
- else:
- rgb_data_total = rgb_data
- os.close(fd_img)
- os.close(fd_rgb)
-
- # 识别
- t1 = time.time()
- img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)). \
- transpose(0, 2, 1)
- rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
- img_list.append((rgb_data.copy(), img_data.copy()))
-
- pixel_predict_result = manual_tree.pixel_predict_ml_dilation(data=img_data, iteration=1)
- blk_predict_result = manual_tree.blk_predict(data=img_data)
- rgb_data = tobacco_detector.pretreatment(rgb_data)
- # print(rgb_data.shape)
- rgb_predict_result = 1 - (background_detector.predict(rgb_data, threshold_low=Config.threshold_low,
- threshold_high=Config.threshold_high) |
- tobacco_detector.swell(tobacco_detector.predict(rgb_data,
- threshold_low=Config.threshold_low,
- threshold_high=Config.threshold_high)))
- mask_rgb = rgb_predict_result.reshape(Config.nRows, Config.nCols // Config.blk_size, Config.blk_size) \
- .sum(axis=2).reshape(Config.nRows // 4, Config.blk_size, Config.nCols // Config.blk_size) \
- .sum(axis=1)
- mask_rgb[mask_rgb <= rgb_threshold] = 0
- mask_rgb[mask_rgb > rgb_threshold] = 1
- mask = (pixel_predict_result & blk_predict_result).astype(np.uint8)
- mask = mask.reshape(Config.nRows, Config.nCols // Config.blk_size, Config.blk_size) \
- .sum(axis=2).reshape(Config.nRows // 4, Config.blk_size, Config.nCols // Config.blk_size) \
- .sum(axis=1)
- mask[mask <= threshold] = 0
- mask[mask > threshold] = 1
- mask_result = (mask | mask_rgb).astype(np.uint8)
- # mask_result = mask_rgb
- mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
- t2 = time.time()
- print(f'rgb len = {len(rgb_data)}')
-
- # 写出
- fd_mask = os.open(mask_fifo_path, os.O_WRONLY)
- os.write(fd_mask, mask_result.tobytes())
- os.close(fd_mask)
- t3 = time.time()
- print(f'total time is:{t3 - t1}')
- i = 0
- print("Stop Serving")
- for img in img_list:
- print(f"writing img {i}...")
- cv2.imwrite(f"./{i}.png", img[0][..., ::-1])
- np.save(f'./{i}.npy', img[1])
- i += 1
- print("save success")
+ print(f'total time is:{t3 - t1}\n')
if __name__ == '__main__':
From 32010f35f00fec1e75d8f420979426dcc8792371 Mon Sep 17 00:00:00 2001
From: "li.zhenye" <李>
Date: Wed, 27 Jul 2022 17:44:21 +0800
Subject: [PATCH 2/5] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86=E8=AF=BB?=
=?UTF-8?q?=E5=8F=96c=E8=AF=AD=E8=A8=80buff=E7=9A=84=E5=8A=9F=E8=83=BD?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
config.py | 2 +-
main.py | 23 +++++++++++++++++++++++
models.py | 2 +-
3 files changed, 25 insertions(+), 2 deletions(-)
diff --git a/config.py b/config.py
index 483c629..2e5266c 100755
--- a/config.py
+++ b/config.py
@@ -20,7 +20,7 @@ class Config:
# 光谱模型参数
blk_size = 4
pixel_model_path = r"./models/dt.p"
- blk_model_path = r"./models/rf_4x4_c22_20_sen8_8.model"
+ blk_model_path = r"./models/rf_4x4_c22_20_sen8_9.model"
spec_size_threshold = 3
# rgb模型参数
diff --git a/main.py b/main.py
index 7ac3b21..e368ea7 100755
--- a/main.py
+++ b/main.py
@@ -1,5 +1,7 @@
import os
import time
+
+import matplotlib.pyplot as plt
import numpy as np
import scipy.io
@@ -70,6 +72,26 @@ def main():
print(f'total time is:{t3 - t1}\n')
+def read_c_captures(buffer_path):
+ buffer_names = [buffer_name for buffer_name in os.listdir(buffer_path) if buffer_name.endswith('.raw')]
+ for buffer_name in buffer_names:
+ with open(os.path.join(buffer_path, buffer_name), 'rb') as f:
+ data = f.read()
+ img = np.frombuffer(data, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)) \
+ .transpose(0, 2, 1)
+ mask_name = buffer_name.replace('buf', 'mask').replace('.raw', '')
+ with open(os.path.join(buffer_path, mask_name), 'rb') as f:
+ data = f.read()
+ mask = np.frombuffer(data, dtype=np.uint8).reshape((256, 1024, -1))
+ # mask = cv2.resize(mask, (1024, 256))
+ fig, axs = plt.subplots(2, 1)
+ axs[0].imshow(img[..., [21, 3, 0]])
+ axs[0].set_title(buffer_name)
+ axs[1].imshow(mask)
+ axs[1].set_title(mask_name)
+ plt.show()
+
+
if __name__ == '__main__':
# 相关参数
img_fifo_path = "/tmp/dkimg.fifo"
@@ -78,3 +100,4 @@ if __name__ == '__main__':
# 主函数
main()
+ # read_c_captures('/home/lzy/2022.7.20/tobacco_v1_0/')
diff --git a/models.py b/models.py
index 95d7f4a..181651d 100755
--- a/models.py
+++ b/models.py
@@ -415,7 +415,7 @@ class SpecDetector(Detector):
# 烟梗mask中将背景赋值为0,将烟梗赋值为2
yellow_things[yellow_things] = tobacco
yellow_things = yellow_things + 0
- # yellow_things = binary_dilation(yellow_things, iterations=iteration)
+ yellow_things = binary_dilation(yellow_things, iterations=iteration)
yellow_things = yellow_things + 0
yellow_things[yellow_things == 1] = 2
From 1bd6e7bbe57ae7a20fa86102ae669a45637a333c Mon Sep 17 00:00:00 2001
From: "li.zhenye" <李>
Date: Thu, 28 Jul 2022 11:03:21 +0800
Subject: [PATCH 3/5] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86=E8=AF=BB?=
=?UTF-8?q?=E5=8F=96rawfile=20de=20function?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
main.py | 30 +++++++++++++++++++++---------
1 file changed, 21 insertions(+), 9 deletions(-)
diff --git a/main.py b/main.py
index e368ea7..ee51cf4 100755
--- a/main.py
+++ b/main.py
@@ -72,20 +72,31 @@ def main():
print(f'total time is:{t3 - t1}\n')
-def read_c_captures(buffer_path):
- buffer_names = [buffer_name for buffer_name in os.listdir(buffer_path) if buffer_name.endswith('.raw')]
+def read_c_captures(buffer_path, no_mask=True, nrows=256, ncols=1024, selected_bands=None):
+ if os.path.isdir(buffer_path):
+ buffer_names = [buffer_name for buffer_name in os.listdir(buffer_path) if buffer_name.endswith('.raw')]
+ else:
+ buffer_names = [buffer_path, ]
for buffer_name in buffer_names:
with open(os.path.join(buffer_path, buffer_name), 'rb') as f:
data = f.read()
- img = np.frombuffer(data, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)) \
+ img = np.frombuffer(data, dtype=np.float32).reshape((nrows, -1, ncols)) \
.transpose(0, 2, 1)
- mask_name = buffer_name.replace('buf', 'mask').replace('.raw', '')
- with open(os.path.join(buffer_path, mask_name), 'rb') as f:
- data = f.read()
- mask = np.frombuffer(data, dtype=np.uint8).reshape((256, 1024, -1))
+ if selected_bands is not None:
+ img = img[..., selected_bands]
+ if img.shape[0] == 1:
+ img = img[0, ...]
+ if not no_mask:
+ mask_name = buffer_name.replace('buf', 'mask').replace('.raw', '')
+ with open(os.path.join(buffer_path, mask_name), 'rb') as f:
+ data = f.read()
+ mask = np.frombuffer(data, dtype=np.uint8).reshape((nrows, ncols, -1))
+ else:
+ mask_name = "no mask"
+ mask = np.zeros_like(img)
# mask = cv2.resize(mask, (1024, 256))
fig, axs = plt.subplots(2, 1)
- axs[0].imshow(img[..., [21, 3, 0]])
+ axs[0].matshow(img)
axs[0].set_title(buffer_name)
axs[1].imshow(mask)
axs[1].set_title(mask_name)
@@ -100,4 +111,5 @@ if __name__ == '__main__':
# 主函数
main()
- # read_c_captures('/home/lzy/2022.7.20/tobacco_v1_0/')
+ # read_c_captures('/home/lzy/2022.7.15/tobacco_v1_0/', no_mask=True, nrows=256, ncols=1024,
+ # selected_bands=[380, 300, 200])
From a0d14940e850472d5229a12b2765fd97f012a009 Mon Sep 17 00:00:00 2001
From: "li.zhenye"
Date: Thu, 28 Jul 2022 12:39:54 +0800
Subject: [PATCH 4/5] =?UTF-8?q?=E5=A4=9A=E7=BA=BF=E7=A8=8B=E7=89=88?=
=?UTF-8?q?=E6=9C=AC?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
efficient_ui.py | 51 +++++++++++
main.py | 14 ++--
models.py | 5 +-
transmit.py | 218 ++++++++++++++++++++++++++++++++++++++++++------
utils.py | 29 +++++++
5 files changed, 284 insertions(+), 33 deletions(-)
create mode 100644 efficient_ui.py
diff --git a/efficient_ui.py b/efficient_ui.py
new file mode 100644
index 0000000..dd627b6
--- /dev/null
+++ b/efficient_ui.py
@@ -0,0 +1,51 @@
+import cv2
+import numpy as np
+
+import transmit
+from config import Config
+
+from utils import ImgQueue as Queue
+
+
+class EfficientUI(object):
+ def __init__(self):
+ # 相关参数
+ img_fifo_path = "/tmp/dkimg.fifo"
+ mask_fifo_path = "/tmp/dkmask.fifo"
+ rgb_fifo_path = "/tmp/dkrgb.fifo"
+ # 创建队列用于链接各个线程
+ rgb_img_queue, spec_img_queue = Queue(), Queue()
+ detector_queue, save_queue, self.visual_queue = Queue(), Queue, Queue()
+ mask_queue = Queue()
+ # 两个接收者,接收光谱和rgb图像
+ spec_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节
+ rgb_len = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量
+ spec_receiver = transmit.FifoReceiver(fifo_path=img_fifo_path, output=spec_img_queue, read_max_num=spec_len)
+ rgb_receiver = transmit.FifoReceiver(fifo_path=rgb_fifo_path, output=rgb_img_queue, read_max_num=rgb_len)
+ # 指令执行与图像流向控制
+ subscribers = {'detector': detector_queue, 'visualize': self.visual_queue, 'save': save_queue}
+ cmd_img_controller = transmit.CmdImgSplitMidware(rgb_queue=rgb_img_queue, spec_queue=spec_img_queue,
+ subscribers=subscribers)
+ # 探测器
+ detector = transmit.ThreadDetector(input_queue=detector_queue, output_queue=mask_queue)
+ # 发送
+ sender = transmit.FifoSender(output_fifo_path=mask_fifo_path, source=mask_queue)
+ # 启动所有线程
+ spec_receiver.start(post_process_func=transmit.PostProcessMethods.spec_data_post_process, name='spce_thread')
+ rgb_receiver.start(post_process_func=transmit.PostProcessMethods.rgb_data_post_process, name='rgb_thread')
+ cmd_img_controller.start(name='control_thread')
+ detector.start(name='detector_thread')
+ sender.start(name='sender_thread')
+
+ def start(self):
+ # 启动图形化
+ while True:
+ cv2.imshow('image_show', mat=np.ones((256, 1024)))
+ key_code = cv2.waitKey(30)
+ if key_code == ord("s"):
+ pass
+
+
+if __name__ == '__main__':
+ app = EfficientUI()
+ app.start()
diff --git a/main.py b/main.py
index ee51cf4..9f4a8f6 100755
--- a/main.py
+++ b/main.py
@@ -1,12 +1,15 @@
import os
import time
+from queue import Queue
-import matplotlib.pyplot as plt
import numpy as np
-import scipy.io
+from matplotlib import pyplot as plt
+
+import models
+import transmit
from config import Config
-from models import RgbDetector, SpecDetector, ManualTree, AnonymousColorDetector
+from models import RgbDetector, SpecDetector
import cv2
@@ -56,20 +59,19 @@ def main():
mask = spec_detector.predict(img_data)
# rgb识别
mask_rgb = rgb_detector.predict(rgb_data)
-
# 结果合并
mask_result = (mask | mask_rgb).astype(np.uint8)
-
# mask_result = mask_rgb.astype(np.uint8)
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
t2 = time.time()
+ print(f'rgb len = {len(rgb_data)}')
# 写出
fd_mask = os.open(mask_fifo_path, os.O_WRONLY)
os.write(fd_mask, mask_result.tobytes())
os.close(fd_mask)
t3 = time.time()
- print(f'total time is:{t3 - t1}\n')
+ print(f'total time is:{t3 - t1}')
def read_c_captures(buffer_path, no_mask=True, nrows=256, ncols=1024, selected_bands=None):
diff --git a/models.py b/models.py
index 181651d..f503b96 100755
--- a/models.py
+++ b/models.py
@@ -5,10 +5,12 @@
# @Software:PyCharm、
import datetime
import pickle
+from queue import Queue
import cv2
import numpy as np
import scipy.io
+import threading
from scipy.ndimage import binary_dilation
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report
@@ -17,7 +19,8 @@ from sklearn.model_selection import train_test_split
from config import Config
from utils import lab_scatter, read_labeled_img, size_threshold
-deploy = True
+
+deploy = False
if not deploy:
print("Training env")
from tqdm import tqdm
diff --git a/transmit.py b/transmit.py
index b990718..c0fed63 100644
--- a/transmit.py
+++ b/transmit.py
@@ -1,11 +1,13 @@
import os
import threading
-from queue import Queue
+from utils import ImgQueue as Queue
import numpy as np
from config import Config
+from models import SpecDetector, RgbDetector
+import typing
-class Receiver(object):
+class Transmitter(object):
def __init__(self):
self.output = None
@@ -45,17 +47,42 @@ class Receiver(object):
raise NotImplementedError
-class FifoReceiver(Receiver):
- def __init__(self, fifo_path: str, output: Queue, read_max_num: int):
+class PostProcessMethods:
+ @classmethod
+ def spec_data_post_process(cls, data):
+ if len(data) < 3:
+ threshold = int(float(data))
+ print("[INFO] Get Spec threshold: ", threshold)
+ return threshold
+ else:
+ spec_img = np.frombuffer(data, dtype=np.float32).\
+ reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1)
+ return spec_img
+
+ @classmethod
+ def rgb_data_post_process(cls, data):
+ if len(data) < 3:
+ threshold = int(float(data))
+ print("[INFO] Get RGB threshold: ", threshold)
+ return threshold
+ else:
+ rgb_img = np.frombuffer(data, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
+ return rgb_img
+
+
+class FifoReceiver(Transmitter):
+ def __init__(self, fifo_path: str, output: Queue, read_max_num: int, msg_queue=None):
super().__init__()
self._input_fifo_path = None
self._output_queue = None
+ self._msg_queue = msg_queue
self._max_len = read_max_num
self.set_source(fifo_path)
self.set_output(output)
self._need_stop = threading.Event()
self._need_stop.clear()
+ self._running_thread = None
def set_source(self, fifo_path: str):
if not os.access(fifo_path, os.F_OK):
@@ -66,9 +93,9 @@ class FifoReceiver(Receiver):
self._output_queue = output
def start(self, post_process_func=None, name='fifo_receiver'):
- x = threading.Thread(target=self._receive_thread_func,
- name=name, args=(post_process_func, ))
- x.start()
+ self._running_thread = threading.Thread(target=self._receive_thread_func,
+ name=name, args=(post_process_func, ))
+ self._running_thread.start()
def stop(self):
self._need_stop.set()
@@ -85,27 +112,166 @@ class FifoReceiver(Receiver):
data = os.read(input_fifo, self._max_len)
if post_process_func is not None:
data = post_process_func(data)
- self._output_queue.put(data)
+ if not self._output_queue.safe_put(data):
+ if self._msg_queue is not None:
+ self._msg_queue.put('Fifo Receiver的接收者未来得及接收')
os.close(input_fifo)
self._need_stop.clear()
- @staticmethod
- def spec_data_post_process(data):
- if len(data) < 3:
- threshold = int(float(data))
- print("[INFO] Get Spec threshold: ", threshold)
- return threshold
- else:
- spec_img = np.frombuffer(data, dtype=np.float32).\
- reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1)
- return spec_img
+
+class FifoSender(Transmitter):
+ def __init__(self, output_fifo_path: str, source: Queue):
+ super().__init__()
+ self._input_source = None
+ self._output_fifo_path = None
+ self.set_source(source)
+ self.set_output(output_fifo_path)
+ self._need_stop = threading.Event()
+ self._need_stop.clear()
+ self._running_thread = None
+
+ def set_source(self, source: Queue):
+ self._input_source = source
+
+ def set_output(self, output_fifo_path: str):
+ if not os.access(output_fifo_path, os.F_OK):
+ os.mkfifo(output_fifo_path, 0o777)
+ self._output_fifo_path = output_fifo_path
+
+ def start(self, pre_process=None, name='fifo_receiver'):
+ self._running_thread = threading.Thread(target=self._send_thread_func, name=name,
+ args=(pre_process, ))
+ self._running_thread.start()
+
+ def stop(self):
+ self._need_stop.set()
+
+ def _send_thread_func(self, pre_process=None):
+ """
+ 接收线程
+
+ :param pre_process:
+ :return:
+ """
+ while not self._need_stop.is_set():
+ output_fifo = os.open(self._output_fifo_path, os.O_WRONLY)
+ if self._input_source.empty():
+ continue
+ data = self._input_source.get()
+ if pre_process is not None:
+ data = pre_process(data)
+ os.write(output_fifo, data)
+ os.close(output_fifo)
+ self._need_stop.clear()
@staticmethod
- def rgb_data_post_process(data):
- if len(data) < 3:
- threshold = int(float(data))
- print("[INFO] Get RGB threshold: ", threshold)
- return threshold
- else:
- rgb_img = np.frombuffer(data, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
- return rgb_img
+ def mask_preprocess(mask: np.ndarray):
+ return mask.tobytes()
+
+ def __del__(self):
+ self.stop()
+ if self._running_thread is not None:
+ self._running_thread.join()
+
+
+class CmdImgSplitMidware(Transmitter):
+ """
+ 用于控制命令和图像的中间件
+ """
+ def __init__(self, subscribers: typing.Dict[str, Queue], rgb_queue: Queue, spec_queue: Queue):
+ super().__init__()
+ self._rgb_queue = None
+ self._spec_queue = None
+ self._subscribers = None
+ self._server_thread = None
+ self.set_source(rgb_queue, spec_queue)
+ self.set_output(subscribers)
+ self.thread_stop = threading.Event()
+
+ def set_source(self, rgb_queue: Queue, spec_queue: Queue):
+ self._rgb_queue = rgb_queue
+ self._spec_queue = spec_queue
+
+ def set_output(self, output: typing.Dict[str, Queue]):
+ self._subscribers = output
+
+ def start(self, name='CMD_thread'):
+ self._server_thread = threading.Thread(target=self._cmd_control_service)
+ self._server_thread.start()
+
+ def stop(self):
+ self.thread_stop.set()
+
+ def _cmd_control_service(self):
+ while not self.thread_stop.is_set():
+ # 判断是否有数据,如果没有数据那么就等下次吧,如果有数据来,必须保证同时
+ if self._rgb_queue.empty() or self._spec_queue.empty():
+ continue
+ rgb_data = self._rgb_queue.get()
+ spec_data = self._spec_queue.get()
+ if isinstance(rgb_data, int) and isinstance(spec_data, int):
+ # 看是不是命令需要执行如果是命令,就执行
+ Config.rgb_size_threshold = rgb_data
+ Config.spec_size_threshold = spec_data
+ continue
+ elif isinstance(spec_data, np.ndarray) and isinstance(rgb_data, np.ndarray):
+ # 如果是图片,交给预测的人
+ for _, subscriber in self._subscribers.items():
+ subscriber.put((spec_data, rgb_data))
+ else:
+ # 否则程序出现毁灭性问题,立刻崩
+ raise Exception("两个相机传回的数据没有对上")
+ self.thread_stop.clear()
+
+
+class ImageSaver(Transmitter):
+ """
+ 进行图片存储的中间件
+ """
+ def set_source(self, *args, **kwargs):
+ pass
+
+ def start(self, *args, **kwargs):
+ pass
+
+ def stop(self, *args, **kwargs):
+ pass
+
+
+class ThreadDetector(Transmitter):
+ def __init__(self, input_queue: Queue, output_queue: Queue):
+ super().__init__()
+ self._input_queue, self._output_queue = input_queue, output_queue
+ self._spec_detector = SpecDetector(blk_model_path=Config.blk_model_path,
+ pixel_model_path=Config.pixel_model_path)
+ self._rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path,
+ background_model_path=Config.rgb_background_model_path)
+ self._predict_thread = None
+ self._thread_exit = threading.Event()
+
+ def set_source(self, img_queue: Queue):
+ self._input_queue = img_queue
+
+ def stop(self, *args, **kwargs):
+ self._thread_exit.set()
+
+ def start(self, name='predict_thread'):
+ self._predict_thread = threading.Thread(target=self._predict_server, name=name)
+ self._predict_thread.start()
+
+ def predict(self, spec: np.ndarray, rgb: np.ndarray):
+ mask = self._spec_detector.predict(spec)
+ # rgb识别
+ mask_rgb = self._rgb_detector.predict(rgb)
+ # 结果合并
+ mask_result = (mask | mask_rgb).astype(np.uint8)
+ mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
+ return mask_result
+
+ def _predict_server(self):
+ while not self._thread_exit.is_set():
+ if not self._input_queue.empty():
+ spec, rgb = self._input_queue.get()
+ mask = self.predict(spec, rgb)
+ self._output_queue.put(mask)
+ self._thread_exit.clear()
diff --git a/utils.py b/utils.py
index 25607e1..f004242 100755
--- a/utils.py
+++ b/utils.py
@@ -5,6 +5,7 @@
# @Software:PyCharm
import glob
import os
+from queue import Queue
import cv2
import numpy as np
@@ -26,6 +27,34 @@ class MergeDict(dict):
return self
+class ImgQueue(Queue):
+ """
+ A custom queue subclass that provides a :meth:`clear` method.
+ """
+
+ def clear(self):
+ """
+ Clears all items from the queue.
+ """
+
+ with self.mutex:
+ unfinished = self.unfinished_tasks - len(self.queue)
+ if unfinished <= 0:
+ if unfinished < 0:
+ raise ValueError('task_done() called too many times')
+ self.all_tasks_done.notify_all()
+ self.unfinished_tasks = unfinished
+ self.queue.clear()
+ self.not_full.notify_all()
+
+ def safe_put(self, *args, **kwargs):
+ if self.full():
+ _ = self.get()
+ return False
+ self.put(*args, **kwargs)
+ return True
+
+
def read_labeled_img(dataset_dir: str, color_dict: dict, ext='.bmp', is_ps_color_space=True) -> dict:
"""
根据dataset_dir下的文件创建数据集
From d697855abc79f7ce12f76d8655a9ba67c9576c41 Mon Sep 17 00:00:00 2001
From: "li.zhenye" <李>
Date: Thu, 28 Jul 2022 13:47:32 +0800
Subject: [PATCH 5/5] Slow MultiThread Version
---
efficient_ui.py | 8 ++++----
models.py | 2 +-
transmit.py | 46 +++++++++++++++++++++++++++++++---------------
utils.py | 4 ++--
4 files changed, 38 insertions(+), 22 deletions(-)
diff --git a/efficient_ui.py b/efficient_ui.py
index dd627b6..0c74260 100644
--- a/efficient_ui.py
+++ b/efficient_ui.py
@@ -15,7 +15,7 @@ class EfficientUI(object):
rgb_fifo_path = "/tmp/dkrgb.fifo"
# 创建队列用于链接各个线程
rgb_img_queue, spec_img_queue = Queue(), Queue()
- detector_queue, save_queue, self.visual_queue = Queue(), Queue, Queue()
+ detector_queue, save_queue, self.visual_queue = Queue(), Queue(), Queue()
mask_queue = Queue()
# 两个接收者,接收光谱和rgb图像
spec_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节
@@ -31,11 +31,11 @@ class EfficientUI(object):
# 发送
sender = transmit.FifoSender(output_fifo_path=mask_fifo_path, source=mask_queue)
# 启动所有线程
- spec_receiver.start(post_process_func=transmit.PostProcessMethods.spec_data_post_process, name='spce_thread')
- rgb_receiver.start(post_process_func=transmit.PostProcessMethods.rgb_data_post_process, name='rgb_thread')
+ spec_receiver.start(post_process_func=transmit.BeforeAfterMethods.spec_data_post_process, name='spce_thread')
+ rgb_receiver.start(post_process_func=transmit.BeforeAfterMethods.rgb_data_post_process, name='rgb_thread')
cmd_img_controller.start(name='control_thread')
detector.start(name='detector_thread')
- sender.start(name='sender_thread')
+ sender.start(pre_process=transmit.BeforeAfterMethods.mask_preprocess, name='sender_thread')
def start(self):
# 启动图形化
diff --git a/models.py b/models.py
index f503b96..abb1a74 100755
--- a/models.py
+++ b/models.py
@@ -20,7 +20,7 @@ from config import Config
from utils import lab_scatter, read_labeled_img, size_threshold
-deploy = False
+deploy = True
if not deploy:
print("Training env")
from tqdm import tqdm
diff --git a/transmit.py b/transmit.py
index c0fed63..2632a12 100644
--- a/transmit.py
+++ b/transmit.py
@@ -1,10 +1,15 @@
import os
import threading
+import time
+
from utils import ImgQueue as Queue
import numpy as np
from config import Config
from models import SpecDetector, RgbDetector
import typing
+import logging
+logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s %(message)s',
+ level=logging.DEBUG)
class Transmitter(object):
@@ -47,26 +52,33 @@ class Transmitter(object):
raise NotImplementedError
-class PostProcessMethods:
+class BeforeAfterMethods:
+ @classmethod
+ def mask_preprocess(cls, mask: np.ndarray):
+ logging.info(f"Send mask with size {mask.shape}")
+ return mask.tobytes()
+
@classmethod
def spec_data_post_process(cls, data):
if len(data) < 3:
threshold = int(float(data))
- print("[INFO] Get Spec threshold: ", threshold)
+ logging.info(f"Get Spec threshold: {threshold}")
return threshold
else:
spec_img = np.frombuffer(data, dtype=np.float32).\
reshape((Config.nRows, Config.nBands, -1)).transpose(0, 2, 1)
+ logging.info(f"Get SPEC image with size {spec_img.shape}")
return spec_img
@classmethod
def rgb_data_post_process(cls, data):
if len(data) < 3:
threshold = int(float(data))
- print("[INFO] Get RGB threshold: ", threshold)
+ logging.info(f"Get RGB threshold: {threshold}")
return threshold
else:
rgb_img = np.frombuffer(data, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
+ logging.info(f"Get RGB img with size {rgb_img.shape}")
return rgb_img
@@ -112,9 +124,7 @@ class FifoReceiver(Transmitter):
data = os.read(input_fifo, self._max_len)
if post_process_func is not None:
data = post_process_func(data)
- if not self._output_queue.safe_put(data):
- if self._msg_queue is not None:
- self._msg_queue.put('Fifo Receiver的接收者未来得及接收')
+ self._output_queue.safe_put(data)
os.close(input_fifo)
self._need_stop.clear()
@@ -154,20 +164,16 @@ class FifoSender(Transmitter):
:return:
"""
while not self._need_stop.is_set():
- output_fifo = os.open(self._output_fifo_path, os.O_WRONLY)
if self._input_source.empty():
continue
data = self._input_source.get()
if pre_process is not None:
data = pre_process(data)
+ output_fifo = os.open(self._output_fifo_path, os.O_WRONLY)
os.write(output_fifo, data)
os.close(output_fifo)
self._need_stop.clear()
- @staticmethod
- def mask_preprocess(mask: np.ndarray):
- return mask.tobytes()
-
def __del__(self):
self.stop()
if self._running_thread is not None:
@@ -196,7 +202,7 @@ class CmdImgSplitMidware(Transmitter):
self._subscribers = output
def start(self, name='CMD_thread'):
- self._server_thread = threading.Thread(target=self._cmd_control_service)
+ self._server_thread = threading.Thread(target=self._cmd_control_service, name=name)
self._server_thread.start()
def stop(self):
@@ -216,8 +222,9 @@ class CmdImgSplitMidware(Transmitter):
continue
elif isinstance(spec_data, np.ndarray) and isinstance(rgb_data, np.ndarray):
# 如果是图片,交给预测的人
- for _, subscriber in self._subscribers.items():
- subscriber.put((spec_data, rgb_data))
+ for name, subscriber in self._subscribers.items():
+ item = (spec_data, rgb_data)
+ subscriber.safe_put(item)
else:
# 否则程序出现毁灭性问题,立刻崩
raise Exception("两个相机传回的数据没有对上")
@@ -260,12 +267,21 @@ class ThreadDetector(Transmitter):
self._predict_thread.start()
def predict(self, spec: np.ndarray, rgb: np.ndarray):
+ logging.info(f'Detector get image with shape {spec.shape} and {rgb.shape}')
+ t1 = time.time()
mask = self._spec_detector.predict(spec)
+ t2 = time.time()
+ logging.info(f'Detector finish spec predict within {(t2 - t1) * 1000:.2f}ms')
# rgb识别
mask_rgb = self._rgb_detector.predict(rgb)
+ t3 = time.time()
+ logging.info(f'Detector finish rgb predict within {(t3 - t2) * 1000:.2f}ms')
# 结果合并
mask_result = (mask | mask_rgb).astype(np.uint8)
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
+ t4 = time.time()
+ logging.info(f'Detector finish merge within {(t4 - t3) * 1000: .2f}ms')
+ logging.info(f'Detector finish predict within {(time.time() -t1)*1000:.2f}ms')
return mask_result
def _predict_server(self):
@@ -273,5 +289,5 @@ class ThreadDetector(Transmitter):
if not self._input_queue.empty():
spec, rgb = self._input_queue.get()
mask = self.predict(spec, rgb)
- self._output_queue.put(mask)
+ self._output_queue.safe_put(mask)
self._thread_exit.clear()
diff --git a/utils.py b/utils.py
index f004242..bf466f8 100755
--- a/utils.py
+++ b/utils.py
@@ -47,11 +47,11 @@ class ImgQueue(Queue):
self.queue.clear()
self.not_full.notify_all()
- def safe_put(self, *args, **kwargs):
+ def safe_put(self, item):
if self.full():
_ = self.get()
return False
- self.put(*args, **kwargs)
+ self.put(item)
return True