From 9e752003aef00fa55970c54fa0d2c6553ee2d3e6 Mon Sep 17 00:00:00 2001 From: "li.zhenye" Date: Thu, 4 Aug 2022 10:43:54 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86=E5=8D=95=E9=80=9A?= =?UTF-8?q?=E9=81=93=E9=A2=84=E6=B5=8B=E7=9A=84=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.py | 77 +++++++++++++++++++++++---------------------------------- 1 file changed, 31 insertions(+), 46 deletions(-) diff --git a/main.py b/main.py index 8458778..a9c3b45 100755 --- a/main.py +++ b/main.py @@ -1,19 +1,13 @@ import os +import cv2 import time -from queue import Queue - import numpy as np -from matplotlib import pyplot as plt - -import models -import transmit from config import Config from models import RgbDetector, SpecDetector -import cv2 -def main(): +def main(only_spec=False, only_color=False): spec_detector = SpecDetector(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path) rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path, background_model_path=Config.rgb_background_model_path) @@ -25,17 +19,12 @@ def main(): os.mkfifo(rgb_fifo_path, 0o777) if not os.access(mask_fifo_path, os.F_OK): os.mkfifo(mask_fifo_path, 0o777) - - # 进行补偿buffer的开启 - if Config.offset_vertical < 0: - # 纵向的补偿小于0,那就意味着光谱图要上移才能补上,那么我们应该补偿SPEC相机的全 0 图像 - conserve_part = np.zeros((abs(Config.offset_vertical) // 4, Config.nRows, Config.nBands)) - elif Config.offset_vertical > 0: - # 纵向的补偿小于0,说明光谱图下移才能补上去,那么我们就需要补偿RGB相机的全 0 图像 - conserve_part = np.zeros(abs(Config.offset_vertical), Config.nRgbRows, Config.nRgbBands) + if not os.access(rgb_mask_fifo_path, os.F_OK): + os.mkfifo(rgb_mask_fifo_path, 0o777) while True: fd_img = os.open(img_fifo_path, os.O_RDONLY) fd_rgb = os.open(rgb_fifo_path, os.O_RDONLY) + # spec data read data = os.read(fd_img, total_len) if len(data) < 3: @@ -45,6 +34,7 @@ def main(): else: data_total = data os.close(fd_img) + # rgb data read rgb_data = os.read(fd_rgb, total_rgb) if len(rgb_data) < 3: @@ -55,36 +45,28 @@ def main(): else: rgb_data_total = rgb_data os.close(fd_rgb) + # 识别 t1 = time.time() - img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1))\ - .transpose(0, 2, 1) + img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)) \ + .transpose(0, 2, 1) rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1)) - - # OFFSET compensate - if Config.offset_vertical < 0: - # 纵向的补偿小于0,那就意味着光谱图要上移才能补上,那么我们应该补偿SPEC相机的全 0 图像 - new_conserve_part, real_part = img_data[:abs(Config.offset_vertical) // 4, ...],\ - img_data[abs(Config.offset_vertical) // 4:, ...] - img_data = np.concatenate([real_part, conserve_part], axis=0) - conserve_part = new_conserve_part - elif Config.offset_vertical > 0: - # 纵向的补偿小于0,说明光谱图下移才能补上去,那么我们就需要补偿RGB相机的全 0 图像 - new_conserve_part, real_part = rgb_data[:abs(Config.offset_vertical), ...],\ - rgb_data[abs(Config.offset_vertical):, ...] - rgb_data = np.concatenate([real_part, conserve_part], axis=0) - conserve_part = new_conserve_part - # 光谱识别 - mask_spec = spec_detector.predict(img_data) - # rgb识别 - mask_rgb = rgb_detector.predict(rgb_data) - # 结果合并 - mask_result = (mask_spec | mask_rgb).astype(np.uint8) + if only_spec: + # 光谱识别 + mask_spec = spec_detector.predict(img_data) + mask_rgb = np.zeros_like(mask_spec, dtype=np.uint8) + elif only_color: + # rgb识别 + mask_rgb = rgb_detector.predict(rgb_data) + mask_spec = np.zeros_like(mask_rgb, dtype=np.uint8) + else: + mask_spec = spec_detector.predict(img_data) + mask_rgb = rgb_detector.predict(rgb_data) # control the size of the output masks - masks = [cv2.resize(mask.astype(np.uint8), Config.target_size) for mask in [mask_result, ]] + masks = [cv2.resize(mask.astype(np.uint8), Config.target_size) for mask in [mask_spec, mask_rgb]] # 写出 - output_fifos = [mask_fifo_path, ] + output_fifos = [mask_fifo_path, rgb_mask_fifo_path] for fifo, mask in zip(output_fifos, masks): fd_mask = os.open(fifo, os.O_WRONLY) os.write(fd_mask, mask.tobytes()) @@ -94,12 +76,15 @@ def main(): if __name__ == '__main__': - # 相关参数 + import argparse + parser = argparse.ArgumentParser(description='主程序') + parser.add_argument('-oc', default=False, action='store_true', help='只进行RGB彩色预测 only rgb', required=False) + parser.add_argument('-os', default=False, action='store_true', help='只进行光谱预测 only spec', required=False) + args = parser.parse_args() + # fifo 参数 img_fifo_path = "/tmp/dkimg.fifo" rgb_fifo_path = "/tmp/dkrgb.fifo" - + # mask fifo mask_fifo_path = "/tmp/dkmask.fifo" - # 主函数 - main() - # read_c_captures('/home/lzy/2022.7.15/tobacco_v1_0/', no_mask=True, nrows=256, ncols=1024, - # selected_bands=[380, 300, 200]) + rgb_mask_fifo_path = "/tmp/dkmask_rgb.fifo" + main(only_spec=args.os, only_color=args.oc)