mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 14:23:55 +00:00
Merge remote-tracking branch 'originn/master'
This commit is contained in:
commit
30041c1aed
@ -1,4 +1,5 @@
|
|||||||
import torch
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
@ -36,3 +37,6 @@ class Config:
|
|||||||
|
|
||||||
# save part
|
# save part
|
||||||
offset_vertical = 0
|
offset_vertical = 0
|
||||||
|
|
||||||
|
# logging
|
||||||
|
root_dir = os.path.split(os.path.realpath(__file__))[0]
|
||||||
|
|||||||
46
main.py
46
main.py
@ -1,4 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import time
|
import time
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -6,6 +8,7 @@ import numpy as np
|
|||||||
import utils
|
import utils
|
||||||
from config import Config
|
from config import Config
|
||||||
from models import RgbDetector, SpecDetector
|
from models import RgbDetector, SpecDetector
|
||||||
|
import logging
|
||||||
|
|
||||||
|
|
||||||
def main(only_spec=False, only_color=False):
|
def main(only_spec=False, only_color=False):
|
||||||
@ -34,27 +37,34 @@ def main(only_spec=False, only_color=False):
|
|||||||
if len(data) < 3:
|
if len(data) < 3:
|
||||||
threshold = int(float(data))
|
threshold = int(float(data))
|
||||||
Config.spec_size_threshold = threshold
|
Config.spec_size_threshold = threshold
|
||||||
print("[INFO] Get spec threshold: ", threshold)
|
logging.info('[INFO] Get spec threshold: ', threshold)
|
||||||
else:
|
else:
|
||||||
data_total = data
|
data_total = data
|
||||||
os.close(fd_img)
|
os.close(fd_img)
|
||||||
|
|
||||||
# rgb data read
|
# rgb data read
|
||||||
rgb_data = os.read(fd_rgb, total_rgb)
|
rgb_data = os.read(fd_rgb, total_rgb)
|
||||||
if len(rgb_data) < 3:
|
if len(rgb_data) < 3:
|
||||||
|
try:
|
||||||
rgb_threshold = int(float(rgb_data))
|
rgb_threshold = int(float(rgb_data))
|
||||||
Config.rgb_size_threshold = rgb_threshold
|
Config.rgb_size_threshold = rgb_threshold
|
||||||
print("[INFO] Get rgb threshold", rgb_threshold)
|
logging.info(f'Get rgb threshold: {rgb_threshold}')
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f'毁灭性错误:收到长度小于3却无法转化为整数的数据,{e}.')
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
rgb_data_total = rgb_data
|
rgb_data_total = rgb_data
|
||||||
os.close(fd_rgb)
|
os.close(fd_rgb)
|
||||||
|
|
||||||
# 识别
|
# 识别
|
||||||
t1 = time.time()
|
since = time.time()
|
||||||
|
try:
|
||||||
img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)) \
|
img_data = np.frombuffer(data_total, dtype=np.float32).reshape((Config.nRows, Config.nBands, -1)) \
|
||||||
.transpose(0, 2, 1)
|
.transpose(0, 2, 1)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f'毁灭性错误!收到的光谱数据长度为{len(data_total)}无法转化成指定的形状 {e}')
|
||||||
|
try:
|
||||||
rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
|
rgb_data = np.frombuffer(rgb_data_total, dtype=np.uint8).reshape((Config.nRgbRows, Config.nRgbCols, -1))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f'毁灭性错误!收到的rgb数据长度为{len(rgb_data)}无法转化成指定形状 {e}')
|
||||||
if only_spec:
|
if only_spec:
|
||||||
# 光谱识别
|
# 光谱识别
|
||||||
mask_spec = spec_detector.predict(img_data)
|
mask_spec = spec_detector.predict(img_data)
|
||||||
@ -68,7 +78,6 @@ def main(only_spec=False, only_color=False):
|
|||||||
else:
|
else:
|
||||||
mask_spec = spec_detector.predict(img_data)
|
mask_spec = spec_detector.predict(img_data)
|
||||||
mask_rgb = rgb_detector.predict(rgb_data)
|
mask_rgb = rgb_detector.predict(rgb_data)
|
||||||
|
|
||||||
# 进行喷阀的合并
|
# 进行喷阀的合并
|
||||||
masks = [utils.valve_expend(mask) for mask in [mask_spec, mask_rgb]]
|
masks = [utils.valve_expend(mask) for mask in [mask_spec, mask_rgb]]
|
||||||
# control the size of the output masks, 在resize前,图像的宽度是和喷阀对应的
|
# control the size of the output masks, 在resize前,图像的宽度是和喷阀对应的
|
||||||
@ -79,8 +88,10 @@ def main(only_spec=False, only_color=False):
|
|||||||
fd_mask = os.open(fifo, os.O_WRONLY)
|
fd_mask = os.open(fifo, os.O_WRONLY)
|
||||||
os.write(fd_mask, mask.tobytes())
|
os.write(fd_mask, mask.tobytes())
|
||||||
os.close(fd_mask)
|
os.close(fd_mask)
|
||||||
t3 = time.time()
|
time_spent = (time.time() - since) * 1000
|
||||||
print(f'total time is:{t3 - t1}')
|
logging.info(f'Total time is: {time_spent:.2f} ms')
|
||||||
|
if time_spent > 200:
|
||||||
|
logging.warning(f'警告预测超时,预测耗时超过了200ms,The prediction time is {time_spent:.2f} ms.')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
@ -88,11 +99,22 @@ if __name__ == '__main__':
|
|||||||
parser = argparse.ArgumentParser(description='主程序')
|
parser = argparse.ArgumentParser(description='主程序')
|
||||||
parser.add_argument('-oc', default=False, action='store_true', help='只进行RGB彩色预测 only rgb', required=False)
|
parser.add_argument('-oc', default=False, action='store_true', help='只进行RGB彩色预测 only rgb', required=False)
|
||||||
parser.add_argument('-os', default=False, action='store_true', help='只进行光谱预测 only spec', required=False)
|
parser.add_argument('-os', default=False, action='store_true', help='只进行光谱预测 only spec', required=False)
|
||||||
|
parser.add_argument('-d', default=False, action='store_true', help='是否使用DEBUG模式', required=False)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
# fifo 参数
|
# fifo 参数
|
||||||
img_fifo_path = "/tmp/dkimg.fifo"
|
img_fifo_path = '/tmp/dkimg.fifo'
|
||||||
rgb_fifo_path = "/tmp/dkrgb.fifo"
|
rgb_fifo_path = '/tmp/dkrgb.fifo'
|
||||||
# mask fifo
|
# mask fifo
|
||||||
mask_fifo_path = "/tmp/dkmask.fifo"
|
mask_fifo_path = '/tmp/dkmask.fifo'
|
||||||
rgb_mask_fifo_path = "/tmp/dkmask_rgb.fifo"
|
rgb_mask_fifo_path = '/tmp/dkmask_rgb.fifo'
|
||||||
|
# logging相关
|
||||||
|
file_handler = logging.FileHandler(os.path.join(Config.root_dir, '.tobacco_algorithm.log'))
|
||||||
|
file_handler.setLevel(logging.WARNING)
|
||||||
|
console_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
if args.d:
|
||||||
|
console_handler.setLevel(logging.DEBUG)
|
||||||
|
else:
|
||||||
|
console_handler.setLevel(logging.WARNING)
|
||||||
|
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
|
||||||
|
handlers=[file_handler, console_handler])
|
||||||
main(only_spec=args.os, only_color=args.oc)
|
main(only_spec=args.os, only_color=args.oc)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user