supermachine-tobacco/main.py
2022-07-28 12:39:54 +08:00

118 lines
4.3 KiB
Python
Executable File

import os
import time
from queue import Queue
import numpy as np
from matplotlib import pyplot as plt
import models
import transmit
from config import Config
from models import RgbDetector, SpecDetector
import cv2
def main():
spec_detector = SpecDetector(blk_model_path=Config.blk_model_path, pixel_model_path=Config.pixel_model_path)
rgb_detector = RgbDetector(tobacco_model_path=Config.rgb_tobacco_model_path,
background_model_path=Config.rgb_background_model_path)
total_len = Config.nRows * Config.nCols * Config.nBands * 4 # float型变量, 4个字节
total_rgb = Config.nRgbRows * Config.nRgbCols * Config.nRgbBands * 1 # int型变量
if not os.access(img_fifo_path, os.F_OK):
os.mkfifo(img_fifo_path, 0o777)
if not os.access(mask_fifo_path, os.F_OK):
os.mkfifo(mask_fifo_path, 0o777)
if not os.access(rgb_fifo_path, os.F_OK):
os.mkfifo(rgb_fifo_path, 0o777)
while True:
fd_img = os.open(img_fifo_path, os.O_RDONLY)
fd_rgb = os.open(rgb_fifo_path, os.O_RDONLY)
# spec data read
data = os.read(fd_img, total_len)
if len(data) < 3:
threshold = int(float(data))
Config.spec_size_threshold = threshold
print("[INFO] Get spec threshold: ", threshold)
else:
data_total = data
os.close(fd_img)
# rgb data read
rgb_data = os.read(fd_rgb, total_rgb)
if len(rgb_data) < 3:
rgb_threshold = int(float(rgb_data))
Config.rgb_size_threshold = rgb_threshold
print("[INFO] Get rgb threshold", rgb_threshold)
continue
else:
rgb_data_total = rgb_data
os.close(fd_rgb)
# 识别
t1 = time.time()
img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)) \
.transpose(0, 2, 1)
rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
# 光谱识别
mask = spec_detector.predict(img_data)
# rgb识别
mask_rgb = rgb_detector.predict(rgb_data)
# 结果合并
mask_result = (mask | mask_rgb).astype(np.uint8)
# mask_result = mask_rgb.astype(np.uint8)
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
t2 = time.time()
print(f'rgb len = {len(rgb_data)}')
# 写出
fd_mask = os.open(mask_fifo_path, os.O_WRONLY)
os.write(fd_mask, mask_result.tobytes())
os.close(fd_mask)
t3 = time.time()
print(f'total time is:{t3 - t1}')
def read_c_captures(buffer_path, no_mask=True, nrows=256, ncols=1024, selected_bands=None):
if os.path.isdir(buffer_path):
buffer_names = [buffer_name for buffer_name in os.listdir(buffer_path) if buffer_name.endswith('.raw')]
else:
buffer_names = [buffer_path, ]
for buffer_name in buffer_names:
with open(os.path.join(buffer_path, buffer_name), 'rb') as f:
data = f.read()
img = np.frombuffer(data, dtype=np.float32).reshape((nrows, -1, ncols)) \
.transpose(0, 2, 1)
if selected_bands is not None:
img = img[..., selected_bands]
if img.shape[0] == 1:
img = img[0, ...]
if not no_mask:
mask_name = buffer_name.replace('buf', 'mask').replace('.raw', '')
with open(os.path.join(buffer_path, mask_name), 'rb') as f:
data = f.read()
mask = np.frombuffer(data, dtype=np.uint8).reshape((nrows, ncols, -1))
else:
mask_name = "no mask"
mask = np.zeros_like(img)
# mask = cv2.resize(mask, (1024, 256))
fig, axs = plt.subplots(2, 1)
axs[0].matshow(img)
axs[0].set_title(buffer_name)
axs[1].imshow(mask)
axs[1].set_title(mask_name)
plt.show()
if __name__ == '__main__':
# 相关参数
img_fifo_path = "/tmp/dkimg.fifo"
mask_fifo_path = "/tmp/dkmask.fifo"
rgb_fifo_path = "/tmp/dkrgb.fifo"
# 主函数
main()
# read_c_captures('/home/lzy/2022.7.15/tobacco_v1_0/', no_mask=True, nrows=256, ncols=1024,
# selected_bands=[380, 300, 200])