From 92948f73d7a234e6e4ab6a1e562c8f7b786c26cf Mon Sep 17 00:00:00 2001
From: "li.zhenye"
Date: Mon, 8 Aug 2022 00:25:05 +0800
Subject: [PATCH] =?UTF-8?q?[ext]=20=E6=94=AF=E6=8C=81=E6=9C=80=E5=A4=A7?=
=?UTF-8?q?=E5=BC=80=E5=90=AF=E9=98=80=E9=97=A8=E6=95=B0=E9=87=8F=E9=99=90?=
=?UTF-8?q?=E5=88=B6=E5=8A=9F=E8=83=BDBeta?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
config.py | 1 +
main.py | 28 ++++++++++++++++------------
main_test.py | 1 +
tests/test_utils.py | 18 ++++++++++++++++++
utils.py | 33 +++++++++++++++++++++++++++++++++
5 files changed, 69 insertions(+), 12 deletions(-)
create mode 100644 tests/test_utils.py
diff --git a/config.py b/config.py
index 7685ea9..4b5f9b0 100644
--- a/config.py
+++ b/config.py
@@ -34,6 +34,7 @@ class Config:
# mask parameter
target_size = (1024, 1024) # (Width, Height) of mask
valve_merge_size = 2 # 每两个喷阀当中有任意一个出现杂质则认为都是杂质
+ max_open_valve_limit = 49 # 最大同时开启喷阀限制,按照电流计算,当前的喷阀可以开启的喷阀 600W的电源 / 12V电源 = 50A, 一个阀门1A
# save part
offset_vertical = 0
diff --git a/main.py b/main.py
index aabfd67..65c5fea 100755
--- a/main.py
+++ b/main.py
@@ -27,7 +27,7 @@ def main(only_spec=False, only_color=False):
os.mkfifo(mask_fifo_path, 0o777)
if not os.access(rgb_mask_fifo_path, os.F_OK):
os.mkfifo(rgb_mask_fifo_path, 0o777)
-
+ logging.info(f"请注意!正在以调试模式运行程序,输出的信息可能较多。")
while True:
fd_img = os.open(img_fifo_path, os.O_RDONLY)
fd_rgb = os.open(rgb_fifo_path, os.O_RDONLY)
@@ -35,9 +35,13 @@ def main(only_spec=False, only_color=False):
# spec data read
data = os.read(fd_img, total_len)
if len(data) < 3:
- threshold = int(float(data))
- Config.spec_size_threshold = threshold
- logging.info('[INFO] Get spec threshold: ', threshold)
+ try:
+ threshold = int(float(data))
+ Config.spec_size_threshold = threshold
+ logging.info('[INFO] Get spec threshold: ', threshold)
+ except Exception as e:
+ logging.error(f'毁灭性错误:收到长度小于3却无法转化为整数spec_size_threshold的网络报文,报文内容为 {data},'
+ f' 错误为 {e}.')
else:
data_total = data
os.close(fd_img)
@@ -49,7 +53,8 @@ def main(only_spec=False, only_color=False):
Config.rgb_size_threshold = rgb_threshold
logging.info(f'Get rgb threshold: {rgb_threshold}')
except Exception as e:
- logging.error(f'毁灭性错误:收到长度小于3却无法转化为整数的数据,{e}.')
+ logging.error(f'毁灭性错误:收到长度小于3却无法转化为整数spec_size_threshold的网络报文,报文内容为 {total_rgb},'
+ f' 错误为 {e}.')
continue
else:
rgb_data_total = rgb_data
@@ -78,8 +83,10 @@ def main(only_spec=False, only_color=False):
else:
mask_spec = spec_detector.predict(img_data).astype(np.uint8)
mask_rgb = rgb_detector.predict(rgb_data).astype(np.uint8)
- # 进行喷阀的合并
+ # 进行多个喷阀的合并
masks = [utils.valve_expend(mask) for mask in [mask_spec, mask_rgb]]
+ # 进行喷阀同时开启限制
+
# control the size of the output masks, 在resize前,图像的宽度是和喷阀对应的
masks = [cv2.resize(mask.astype(np.uint8), Config.target_size) for mask in masks]
# 写出
@@ -109,12 +116,9 @@ if __name__ == '__main__':
rgb_mask_fifo_path = '/tmp/dkmask_rgb.fifo'
# logging相关
file_handler = logging.FileHandler(os.path.join(Config.root_dir, '.tobacco_algorithm.log'))
- file_handler.setLevel(logging.WARNING)
+ file_handler.setLevel(logging.DEBUG if args.d else logging.WARNING)
console_handler = logging.StreamHandler(sys.stdout)
- if args.d:
- console_handler.setLevel(logging.DEBUG)
- else:
- console_handler.setLevel(logging.WARNING)
+ console_handler.setLevel(logging.WARNING)
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s',
- handlers=[file_handler, console_handler])
+ handlers=[file_handler, console_handler], level=logging.DEBUG)
main(only_spec=args.os, only_color=args.oc)
diff --git a/main_test.py b/main_test.py
index b244d92..0990385 100644
--- a/main_test.py
+++ b/main_test.py
@@ -60,6 +60,7 @@ class TestMain:
data = f.read()
spec_img = transmit.BeforeAfterMethods.spec_data_post_process(data)
if convert:
+ # TODO: 完善这个文件转换功能
spec_img_show = np.asarray(np.clip(spec_img[..., [21, 3, 0]] * 255, a_min=0, a_max=255),
dtype=np.uint8)
cv2.imwrite(rgb_file_name + '.bmp', spec_img_show[..., ::-1])
diff --git a/tests/test_utils.py b/tests/test_utils.py
new file mode 100644
index 0000000..a2743a9
--- /dev/null
+++ b/tests/test_utils.py
@@ -0,0 +1,18 @@
+import unittest
+
+import numpy as np
+
+from utils import valve_limit
+
+
+class UtilTestCase(unittest.TestCase):
+ mask_test = np.zeros((1024, 1024), dtype=np.uint8)
+ mask_test[0:20, :] = 1
+
+ def test_valve_limit(self):
+ mask_result = valve_limit(self.mask_test, max_valve_num=49)
+ self.assertTrue(np.all(np.sum(mask_result, 1) <= 49)) # add assertion here
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/utils.py b/utils.py
index 805a757..9e6c4ea 100755
--- a/utils.py
+++ b/utils.py
@@ -4,16 +4,23 @@
# @File: utils.py
# @Software:PyCharm
import glob
+import logging
import os
from queue import Queue
import cv2
import numpy as np
+from numpy.random import default_rng
from matplotlib import pyplot as plt
import re
def natural_sort(l):
+ """
+ 自然排序
+ :param l: 待排序
+ :return:
+ """
convert = lambda text: int(text) if text.isdigit() else text.lower()
alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key)]
return sorted(l, key=alphanum_key)
@@ -166,6 +173,32 @@ def valve_expend(img: np.ndarray) -> np.ndarray:
return img
+def valve_limit(mask: np.ndarray, max_valve_num: int) -> np.ndarray:
+ """
+ 用于限制阀门同时开启个数的函数
+ :param mask: 阀门开启mask,0,1格式,每个1对应一个阀门
+ :param max_valve_num: 最大阀门数量
+ :return:
+ """
+ assert (max_valve_num >= 5) and (max_valve_num < 50)
+ row_valve_count = np.sum(mask, axis=1)
+ if np.any(row_valve_count > max_valve_num):
+ over_rows_idx = np.argwhere(row_valve_count > max_valve_num).ravel()
+ logging.warning(f'发现单行喷阀数量{len(over_rows_idx)}超过限制,已限制到最大许可值{max_valve_num}')
+ over_rows = mask[over_rows_idx, :]
+
+ # a simple function to get lucky valves when too many valves appear in the same line
+ def process_row(each_row):
+ valve_idx = np.argwhere(each_row > 0).ravel()
+ lucky_valve_idx = default_rng().choice(valve_idx, max_valve_num)
+ new_row = np.zeros_like(each_row)
+ new_row[lucky_valve_idx] = 1
+ return new_row
+ limited_rows = np.apply_along_axis(process_row, 1, over_rows)
+ mask[over_rows_idx] = limited_rows
+ return mask
+
+
def read_envi_ascii(file_name, save_xy=False, hdr_file_name=None):
"""
Read envi ascii file. Use ENVI ROI Tool -> File -> output ROIs to ASCII...