mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 22:33:54 +00:00
[ext] 修改transmit到一半
This commit is contained in:
parent
1f62ba2e4e
commit
e7e7de43e6
2
main.py
2
main.py
@ -38,7 +38,7 @@ def main(only_spec=False, only_color=False):
|
||||
try:
|
||||
threshold = int(float(data))
|
||||
Config.spec_size_threshold = threshold
|
||||
logging.info(f'[INFO] Get spec threshold: {threshold}')
|
||||
logging.info(f'Get spec threshold: {threshold}')
|
||||
except Exception as e:
|
||||
logging.error(f'毁灭性错误:收到长度小于3却无法转化为整数spec_size_threshold的网络报文,报文内容为 {data},'
|
||||
f' 错误为 {e}.')
|
||||
|
||||
117
transmit.py
117
transmit.py
@ -3,19 +3,28 @@ import threading
|
||||
from multiprocessing import Process, Queue
|
||||
import time
|
||||
|
||||
import cv2
|
||||
|
||||
import utils
|
||||
from utils import ImgQueue as ImgQueue
|
||||
import functools
|
||||
import numpy as np
|
||||
from config import Config
|
||||
from models import SpecDetector, RgbDetector
|
||||
import typing
|
||||
from typing import Any
|
||||
import logging
|
||||
logging.basicConfig(format='%(asctime)s %(levelname)s %(name)s %(message)s',
|
||||
level=logging.WARNING)
|
||||
|
||||
|
||||
class Transmitter(object):
|
||||
def __init__(self):
|
||||
def __init__(self, job_name:str, run_process:bool = False):
|
||||
self.output = None
|
||||
self.job_name = job_name
|
||||
self.run_process = run_process # If true, run process when started else run thread.
|
||||
self._thread_stop = threading.Event()
|
||||
self._thread_stop.clear()
|
||||
self._running_handler = None
|
||||
|
||||
def set_source(self, *args, **kwargs):
|
||||
"""
|
||||
@ -36,20 +45,41 @@ class Transmitter(object):
|
||||
|
||||
def start(self, *args, **kwargs):
|
||||
"""
|
||||
启动接收线程或进程
|
||||
启动线程或进程
|
||||
:param args:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
name = kwargs.get('name', default='base thread')
|
||||
if not self.run_process:
|
||||
self._running_handler = threading.Thread(target=self.job_func, name=name, args=args)
|
||||
else:
|
||||
self._running_handler = Process(target=self.job_func, name=name, args=args, daemon=True)
|
||||
self._running_handler.start()
|
||||
|
||||
def stop(self, *args, **kwargs):
|
||||
"""
|
||||
停止接收线程或进程
|
||||
停止线程或进程
|
||||
:param args:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
if self._running_handler is not None:
|
||||
self._thread_stop.set()
|
||||
self._running_handler = None
|
||||
|
||||
@staticmethod
|
||||
def job_decorator(func):
|
||||
functools.wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
logging.info(f'{self.job_name} {"process" if self.run_process else "thread"} start.')
|
||||
while not self._thread_stop.is_set():
|
||||
self.job_func(*args, **kwargs)
|
||||
logging.info(f'{self.job_name} {"process" if self.run_process else "thread"} stop.')
|
||||
self._need_stop.clear()
|
||||
return wrapper
|
||||
|
||||
def job_func(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -84,8 +114,9 @@ class BeforeAfterMethods:
|
||||
|
||||
|
||||
class FifoReceiver(Transmitter):
|
||||
def __init__(self, fifo_path: str, output: ImgQueue, read_max_num: int, msg_queue=None):
|
||||
super().__init__()
|
||||
def __init__(self, job_name:str, fifo_path: str, output: ImgQueue,
|
||||
read_max_num: int, msg_queue=None):
|
||||
super().__init__(job_name=job_name)
|
||||
self._input_fifo_path = None
|
||||
self._output_queue = None
|
||||
self._msg_queue = msg_queue
|
||||
@ -93,9 +124,6 @@ class FifoReceiver(Transmitter):
|
||||
|
||||
self.set_source(fifo_path)
|
||||
self.set_output(output)
|
||||
self._need_stop = threading.Event()
|
||||
self._need_stop.clear()
|
||||
self._running_thread = None
|
||||
|
||||
def set_source(self, fifo_path: str):
|
||||
if not os.access(fifo_path, os.F_OK):
|
||||
@ -105,29 +133,21 @@ class FifoReceiver(Transmitter):
|
||||
def set_output(self, output: ImgQueue):
|
||||
self._output_queue = output
|
||||
|
||||
def start(self, post_process_func=None, name='fifo_receiver'):
|
||||
self._running_thread = threading.Thread(target=self._receive_thread_func,
|
||||
name=name, args=(post_process_func, ))
|
||||
self._running_thread.start()
|
||||
|
||||
def stop(self):
|
||||
self._need_stop.set()
|
||||
|
||||
def _receive_thread_func(self, post_process_func=None):
|
||||
@Transmitter.job_decorator
|
||||
def job_func(self, post_process_func=None):
|
||||
"""
|
||||
接收线程
|
||||
|
||||
:param post_process_func:
|
||||
:return:
|
||||
"""
|
||||
while not self._need_stop.is_set():
|
||||
input_fifo = os.open(self._input_fifo_path, os.O_RDONLY)
|
||||
data = os.read(input_fifo, self._max_len)
|
||||
if post_process_func is not None:
|
||||
data = post_process_func(data)
|
||||
self._output_queue.safe_put(data)
|
||||
os.close(input_fifo)
|
||||
self._need_stop.clear()
|
||||
input_fifo = os.open(self._input_fifo_path, os.O_RDONLY)
|
||||
data = os.read(input_fifo, self._max_len)
|
||||
if post_process_func is not None:
|
||||
data = post_process_func(data)
|
||||
self._output_queue.safe_put(data)
|
||||
os.close(input_fifo)
|
||||
|
||||
|
||||
|
||||
class FifoSender(Transmitter):
|
||||
@ -220,6 +240,7 @@ class CmdImgSplitMidware(Transmitter):
|
||||
# 看是不是命令需要执行如果是命令,就执行
|
||||
Config.rgb_size_threshold = rgb_data
|
||||
Config.spec_size_threshold = spec_data
|
||||
logging.info("获取到指令")
|
||||
continue
|
||||
elif isinstance(spec_data, np.ndarray) and isinstance(rgb_data, np.ndarray):
|
||||
# 如果是图片,交给预测的人
|
||||
@ -228,6 +249,7 @@ class CmdImgSplitMidware(Transmitter):
|
||||
subscriber.safe_put(item)
|
||||
else:
|
||||
# 否则程序出现毁灭性问题,立刻崩
|
||||
logging.critical('两个相机传回的数据没有对上')
|
||||
raise Exception("两个相机传回的数据没有对上")
|
||||
self.thread_stop.clear()
|
||||
|
||||
@ -316,27 +338,44 @@ class ProcessDetector(Transmitter):
|
||||
self._predict_thread.start()
|
||||
|
||||
def predict(self, spec: np.ndarray, rgb: np.ndarray):
|
||||
logging.info(f'Detector get image with shape {spec.shape} and {rgb.shape}')
|
||||
logging.debug(f'Detector get image with shape {spec.shape} and {rgb.shape}')
|
||||
t1 = time.time()
|
||||
mask = self._spec_detector.predict(spec)
|
||||
mask_spec = self._spec_detector.predict(spec)
|
||||
t2 = time.time()
|
||||
logging.info(f'Detector finish spec predict within {(t2 - t1) * 1000:.2f}ms')
|
||||
logging.debug(f'Detector finish spec predict within {(t2 - t1) * 1000:.2f}ms')
|
||||
# rgb识别
|
||||
mask_rgb = self._rgb_detector.predict(rgb)
|
||||
t3 = time.time()
|
||||
logging.info(f'Detector finish rgb predict within {(t3 - t2) * 1000:.2f}ms')
|
||||
logging.debug(f'Detector finish rgb predict within {(t3 - t2) * 1000:.2f}ms')
|
||||
# 结果合并
|
||||
mask_result = (mask | mask_rgb).astype(np.uint8)
|
||||
mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
|
||||
# mask_result = (mask | mask_rgb).astype(np.uint8)
|
||||
# mask_result = mask_result.repeat(Config.blk_size, axis=0).repeat(Config.blk_size, axis=1).astype(np.uint8)
|
||||
# 进行多个喷阀的合并
|
||||
masks = [utils.valve_expend(mask) for mask in [mask_spec, mask_rgb]]
|
||||
# 进行喷阀同时开启限制
|
||||
masks = [utils.valve_limit(mask, Config.max_open_valve_limit) for mask in masks]
|
||||
# control the size of the output masks, 在resize前,图像的宽度是和喷阀对应的
|
||||
masks = [cv2.resize(mask.astype(np.uint8), Config.target_size) for mask in masks]
|
||||
t4 = time.time()
|
||||
logging.info(f'Detector finish merge within {(t4 - t3) * 1000: .2f}ms')
|
||||
logging.info(f'Detector finish predict within {(time.time() -t1)*1000:.2f}ms')
|
||||
return mask_result
|
||||
logging.debug(f'Detector finish merge within {(t4 - t3) * 1000: .2f}ms')
|
||||
logging.debug(f'Detector finish predict within {(time.time() -t1)*1000:.2f}ms')
|
||||
return masks
|
||||
|
||||
def _predict_server(self):
|
||||
while not self._thread_exit.is_set():
|
||||
if not self._input_queue.empty():
|
||||
spec, rgb = self._input_queue.get()
|
||||
mask = self.predict(spec, rgb)
|
||||
self._output_queue.put(mask)
|
||||
self._thread_exit.clear()
|
||||
masks = self.predict(spec, rgb)
|
||||
self._output_queue.put(masks[:])
|
||||
self._thread_exit.clear()
|
||||
|
||||
|
||||
class SplitMidware(Transmitter):
|
||||
def set_source(self, mask_source: ImgQueue):
|
||||
|
||||
|
||||
def start(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def stop(self, *args, **kwargs):
|
||||
pass
|
||||
Loading…
Reference in New Issue
Block a user