Compare commits

..

7 Commits

Author SHA1 Message Date
duanmu
3b89dff026 fix:更新requirements.txt 2024-01-05 14:48:25 +08:00
duanmu
c391cfa7d9 fix:更新23年12月现场参数:
光谱和rgb的lab参数
更新rgb背景模型

未在现场更新的参数:
yolo模型
阈值

修正了yolo不起作用的问题
2024-01-03 14:50:58 +08:00
feijinti
192bb8db99 fix:增加了光谱伪彩色图lab 2023-09-24 12:59:41 +08:00
duanmu
c4191be192 fix:
添加了lab颜色空间识别绿色杂质、蓝色杂质,添加a、b的阈值为127,尺寸限制改为6。
优化了lab色彩空间内绘制3维数据分布情况。
喷阀横向膨胀尺寸改为5。
添加test_rgb测试单张图片
2023-09-20 20:50:14 +08:00
FEIJINTI
f2c614dbcf 屏蔽了右侧10个喷阀 2023-03-17 13:25:05 +08:00
FEIJINTI
2af3d8091f 屏蔽了左侧10个喷阀 2022-11-15 12:47:11 +08:00
FEIJINTI
edd3b0f9c8 屏蔽了左侧8个喷阀 2022-11-15 12:46:20 +08:00
9 changed files with 144 additions and 59 deletions

View File

