From 9e752003aef00fa55970c54fa0d2c6553ee2d3e6 Mon Sep 17 00:00:00 2001
From: "li.zhenye"
Date: Thu, 4 Aug 2022 10:43:54 +0800
Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86=E5=8D=95=E9=80=9A?=
=?UTF-8?q?=E9=81=93=E9=A2=84=E6=B5=8B=E7=9A=84=E5=8A=9F=E8=83=BD?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
main.py | 77 +++++++++++++++++++++++----------------------------------
1 file changed, 31 insertions(+), 46 deletions(-)
diff --git a/main.py b/main.py
index 8458778..a9c3b45 100755
--- a/main.py
+++ b/main.py
@@ -1,19 +1,13 @@
import os
+import cv2
import time
-from queue import Queue
-
import numpy as np
-from matplotlib import pyplot as plt
-
-import models
-import transmit
from config import Config
from models import RgbDetector, SpecDetector
-import cv2
-def main():
+def main(only_spec=False, only_color=False):
spec_detector = SpecDetector(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path)
rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path,
background_model_path=Config.rgb_background_model_path)
@@ -25,17 +19,12 @@ def main():
os.mkfifo(rgb_fifo_path, 0o777)
if not os.access(mask_fifo_path, os.F_OK):
os.mkfifo(mask_fifo_path, 0o777)
-
- # 进行补偿buffer的开启
- if Config.offset_vertical < 0:
- # 纵向的补偿小于0,那就意味着光谱图要上移才能补上,那么我们应该补偿SPEC相机的全 0 图像
- conserve_part = np.zeros((abs(Config.offset_vertical) // 4, Config.nRows, Config.nBands))
- elif Config.offset_vertical > 0:
- # 纵向的补偿小于0,说明光谱图下移才能补上去,那么我们就需要补偿RGB相机的全 0 图像
- conserve_part = np.zeros(abs(Config.offset_vertical), Config.nRgbRows, Config.nRgbBands)
+ if not os.access(rgb_mask_fifo_path, os.F_OK):
+ os.mkfifo(rgb_mask_fifo_path, 0o777)
while True:
fd_img = os.open(img_fifo_path, os.O_RDONLY)
fd_rgb = os.open(rgb_fifo_path, os.O_RDONLY)
+
# spec data read
data = os.read(fd_img, total_len)
if len(data) < 3:
@@ -45,6 +34,7 @@ def main():
else:
data_total = data
os.close(fd_img)
+
# rgb data read
rgb_data = os.read(fd_rgb, total_rgb)
if len(rgb_data) < 3:
@@ -55,36 +45,28 @@ def main():
else:
rgb_data_total = rgb_data
os.close(fd_rgb)
+
# 识别
t1 = time.time()
- img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1))\
- .transpose(0, 2, 1)
+ img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)) \
+ .transpose(0, 2, 1)
rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
-
- # OFFSET compensate
- if Config.offset_vertical < 0:
- # 纵向的补偿小于0,那就意味着光谱图要上移才能补上,那么我们应该补偿SPEC相机的全 0 图像
- new_conserve_part, real_part = img_data[:abs(Config.offset_vertical) // 4, ...],\
- img_data[abs(Config.offset_vertical) // 4:, ...]
- img_data = np.concatenate([real_part, conserve_part], axis=0)
- conserve_part = new_conserve_part
- elif Config.offset_vertical > 0:
- # 纵向的补偿小于0,说明光谱图下移才能补上去,那么我们就需要补偿RGB相机的全 0 图像
- new_conserve_part, real_part = rgb_data[:abs(Config.offset_vertical), ...],\
- rgb_data[abs(Config.offset_vertical):, ...]
- rgb_data = np.concatenate([real_part, conserve_part], axis=0)
- conserve_part = new_conserve_part
- # 光谱识别
- mask_spec = spec_detector.predict(img_data)
- # rgb识别
- mask_rgb = rgb_detector.predict(rgb_data)
- # 结果合并
- mask_result = (mask_spec | mask_rgb).astype(np.uint8)
+ if only_spec:
+ # 光谱识别
+ mask_spec = spec_detector.predict(img_data)
+ mask_rgb = np.zeros_like(mask_spec, dtype=np.uint8)
+ elif only_color:
+ # rgb识别
+ mask_rgb = rgb_detector.predict(rgb_data)
+ mask_spec = np.zeros_like(mask_rgb, dtype=np.uint8)
+ else:
+ mask_spec = spec_detector.predict(img_data)
+ mask_rgb = rgb_detector.predict(rgb_data)
# control the size of the output masks
- masks = [cv2.resize(mask.astype(np.uint8), Config.target_size) for mask in [mask_result, ]]
+ masks = [cv2.resize(mask.astype(np.uint8), Config.target_size) for mask in [mask_spec, mask_rgb]]
# 写出
- output_fifos = [mask_fifo_path, ]
+ output_fifos = [mask_fifo_path, rgb_mask_fifo_path]
for fifo, mask in zip(output_fifos, masks):
fd_mask = os.open(fifo, os.O_WRONLY)
os.write(fd_mask, mask.tobytes())
@@ -94,12 +76,15 @@ def main():
if __name__ == '__main__':
- # 相关参数
+ import argparse
+ parser = argparse.ArgumentParser(description='主程序')
+ parser.add_argument('-oc', default=False, action='store_true', help='只进行RGB彩色预测 only rgb', required=False)
+ parser.add_argument('-os', default=False, action='store_true', help='只进行光谱预测 only spec', required=False)
+ args = parser.parse_args()
+ # fifo 参数
img_fifo_path = "/tmp/dkimg.fifo"
rgb_fifo_path = "/tmp/dkrgb.fifo"
-
+ # mask fifo
mask_fifo_path = "/tmp/dkmask.fifo"
- # 主函数
- main()
- # read_c_captures('/home/lzy/2022.7.15/tobacco_v1_0/', no_mask=True, nrows=256, ncols=1024,
- # selected_bands=[380, 300, 200])
+ rgb_mask_fifo_path = "/tmp/dkmask_rgb.fifo"
+ main(only_spec=args.os, only_color=args.oc)