@ -12,24 +12,8 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 7,
"outputs": [ "outputs": [],
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\FEIJINTI\\miniconda3\\envs\\deepo\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training env\n"
]
}
],
"source": [ "source": [
"import numpy as np\n", "import numpy as np\n",
"import scipy\n", "import scipy\n",
@ -46,7 +30,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 8,
"outputs": [], "outputs": [],
"source": [ "source": [
"train_from_existed = False # 是否从现有数据训练如果是的话那就从dataset_file训练否则就用data_dir里头的数据\n", "train_from_existed = False # 是否从现有数据训练如果是的话那就从dataset_file训练否则就用data_dir里头的数据\n",
@ -54,8 +38,8 @@
"dataset_file = \"data/dataset/dataset_2022-07-20_10-04.mat\"\n", "dataset_file = \"data/dataset/dataset_2022-07-20_10-04.mat\"\n",
"\n", "\n",
"# color_dict = {(0, 0, 255): \"yangeng\", (255, 0, 0): 'beijing',(0, 255, 0): \"zibian\"} # 颜色对应的类别\n", "# color_dict = {(0, 0, 255): \"yangeng\", (255, 0, 0): 'beijing',(0, 255, 0): \"zibian\"} # 颜色对应的类别\n",
"color_dict = {(0, 0, 255): \"yangeng\"}\n", "# color_dict = {(0, 0, 255): \"yangeng\"}\n",
"# color_dict = {(255, 0, 0): 'beijing'}\n", "color_dict = {(255, 0, 0): 'beijing'}\n",
"# color_dict = {(0, 255, 0): \"zibian\"}\n", "# color_dict = {(0, 255, 0): \"zibian\"}\n",
"label_index = {\"yangeng\": 1, \"beijing\": 0, \"zibian\":2} # 类别对应的序号\n", "label_index = {\"yangeng\": 1, \"beijing\": 0, \"zibian\":2} # 类别对应的序号\n",
"show_samples = False # 是否展示样本\n", "show_samples = False # 是否展示样本\n",
@ -86,7 +70,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 9,
"outputs": [], "outputs": [],
"source": [ "source": [
"dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)\n", "dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)\n",
@ -169,7 +153,7 @@
"name": "stderr", "name": "stderr",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"23961it [00:46, 519.87it/s] " "152147it [42:48, 59.24it/s] \n"
] ]
}, },
{ {
@ -178,19 +162,12 @@
"text": [ "text": [
" precision recall f1-score support\n", " precision recall f1-score support\n",
"\n", "\n",
" 0 1.00 1.00 1.00 7188\n", " 0 1.00 1.00 1.00 45644\n",
" 1 1.00 1.00 1.00 5991\n", " 1 1.00 1.00 1.00 38037\n",
"\n", "\n",
" accuracy 1.00 13179\n", " accuracy 1.00 83681\n",
" macro avg 1.00 1.00 1.00 13179\n", " macro avg 1.00 1.00 1.00 83681\n",
"weighted avg 1.00 1.00 1.00 13179\n", "weighted avg 1.00 1.00 1.00 83681\n",
"\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\n" "\n"
] ]
} }

View File

@ -15,14 +15,20 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 1,
"outputs": [ "outputs": [
{ {
"name": "stderr", "ename": "NotImplementedError",
"output_type": "stream", "evalue": "开发时和机器上部署的路径不同请注意选择rgb_tobacco_model_path、rgb_background_model_path、ai_path后删除本行",
"text": [ "output_type": "error",
"C:\\Users\\FEIJINTI\\miniconda3\\envs\\deepo\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", "traceback": [
" from .autonotebook import tqdm as notebook_tqdm\n" "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[1;31mNotImplementedError\u001B[0m Traceback (most recent call last)",
"Input \u001B[1;32mIn [1]\u001B[0m, in \u001B[0;36m<cell line: 3>\u001B[1;34m()\u001B[0m\n\u001B[0;32m 1\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mnumpy\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m \u001B[38;5;21;01mnp\u001B[39;00m\n\u001B[0;32m 2\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mpickle\u001B[39;00m\n\u001B[1;32m----> 3\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mutils\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m read_envi_ascii\n\u001B[0;32m 4\u001B[0m \u001B[38;5;66;03m# from config import Config\u001B[39;00m\n\u001B[0;32m 5\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mmodels\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m ManualTree\n",
"File \u001B[1;32m~\\OneDrive - macrosolid\\PycharmProjects\\tobacco_color\\utils\\__init__.py:20\u001B[0m, in \u001B[0;36m<module>\u001B[1;34m\u001B[0m\n\u001B[0;32m 17\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mmatplotlib\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m pyplot \u001B[38;5;28;01mas\u001B[39;00m plt\n\u001B[0;32m 18\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mre\u001B[39;00m\n\u001B[1;32m---> 20\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mconfig\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Config\n\u001B[0;32m 23\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mnatural_sort\u001B[39m(l):\n\u001B[0;32m 24\u001B[0m \u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[0;32m 25\u001B[0m \u001B[38;5;124;03m 自然排序\u001B[39;00m\n\u001B[0;32m 26\u001B[0m \u001B[38;5;124;03m :param l: 待排序\u001B[39;00m\n\u001B[0;32m 27\u001B[0m \u001B[38;5;124;03m :return:\u001B[39;00m\n\u001B[0;32m 28\u001B[0m \u001B[38;5;124;03m \"\"\"\u001B[39;00m\n",
"File \u001B[1;32m~\\OneDrive - macrosolid\\PycharmProjects\\tobacco_color\\config.py:6\u001B[0m, in \u001B[0;36m<module>\u001B[1;34m\u001B[0m\n\u001B[0;32m 1\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mos\u001B[39;00m\n\u001B[0;32m 3\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mnumpy\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m \u001B[38;5;21;01mnp\u001B[39;00m\n\u001B[1;32m----> 6\u001B[0m \u001B[38;5;28;01mclass\u001B[39;00m \u001B[38;5;21;01mConfig\u001B[39;00m:\n\u001B[0;32m 7\u001B[0m \u001B[38;5;66;03m# 文件相关参数\u001B[39;00m\n\u001B[0;32m 8\u001B[0m nRows, nCols, nBands \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m256\u001B[39m, \u001B[38;5;241m1024\u001B[39m, \u001B[38;5;241m22\u001B[39m\n\u001B[0;32m 9\u001B[0m nRgbRows, nRgbCols, nRgbBands \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m1024\u001B[39m, \u001B[38;5;241m4096\u001B[39m, \u001B[38;5;241m3\u001B[39m\n",
"File \u001B[1;32m~\\OneDrive - macrosolid\\PycharmProjects\\tobacco_color\\config.py:31\u001B[0m, in \u001B[0;36mConfig\u001B[1;34m()\u001B[0m\n\u001B[0;32m 28\u001B[0m spec_size_threshold \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m3\u001B[39m\n\u001B[0;32m 30\u001B[0m \u001B[38;5;66;03m# rgb模型参数\u001B[39;00m\n\u001B[1;32m---> 31\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mNotImplementedError\u001B[39;00m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m开发时和机器上部署的路径不同请注意选择rgb_tobacco_model_path、rgb_background_model_path、ai_path后删除本行\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[0;32m 32\u001B[0m \u001B[38;5;66;03m# rgb_tobacco_model_path = r\"weights/tobacco_dt_2022-08-27_14-43.model\" # 开发时的路径\u001B[39;00m\n\u001B[0;32m 33\u001B[0m \u001B[38;5;66;03m# rgb_tobacco_model_path = r\"/home/dt/tobacco-color/weights/tobacco_dt_2022-08-27_14-43.model\" # 机器上部署的路径\u001B[39;00m\n\u001B[0;32m 34\u001B[0m \u001B[38;5;66;03m# rgb_background_model_path = r\"weights/background_dt_2022-08-22_22-15.model\" # 开发时的路径\u001B[39;00m\n\u001B[0;32m 35\u001B[0m \u001B[38;5;66;03m# rgb_background_model_path = r\"/home/dt/tobacco-color/weights/background_dt_2022-08-22_22-15.model\" # 机器上部署的路径\u001B[39;00m\n\u001B[0;32m 36\u001B[0m threshold_low, threshold_high \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m10\u001B[39m, \u001B[38;5;241m230\u001B[39m\n",
"\u001B[1;31mNotImplementedError\u001B[0m: 开发时和机器上部署的路径不同请注意选择rgb_tobacco_model_path、rgb_background_model_path、ai_path后删除本行"
] ]
} }
], ],

View File

@ -20,30 +20,34 @@ class Config:
# 光谱模型参数 # 光谱模型参数
blk_size = 4 # 必须是2的倍数不然会出错 blk_size = 4 # 必须是2的倍数不然会出错
raise NotImplementedError("开发时和机器上部署的路径不同请注意选择pixel_model_path、blk_model_path后删除本行") pixel_model_path = r"./weights/pixel_2022-08-02_15-22.model" # 开发时的路径
# pixel_model_path = r"./weights/pixel_2022-08-02_15-22.model" # 开发时的路径
# pixel_model_path = r"/home/dt/tobacco-color/weights/pixel_2022-08-02_15-22.model" # 机器上部署的路径 # pixel_model_path = r"/home/dt/tobacco-color/weights/pixel_2022-08-02_15-22.model" # 机器上部署的路径
# blk_model_path = r"./weights/rf_4x4_c22_20_sen8_9.model" # 开发时的路径 blk_model_path = r"./weights/rf_4x4_c22_20_sen8_9.model" # 开发时的路径
# blk_model_path = r"/home/dt/tobacco-color/weights/rf_4x4_c22_20_sen8_9.model" # 机器上部署的路径 # blk_model_path = r"/home/dt/tobacco-color/weights/rf_4x4_c22_20_sen8_9.model" # 机器上部署的路径
spec_size_threshold = 3 spec_size_threshold = 3
s_threshold_a = 124 # s_a的最高允许值
s_threshold_b = 124 # s_b的最高允许值
# rgb模型参数 # rgb模型参数
raise NotImplementedError("开发时和机器上部署的路径不同请注意选择rgb_tobacco_model_path、rgb_background_model_path、ai_path后删除本行") rgb_tobacco_model_path = r"weights/tobacco_dt_2022-08-27_14-43.model" # 开发时的路径
# rgb_tobacco_model_path = r"weights/tobacco_dt_2022-08-27_14-43.model" # 开发时的路径
# rgb_tobacco_model_path = r"/home/dt/tobacco-color/weights/tobacco_dt_2022-08-27_14-43.model" # 机器上部署的路径 # rgb_tobacco_model_path = r"/home/dt/tobacco-color/weights/tobacco_dt_2022-08-27_14-43.model" # 机器上部署的路径
# rgb_background_model_path = r"weights/background_dt_2022-08-22_22-15.model" # 开发时的路径 rgb_background_model_path = r"weights/background_dt_2023-12-26_20-39.model" # 开发时的路径
# rgb_background_model_path = r"/home/dt/tobacco-color/weights/background_dt_2022-08-22_22-15.model" # 机器上部署的路径 # rgb_background_model_path = r"/home/dt/tobacco-color/weights/background_dt_2022-08-22_22-15.model" # 机器上部署的路径
threshold_low, threshold_high = 10, 230 threshold_low, threshold_high = 10, 230
threshold_s = 190 # 饱和度的最高允许值 threshold_s = 190 # 饱和度的最高允许值
rgb_size_threshold = 4 # rgb的尺寸限制 threshold_a = 125 # a的最高允许值
# ai_path = 'weights/best0827.pt' # 开发时的路径 threshold_b = 126 # b的最高允许值
rgb_size_threshold = 6 # rgb的尺寸限制
lab_size_threshold = 6 # lab的尺寸限制
ai_path = 'weights/best1227.pt' # 开发时的路径
# ai_path = '/home/dt/tobacco-color/weights/best0827.pt' # 机器上部署的路径 # ai_path = '/home/dt/tobacco-color/weights/best0827.pt' # 机器上部署的路径
ai_conf_threshold = 0.6 ai_conf_threshold = 0.8
# mask parameter # mask parameter
target_size = (1024, 1024) # (Width, Height) of mask target_size = (1024, 1024) # (Width, Height) of mask
valve_merge_size = 2 # 每两个喷阀当中有任意一个出现杂质则认为都是杂质 valve_merge_size = 2 # 每两个喷阀当中有任意一个出现杂质则认为都是杂质
valve_horizontal_padding = 3 # 喷阀横向膨胀的尺寸,应该是奇数,3时表示左右各膨胀1 valve_horizontal_padding = 5 # 喷阀横向膨胀的尺寸,应该是奇数,3时表示左右各膨胀1,5时表示左右各膨胀2(23.9.20老倪要求改为5)
max_open_valve_limit = 25 # 最大同时开启喷阀限制,按照电流计算,当前的喷阀可以开启的喷阀 600W的电源 / 12V电源 = 50A, 一个阀门1A max_open_valve_limit = 25 # 最大同时开启喷阀限制,按照电流计算,当前的喷阀可以开启的喷阀 600W的电源 / 12V电源 = 50A, 一个阀门1A
max_time_spent = 200 max_time_spent = 200
# save part # save part

View File

@ -48,7 +48,8 @@ class SugarDetect(object):
img = letterbox(img, (imgsz, imgsz), stride=stride)[0] img = letterbox(img, (imgsz, imgsz), stride=stride)[0]
# Convert # Convert
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416 # img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
img = img.transpose(2, 0, 1) # to 3x416x416
img = np.ascontiguousarray(img) img = np.ascontiguousarray(img)
# Preprocess # Preprocess

View File

@ -124,6 +124,7 @@ def main(only_spec=False, only_color=False, if_merge=False, interval_time=None,
mask_rgb = rgb_detector.predict(rgb_data).astype(np.uint8) mask_rgb = rgb_detector.predict(rgb_data).astype(np.uint8)
masks = [mask_spec, mask_rgb] masks = [mask_spec, mask_rgb]
# 进行多个喷阀的合并 # 进行多个喷阀的合并
masks = [utils_customized.shield_valve(mask, left_shield=10, right_shield=10) for mask in masks]
masks = [utils_customized.valve_expend(mask) for mask in masks] masks = [utils_customized.valve_expend(mask) for mask in masks]
mask_nums = sum([np.sum(np.sum(mask)) for mask in masks]) mask_nums = sum([np.sum(np.sum(mask)) for mask in masks])
log_time_count += 1 log_time_count += 1

View File

@ -15,6 +15,7 @@ from scipy.ndimage import binary_dilation
from sklearn.tree import DecisionTreeClassifier from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
import time
from config import Config from config import Config
from detector import SugarDetect from detector import SugarDetect
@ -328,6 +329,18 @@ class RgbDetector(Detector):
mask_ai = self.ai_detector.detect(rgb_data, Config.ai_conf_threshold) mask_ai = self.ai_detector.detect(rgb_data, Config.ai_conf_threshold)
mask_ai = cv2.resize(mask_ai, dsize=(mask_rgb.shape[1], mask_rgb.shape[0])) mask_ai = cv2.resize(mask_ai, dsize=(mask_rgb.shape[1], mask_rgb.shape[0]))
mask_rgb = mask_ai | mask_rgb mask_rgb = mask_ai | mask_rgb
# # 测试时间
# start = time.time()
# 转换为lab提取a通道识别绿色杂质
lab_a = cv2.cvtColor(rgb_data, cv2.COLOR_RGB2LAB)[..., 1] < Config.threshold_a
# 转换为lab提取b通道识别蓝色杂质
lab_b = cv2.cvtColor(rgb_data, cv2.COLOR_RGB2LAB)[..., 2] < Config.threshold_b
lab_predict_result = lab_a | lab_b
mask_lab = size_threshold(lab_predict_result, Config.blk_size, Config.lab_size_threshold)
mask_rgb = mask_rgb.astype(np.uint8) | mask_lab.astype(np.uint8)
# # 测试时间
# end = time.time()
# print("lab time: ", end - start)
return mask_rgb return mask_rgb
def load(self, tobacco_model_path, background_model_path): def load(self, tobacco_model_path, background_model_path):
@ -357,9 +370,19 @@ class SpecDetector(Detector):
pixel_predict_result = self.pixel_predict_ml_dilation(data=img_data, iteration=1) pixel_predict_result = self.pixel_predict_ml_dilation(data=img_data, iteration=1)
blk_predict_result = self.blk_predict(data=img_data) blk_predict_result = self.blk_predict(data=img_data)
mask = (pixel_predict_result & blk_predict_result).astype(np.uint8) mask = (pixel_predict_result & blk_predict_result).astype(np.uint8)
spec_cv = np.clip(img_data[..., [21, 3, 0]], a_min=0, a_max=1) * 255
spec_cv = spec_cv.astype(np.uint8)
# spec转lab 提取a通道识别绿色杂质
lab_a = cv2.cvtColor(spec_cv, cv2.COLOR_RGB2LAB)[..., 1] < Config.s_threshold_a
# spec转lab 提取b通道识别蓝色杂质
lab_b = cv2.cvtColor(spec_cv, cv2.COLOR_RGB2LAB)[..., 2] < Config.s_threshold_b
lab_predict_result = lab_a | lab_b
mask_lab = size_threshold(lab_predict_result, Config.blk_size, Config.lab_size_threshold)
if save_part: if save_part:
self.spare_part = mask[-(Config.blk_size//2):, :] self.spare_part = mask[-(Config.blk_size//2):, :]
mask = size_threshold(mask, Config.blk_size, Config.spec_size_threshold, self.spare_part) mask = size_threshold(mask, Config.blk_size, Config.spec_size_threshold, self.spare_part)
mask = mask.astype(np.uint8) | mask_lab.astype(np.uint8)
return mask return mask
def save(self, *args, **kwargs): def save(self, *args, **kwargs):
@ -456,15 +479,17 @@ class DecisionTree(DecisionTreeClassifier):
if __name__ == '__main__': if __name__ == '__main__':
data_dir = "data/dataset" import os
data_dir = os.path.join('E:\Tobacco\data', 'dataset')
color_dict = {(0, 0, 255): "yangeng"} color_dict = {(0, 0, 255): "yangeng"}
dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False) dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)
ground_truth = dataset['yangeng'] ground_truth = dataset['yangeng']
detector = AnonymousColorDetector(file_path='models/dt_2022-07-19_14-38.model') detector = AnonymousColorDetector(file_path=r'E:\Tobacco\weights\tobacco_dt_2022-08-05_10-38.model')
# x = np.array([[10, 30, 20], [10, 35, 25], [10, 35, 36]]) # x = np.array([[10, 30, 20], [10, 35, 25], [10, 35, 36]])
boundary = np.array([0, 0, 0, 255, 255, 255]) boundary = np.array([0, 0, 0, 255, 255, 255])
# detector.fit(x, world_boundary, threshold=5, negative_sample_size=2000) # detector.fit(x, world_boundary, threshold=5, negative_sample_size=2000)
detector.visualize(boundary, sample_size=50000, class_max_num=5000, ground_truth=ground_truth) detector.visualize(boundary, sample_size=500000, class_max_num=5000, ground_truth=ground_truth, inside_alpha=0.3,
outside_alpha=0.01)
temp = scipy.io.loadmat('data/dataset_2022-07-19_11-35.mat') temp = scipy.io.loadmat('data/dataset_2022-07-19_11-35.mat')
x, y = temp['x'], temp['y'] x, y = temp['x'], temp['y']
dataset = {'inside': x[y.ravel() == 1, :], "outside": x[y.ravel() == 0, :]} dataset = {'inside': x[y.ravel() == 1, :], "outside": x[y.ravel() == 0, :]}

Binary file not shown.

47
tests/test_rgb.py Normal file
View File

@ -0,0 +1,47 @@
import os
import numpy as np
from config import Config
from models import Detector, AnonymousColorDetector, RgbDetector
import cv2
# 测试单张图片使用RGB进行预测的效果
# # 测试时间
# import time
# start_time = time.time()
# 读取图片
file_path = r"E:\Tobacco\data\testImgs\Image_2022_0726_1413_46_400-001165.bmp"
img = cv2.imread(file_path)[..., ::-1]
print("img.shape:", img.shape)
# 初始化和加载色彩模型
print('Initializing color model...')
rgb_detector = RgbDetector(tobacco_model_path=r'../weights/tobacco_dt_2022-08-27_14-43.model',
background_model_path=r"../weights/background_dt_2022-08-22_22-15.model",
ai_path='../weights/best0827.pt')
_ = rgb_detector.predict(np.ones((Config.nRgbRows, Config.nRgbCols, Config.nRgbBands), dtype=np.uint8) * 40)
print('Color model loaded.')
# 预测单张图片
print('Predicting...')
mask_rgb = rgb_detector.predict(img).astype(np.uint8)
# # 测试时间
# end_time = time.time()
# print("time cost:", end_time - start_time)
# 使用matplotlib展示两个图片的对比
import matplotlib.pyplot as plt
# 切换matplotlib的后端为qt否则会报错
plt.switch_backend('qt5agg')
fig, ax = plt.subplots(1, 2)
ax[0].imshow(img)
ax[1].matshow(mask_rgb)
plt.show()

View File

@ -114,12 +114,28 @@ def lab_scatter(dataset: dict, class_max_num=None, is_3d=False, is_ps_color_spac
:return: None :return: None
""" """
# 观察色彩分布情况 # 观察色彩分布情况
if "alpha" not in kwargs.keys():
kwargs["alpha"] = 0.1
if 'inside_alpha' in kwargs.keys():
inside_alpha = kwargs['inside_alpha']
else:
inside_alpha = kwargs["alpha"]
if 'outside_alpha' in kwargs.keys():
outside_alpha = kwargs['outside_alpha']
else:
outside_alpha = kwargs["alpha"]
fig = plt.figure() fig = plt.figure()
if is_3d: if is_3d:
ax = fig.add_subplot(projection='3d') ax = fig.add_subplot(projection='3d')
else: else:
ax = fig.add_subplot() ax = fig.add_subplot()
for label, data in dataset.items(): for label, data in dataset.items():
if label == 'Inside':
alpha = inside_alpha
elif label == 'Outside':
alpha = outside_alpha
else:
alpha = kwargs["alpha"]
if class_max_num is not None: if class_max_num is not None:
assert isinstance(class_max_num, int) assert isinstance(class_max_num, int)
if data.shape[0] > class_max_num: if data.shape[0] > class_max_num:
@ -128,9 +144,9 @@ def lab_scatter(dataset: dict, class_max_num=None, is_3d=False, is_ps_color_spac
data = data[sample_idx, :] data = data[sample_idx, :]
l, a, b = [data[:, i] for i in range(3)] l, a, b = [data[:, i] for i in range(3)]
if is_3d: if is_3d:
ax.scatter(a, b, l, label=label, alpha=0.1) ax.scatter(a, b, l, label=label, alpha=alpha)
else: else:
ax.scatter(a, b, label=label, alpha=0.1) ax.scatter(a, b, label=label, alpha=alpha)
x_max, x_min, y_max, y_min, z_max, z_min = [127, -127, 127, -127, 100, 0] if is_ps_color_space else \ x_max, x_min, y_max, y_min, z_max, z_min = [127, -127, 127, -127, 100, 0] if is_ps_color_space else \
[255, 0, 255, 0, 255, 0] [255, 0, 255, 0, 255, 0]
ax.set_xlim(x_min, x_max) ax.set_xlim(x_min, x_max)
@ -178,6 +194,14 @@ def valve_expend(img: np.ndarray) -> np.ndarray:
return img return img
def shield_valve(mask: np.ndarray, left_shield: int = -1, right_shield: int = -1) -> np.asarray:
if (left_shield < mask.shape[1]) & (left_shield > 0):
mask[:, :left_shield] = 0
if (right_shield < mask.shape[1]) & (right_shield > 0):
mask[:, -right_shield:] = 0
return mask
def valve_limit(mask: np.ndarray, max_valve_num: int) -> np.ndarray: def valve_limit(mask: np.ndarray, max_valve_num: int) -> np.ndarray:
""" """
用于限制阀门同时开启个数的函数 用于限制阀门同时开启个数的函数
@ -282,8 +306,8 @@ def valve_log(log_path: pathlib.Path, valve_num: [int, str]):
将喷阀的开启次数记录到文件log_path当中 将喷阀的开启次数记录到文件log_path当中
""" """
valve_str = "截至 " + datetime.now().strftime('%Y-%m-%d %H:%M:%S') + f' 喷阀使用次数: {valve_num}.' valve_str = "截至 " + datetime.now().strftime('%Y-%m-%d %H:%M:%S') + f' 喷阀使用次数: {valve_num}.'
with open(log_path, "w") as f: with open(log_path, "a") as f:
f.write(str(valve_str)) f.write(str(valve_str) + "\n")
if __name__ == '__main__': if __name__ == '__main__':