mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 14:23:55 +00:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3b89dff026 | ||
|
|
c391cfa7d9 | ||
|
|
192bb8db99 | ||
|
|
c4191be192 | ||
|
|
f2c614dbcf | ||
|
|
2af3d8091f | ||
|
|
edd3b0f9c8 |
@ -12,24 +12,8 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": 7,
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"C:\\Users\\FEIJINTI\\miniconda3\\envs\\deepo\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
|
||||||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Training env\n"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"import numpy as np\n",
|
"import numpy as np\n",
|
||||||
"import scipy\n",
|
"import scipy\n",
|
||||||
@ -46,7 +30,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 8,
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"train_from_existed = False # 是否从现有数据训练,如果是的话,那就从dataset_file训练,否则就用data_dir里头的数据\n",
|
"train_from_existed = False # 是否从现有数据训练,如果是的话,那就从dataset_file训练,否则就用data_dir里头的数据\n",
|
||||||
@ -54,8 +38,8 @@
|
|||||||
"dataset_file = \"data/dataset/dataset_2022-07-20_10-04.mat\"\n",
|
"dataset_file = \"data/dataset/dataset_2022-07-20_10-04.mat\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# color_dict = {(0, 0, 255): \"yangeng\", (255, 0, 0): 'beijing',(0, 255, 0): \"zibian\"} # 颜色对应的类别\n",
|
"# color_dict = {(0, 0, 255): \"yangeng\", (255, 0, 0): 'beijing',(0, 255, 0): \"zibian\"} # 颜色对应的类别\n",
|
||||||
"color_dict = {(0, 0, 255): \"yangeng\"}\n",
|
"# color_dict = {(0, 0, 255): \"yangeng\"}\n",
|
||||||
"# color_dict = {(255, 0, 0): 'beijing'}\n",
|
"color_dict = {(255, 0, 0): 'beijing'}\n",
|
||||||
"# color_dict = {(0, 255, 0): \"zibian\"}\n",
|
"# color_dict = {(0, 255, 0): \"zibian\"}\n",
|
||||||
"label_index = {\"yangeng\": 1, \"beijing\": 0, \"zibian\":2} # 类别对应的序号\n",
|
"label_index = {\"yangeng\": 1, \"beijing\": 0, \"zibian\":2} # 类别对应的序号\n",
|
||||||
"show_samples = False # 是否展示样本\n",
|
"show_samples = False # 是否展示样本\n",
|
||||||
@ -86,7 +70,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 9,
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)\n",
|
"dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)\n",
|
||||||
@ -169,7 +153,7 @@
|
|||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"23961it [00:46, 519.87it/s] "
|
"152147it [42:48, 59.24it/s] \n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -178,19 +162,12 @@
|
|||||||
"text": [
|
"text": [
|
||||||
" precision recall f1-score support\n",
|
" precision recall f1-score support\n",
|
||||||
"\n",
|
"\n",
|
||||||
" 0 1.00 1.00 1.00 7188\n",
|
" 0 1.00 1.00 1.00 45644\n",
|
||||||
" 1 1.00 1.00 1.00 5991\n",
|
" 1 1.00 1.00 1.00 38037\n",
|
||||||
"\n",
|
"\n",
|
||||||
" accuracy 1.00 13179\n",
|
" accuracy 1.00 83681\n",
|
||||||
" macro avg 1.00 1.00 1.00 13179\n",
|
" macro avg 1.00 1.00 1.00 83681\n",
|
||||||
"weighted avg 1.00 1.00 1.00 13179\n",
|
"weighted avg 1.00 1.00 1.00 83681\n",
|
||||||
"\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "stderr",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"\n"
|
"\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|||||||
@ -15,14 +15,20 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 1,
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stderr",
|
"ename": "NotImplementedError",
|
||||||
"output_type": "stream",
|
"evalue": "开发时和机器上部署的路径不同,请注意选择rgb_tobacco_model_path、rgb_background_model_path、ai_path后删除本行",
|
||||||
"text": [
|
"output_type": "error",
|
||||||
"C:\\Users\\FEIJINTI\\miniconda3\\envs\\deepo\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
"traceback": [
|
||||||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
|
||||||
|
"\u001B[1;31mNotImplementedError\u001B[0m Traceback (most recent call last)",
|
||||||
|
"Input \u001B[1;32mIn [1]\u001B[0m, in \u001B[0;36m<cell line: 3>\u001B[1;34m()\u001B[0m\n\u001B[0;32m 1\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mnumpy\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m \u001B[38;5;21;01mnp\u001B[39;00m\n\u001B[0;32m 2\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mpickle\u001B[39;00m\n\u001B[1;32m----> 3\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mutils\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m read_envi_ascii\n\u001B[0;32m 4\u001B[0m \u001B[38;5;66;03m# from config import Config\u001B[39;00m\n\u001B[0;32m 5\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mmodels\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m ManualTree\n",
|
||||||
|
"File \u001B[1;32m~\\OneDrive - macrosolid\\PycharmProjects\\tobacco_color\\utils\\__init__.py:20\u001B[0m, in \u001B[0;36m<module>\u001B[1;34m\u001B[0m\n\u001B[0;32m 17\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mmatplotlib\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m pyplot \u001B[38;5;28;01mas\u001B[39;00m plt\n\u001B[0;32m 18\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mre\u001B[39;00m\n\u001B[1;32m---> 20\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mconfig\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Config\n\u001B[0;32m 23\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mnatural_sort\u001B[39m(l):\n\u001B[0;32m 24\u001B[0m \u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[0;32m 25\u001B[0m \u001B[38;5;124;03m 自然排序\u001B[39;00m\n\u001B[0;32m 26\u001B[0m \u001B[38;5;124;03m :param l: 待排序\u001B[39;00m\n\u001B[0;32m 27\u001B[0m \u001B[38;5;124;03m :return:\u001B[39;00m\n\u001B[0;32m 28\u001B[0m \u001B[38;5;124;03m \"\"\"\u001B[39;00m\n",
|
||||||
|
"File \u001B[1;32m~\\OneDrive - macrosolid\\PycharmProjects\\tobacco_color\\config.py:6\u001B[0m, in \u001B[0;36m<module>\u001B[1;34m\u001B[0m\n\u001B[0;32m 1\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mos\u001B[39;00m\n\u001B[0;32m 3\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mnumpy\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m \u001B[38;5;21;01mnp\u001B[39;00m\n\u001B[1;32m----> 6\u001B[0m \u001B[38;5;28;01mclass\u001B[39;00m \u001B[38;5;21;01mConfig\u001B[39;00m:\n\u001B[0;32m 7\u001B[0m \u001B[38;5;66;03m# 文件相关参数\u001B[39;00m\n\u001B[0;32m 8\u001B[0m nRows, nCols, nBands \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m256\u001B[39m, \u001B[38;5;241m1024\u001B[39m, \u001B[38;5;241m22\u001B[39m\n\u001B[0;32m 9\u001B[0m nRgbRows, nRgbCols, nRgbBands \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m1024\u001B[39m, \u001B[38;5;241m4096\u001B[39m, \u001B[38;5;241m3\u001B[39m\n",
|
||||||
|
"File \u001B[1;32m~\\OneDrive - macrosolid\\PycharmProjects\\tobacco_color\\config.py:31\u001B[0m, in \u001B[0;36mConfig\u001B[1;34m()\u001B[0m\n\u001B[0;32m 28\u001B[0m spec_size_threshold \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m3\u001B[39m\n\u001B[0;32m 30\u001B[0m \u001B[38;5;66;03m# rgb模型参数\u001B[39;00m\n\u001B[1;32m---> 31\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mNotImplementedError\u001B[39;00m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m开发时和机器上部署的路径不同,请注意选择rgb_tobacco_model_path、rgb_background_model_path、ai_path后删除本行\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[0;32m 32\u001B[0m \u001B[38;5;66;03m# rgb_tobacco_model_path = r\"weights/tobacco_dt_2022-08-27_14-43.model\" # 开发时的路径\u001B[39;00m\n\u001B[0;32m 33\u001B[0m \u001B[38;5;66;03m# rgb_tobacco_model_path = r\"/home/dt/tobacco-color/weights/tobacco_dt_2022-08-27_14-43.model\" # 机器上部署的路径\u001B[39;00m\n\u001B[0;32m 34\u001B[0m \u001B[38;5;66;03m# rgb_background_model_path = r\"weights/background_dt_2022-08-22_22-15.model\" # 开发时的路径\u001B[39;00m\n\u001B[0;32m 35\u001B[0m \u001B[38;5;66;03m# rgb_background_model_path = r\"/home/dt/tobacco-color/weights/background_dt_2022-08-22_22-15.model\" # 机器上部署的路径\u001B[39;00m\n\u001B[0;32m 36\u001B[0m threshold_low, threshold_high \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m10\u001B[39m, \u001B[38;5;241m230\u001B[39m\n",
|
||||||
|
"\u001B[1;31mNotImplementedError\u001B[0m: 开发时和机器上部署的路径不同,请注意选择rgb_tobacco_model_path、rgb_background_model_path、ai_path后删除本行"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
|||||||
24
config.py
24
config.py
@ -20,30 +20,34 @@ class Config:
|
|||||||
|
|
||||||
# 光谱模型参数
|
# 光谱模型参数
|
||||||
blk_size = 4 # 必须是2的倍数,不然会出错
|
blk_size = 4 # 必须是2的倍数,不然会出错
|
||||||
raise NotImplementedError("开发时和机器上部署的路径不同,请注意选择pixel_model_path、blk_model_path后删除本行")
|
pixel_model_path = r"./weights/pixel_2022-08-02_15-22.model" # 开发时的路径
|
||||||
# pixel_model_path = r"./weights/pixel_2022-08-02_15-22.model" # 开发时的路径
|
|
||||||
# pixel_model_path = r"/home/dt/tobacco-color/weights/pixel_2022-08-02_15-22.model" # 机器上部署的路径
|
# pixel_model_path = r"/home/dt/tobacco-color/weights/pixel_2022-08-02_15-22.model" # 机器上部署的路径
|
||||||
# blk_model_path = r"./weights/rf_4x4_c22_20_sen8_9.model" # 开发时的路径
|
blk_model_path = r"./weights/rf_4x4_c22_20_sen8_9.model" # 开发时的路径
|
||||||
# blk_model_path = r"/home/dt/tobacco-color/weights/rf_4x4_c22_20_sen8_9.model" # 机器上部署的路径
|
# blk_model_path = r"/home/dt/tobacco-color/weights/rf_4x4_c22_20_sen8_9.model" # 机器上部署的路径
|
||||||
spec_size_threshold = 3
|
spec_size_threshold = 3
|
||||||
|
|
||||||
|
s_threshold_a = 124 # s_a的最高允许值
|
||||||
|
s_threshold_b = 124 # s_b的最高允许值
|
||||||
|
|
||||||
# rgb模型参数
|
# rgb模型参数
|
||||||
raise NotImplementedError("开发时和机器上部署的路径不同,请注意选择rgb_tobacco_model_path、rgb_background_model_path、ai_path后删除本行")
|
rgb_tobacco_model_path = r"weights/tobacco_dt_2022-08-27_14-43.model" # 开发时的路径
|
||||||
# rgb_tobacco_model_path = r"weights/tobacco_dt_2022-08-27_14-43.model" # 开发时的路径
|
|
||||||
# rgb_tobacco_model_path = r"/home/dt/tobacco-color/weights/tobacco_dt_2022-08-27_14-43.model" # 机器上部署的路径
|
# rgb_tobacco_model_path = r"/home/dt/tobacco-color/weights/tobacco_dt_2022-08-27_14-43.model" # 机器上部署的路径
|
||||||
# rgb_background_model_path = r"weights/background_dt_2022-08-22_22-15.model" # 开发时的路径
|
rgb_background_model_path = r"weights/background_dt_2023-12-26_20-39.model" # 开发时的路径
|
||||||
# rgb_background_model_path = r"/home/dt/tobacco-color/weights/background_dt_2022-08-22_22-15.model" # 机器上部署的路径
|
# rgb_background_model_path = r"/home/dt/tobacco-color/weights/background_dt_2022-08-22_22-15.model" # 机器上部署的路径
|
||||||
threshold_low, threshold_high = 10, 230
|
threshold_low, threshold_high = 10, 230
|
||||||
threshold_s = 190 # 饱和度的最高允许值
|
threshold_s = 190 # 饱和度的最高允许值
|
||||||
rgb_size_threshold = 4 # rgb的尺寸限制
|
threshold_a = 125 # a的最高允许值
|
||||||
# ai_path = 'weights/best0827.pt' # 开发时的路径
|
threshold_b = 126 # b的最高允许值
|
||||||
|
rgb_size_threshold = 6 # rgb的尺寸限制
|
||||||
|
lab_size_threshold = 6 # lab的尺寸限制
|
||||||
|
ai_path = 'weights/best1227.pt' # 开发时的路径
|
||||||
# ai_path = '/home/dt/tobacco-color/weights/best0827.pt' # 机器上部署的路径
|
# ai_path = '/home/dt/tobacco-color/weights/best0827.pt' # 机器上部署的路径
|
||||||
ai_conf_threshold = 0.6
|
ai_conf_threshold = 0.8
|
||||||
|
|
||||||
# mask parameter
|
# mask parameter
|
||||||
target_size = (1024, 1024) # (Width, Height) of mask
|
target_size = (1024, 1024) # (Width, Height) of mask
|
||||||
valve_merge_size = 2 # 每两个喷阀当中有任意一个出现杂质则认为都是杂质
|
valve_merge_size = 2 # 每两个喷阀当中有任意一个出现杂质则认为都是杂质
|
||||||
valve_horizontal_padding = 3 # 喷阀横向膨胀的尺寸,应该是奇数,3时表示左右各膨胀1
|
valve_horizontal_padding = 5 # 喷阀横向膨胀的尺寸,应该是奇数,3时表示左右各膨胀1,5时表示左右各膨胀2(23.9.20老倪要求改为5)
|
||||||
max_open_valve_limit = 25 # 最大同时开启喷阀限制,按照电流计算,当前的喷阀可以开启的喷阀 600W的电源 / 12V电源 = 50A, 一个阀门1A
|
max_open_valve_limit = 25 # 最大同时开启喷阀限制,按照电流计算,当前的喷阀可以开启的喷阀 600W的电源 / 12V电源 = 50A, 一个阀门1A
|
||||||
max_time_spent = 200
|
max_time_spent = 200
|
||||||
# save part
|
# save part
|
||||||
|
|||||||
@ -48,7 +48,8 @@ class SugarDetect(object):
|
|||||||
img = letterbox(img, (imgsz, imgsz), stride=stride)[0]
|
img = letterbox(img, (imgsz, imgsz), stride=stride)[0]
|
||||||
|
|
||||||
# Convert
|
# Convert
|
||||||
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
|
# img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
|
||||||
|
img = img.transpose(2, 0, 1) # to 3x416x416
|
||||||
img = np.ascontiguousarray(img)
|
img = np.ascontiguousarray(img)
|
||||||
|
|
||||||
# Preprocess
|
# Preprocess
|
||||||
|
|||||||
1
main.py
1
main.py
@ -124,6 +124,7 @@ def main(only_spec=False, only_color=False, if_merge=False, interval_time=None,
|
|||||||
mask_rgb = rgb_detector.predict(rgb_data).astype(np.uint8)
|
mask_rgb = rgb_detector.predict(rgb_data).astype(np.uint8)
|
||||||
masks = [mask_spec, mask_rgb]
|
masks = [mask_spec, mask_rgb]
|
||||||
# 进行多个喷阀的合并
|
# 进行多个喷阀的合并
|
||||||
|
masks = [utils_customized.shield_valve(mask, left_shield=10, right_shield=10) for mask in masks]
|
||||||
masks = [utils_customized.valve_expend(mask) for mask in masks]
|
masks = [utils_customized.valve_expend(mask) for mask in masks]
|
||||||
mask_nums = sum([np.sum(np.sum(mask)) for mask in masks])
|
mask_nums = sum([np.sum(np.sum(mask)) for mask in masks])
|
||||||
log_time_count += 1
|
log_time_count += 1
|
||||||
|
|||||||
@ -15,6 +15,7 @@ from scipy.ndimage import binary_dilation
|
|||||||
from sklearn.tree import DecisionTreeClassifier
|
from sklearn.tree import DecisionTreeClassifier
|
||||||
from sklearn.metrics import classification_report
|
from sklearn.metrics import classification_report
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
|
import time
|
||||||
|
|
||||||
from config import Config
|
from config import Config
|
||||||
from detector import SugarDetect
|
from detector import SugarDetect
|
||||||
@ -328,6 +329,18 @@ class RgbDetector(Detector):
|
|||||||
mask_ai = self.ai_detector.detect(rgb_data, Config.ai_conf_threshold)
|
mask_ai = self.ai_detector.detect(rgb_data, Config.ai_conf_threshold)
|
||||||
mask_ai = cv2.resize(mask_ai, dsize=(mask_rgb.shape[1], mask_rgb.shape[0]))
|
mask_ai = cv2.resize(mask_ai, dsize=(mask_rgb.shape[1], mask_rgb.shape[0]))
|
||||||
mask_rgb = mask_ai | mask_rgb
|
mask_rgb = mask_ai | mask_rgb
|
||||||
|
# # 测试时间
|
||||||
|
# start = time.time()
|
||||||
|
# 转换为lab,提取a通道,识别绿色杂质
|
||||||
|
lab_a = cv2.cvtColor(rgb_data, cv2.COLOR_RGB2LAB)[..., 1] < Config.threshold_a
|
||||||
|
# 转换为lab,提取b通道,识别蓝色杂质
|
||||||
|
lab_b = cv2.cvtColor(rgb_data, cv2.COLOR_RGB2LAB)[..., 2] < Config.threshold_b
|
||||||
|
lab_predict_result = lab_a | lab_b
|
||||||
|
mask_lab = size_threshold(lab_predict_result, Config.blk_size, Config.lab_size_threshold)
|
||||||
|
mask_rgb = mask_rgb.astype(np.uint8) | mask_lab.astype(np.uint8)
|
||||||
|
# # 测试时间
|
||||||
|
# end = time.time()
|
||||||
|
# print("lab time: ", end - start)
|
||||||
return mask_rgb
|
return mask_rgb
|
||||||
|
|
||||||
def load(self, tobacco_model_path, background_model_path):
|
def load(self, tobacco_model_path, background_model_path):
|
||||||
@ -357,9 +370,19 @@ class SpecDetector(Detector):
|
|||||||
pixel_predict_result = self.pixel_predict_ml_dilation(data=img_data, iteration=1)
|
pixel_predict_result = self.pixel_predict_ml_dilation(data=img_data, iteration=1)
|
||||||
blk_predict_result = self.blk_predict(data=img_data)
|
blk_predict_result = self.blk_predict(data=img_data)
|
||||||
mask = (pixel_predict_result & blk_predict_result).astype(np.uint8)
|
mask = (pixel_predict_result & blk_predict_result).astype(np.uint8)
|
||||||
|
spec_cv = np.clip(img_data[..., [21, 3, 0]], a_min=0, a_max=1) * 255
|
||||||
|
spec_cv = spec_cv.astype(np.uint8)
|
||||||
|
# spec转lab 提取a通道,识别绿色杂质
|
||||||
|
lab_a = cv2.cvtColor(spec_cv, cv2.COLOR_RGB2LAB)[..., 1] < Config.s_threshold_a
|
||||||
|
# spec转lab 提取b通道,识别蓝色杂质
|
||||||
|
lab_b = cv2.cvtColor(spec_cv, cv2.COLOR_RGB2LAB)[..., 2] < Config.s_threshold_b
|
||||||
|
lab_predict_result = lab_a | lab_b
|
||||||
|
mask_lab = size_threshold(lab_predict_result, Config.blk_size, Config.lab_size_threshold)
|
||||||
|
|
||||||
if save_part:
|
if save_part:
|
||||||
self.spare_part = mask[-(Config.blk_size//2):, :]
|
self.spare_part = mask[-(Config.blk_size//2):, :]
|
||||||
mask = size_threshold(mask, Config.blk_size, Config.spec_size_threshold, self.spare_part)
|
mask = size_threshold(mask, Config.blk_size, Config.spec_size_threshold, self.spare_part)
|
||||||
|
mask = mask.astype(np.uint8) | mask_lab.astype(np.uint8)
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
def save(self, *args, **kwargs):
|
def save(self, *args, **kwargs):
|
||||||
@ -456,15 +479,17 @@ class DecisionTree(DecisionTreeClassifier):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
data_dir = "data/dataset"
|
import os
|
||||||
|
data_dir = os.path.join('E:\Tobacco\data', 'dataset')
|
||||||
color_dict = {(0, 0, 255): "yangeng"}
|
color_dict = {(0, 0, 255): "yangeng"}
|
||||||
dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)
|
dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)
|
||||||
ground_truth = dataset['yangeng']
|
ground_truth = dataset['yangeng']
|
||||||
detector = AnonymousColorDetector(file_path='models/dt_2022-07-19_14-38.model')
|
detector = AnonymousColorDetector(file_path=r'E:\Tobacco\weights\tobacco_dt_2022-08-05_10-38.model')
|
||||||
# x = np.array([[10, 30, 20], [10, 35, 25], [10, 35, 36]])
|
# x = np.array([[10, 30, 20], [10, 35, 25], [10, 35, 36]])
|
||||||
boundary = np.array([0, 0, 0, 255, 255, 255])
|
boundary = np.array([0, 0, 0, 255, 255, 255])
|
||||||
# detector.fit(x, world_boundary, threshold=5, negative_sample_size=2000)
|
# detector.fit(x, world_boundary, threshold=5, negative_sample_size=2000)
|
||||||
detector.visualize(boundary, sample_size=50000, class_max_num=5000, ground_truth=ground_truth)
|
detector.visualize(boundary, sample_size=500000, class_max_num=5000, ground_truth=ground_truth, inside_alpha=0.3,
|
||||||
|
outside_alpha=0.01)
|
||||||
temp = scipy.io.loadmat('data/dataset_2022-07-19_11-35.mat')
|
temp = scipy.io.loadmat('data/dataset_2022-07-19_11-35.mat')
|
||||||
x, y = temp['x'], temp['y']
|
x, y = temp['x'], temp['y']
|
||||||
dataset = {'inside': x[y.ravel() == 1, :], "outside": x[y.ravel() == 0, :]}
|
dataset = {'inside': x[y.ravel() == 1, :], "outside": x[y.ravel() == 0, :]}
|
||||||
|
|||||||
BIN
requirements.txt
BIN
requirements.txt
Binary file not shown.
47
tests/test_rgb.py
Normal file
47
tests/test_rgb.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from config import Config
|
||||||
|
from models import Detector, AnonymousColorDetector, RgbDetector
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
# 测试单张图片使用RGB进行预测的效果
|
||||||
|
|
||||||
|
# # 测试时间
|
||||||
|
# import time
|
||||||
|
# start_time = time.time()
|
||||||
|
# 读取图片
|
||||||
|
file_path = r"E:\Tobacco\data\testImgs\Image_2022_0726_1413_46_400-001165.bmp"
|
||||||
|
img = cv2.imread(file_path)[..., ::-1]
|
||||||
|
print("img.shape:", img.shape)
|
||||||
|
|
||||||
|
# 初始化和加载色彩模型
|
||||||
|
print('Initializing color model...')
|
||||||
|
rgb_detector = RgbDetector(tobacco_model_path=r'../weights/tobacco_dt_2022-08-27_14-43.model',
|
||||||
|
background_model_path=r"../weights/background_dt_2022-08-22_22-15.model",
|
||||||
|
ai_path='../weights/best0827.pt')
|
||||||
|
_ = rgb_detector.predict(np.ones((Config.nRgbRows, Config.nRgbCols, Config.nRgbBands), dtype=np.uint8) * 40)
|
||||||
|
print('Color model loaded.')
|
||||||
|
|
||||||
|
# 预测单张图片
|
||||||
|
print('Predicting...')
|
||||||
|
mask_rgb = rgb_detector.predict(img).astype(np.uint8)
|
||||||
|
|
||||||
|
# # 测试时间
|
||||||
|
# end_time = time.time()
|
||||||
|
# print("time cost:", end_time - start_time)
|
||||||
|
|
||||||
|
# 使用matplotlib展示两个图片的对比
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
# 切换matplotlib的后端为qt,否则会报错
|
||||||
|
plt.switch_backend('qt5agg')
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(1, 2)
|
||||||
|
ax[0].imshow(img)
|
||||||
|
ax[1].matshow(mask_rgb)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -114,12 +114,28 @@ def lab_scatter(dataset: dict, class_max_num=None, is_3d=False, is_ps_color_spac
|
|||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
# 观察色彩分布情况
|
# 观察色彩分布情况
|
||||||
|
if "alpha" not in kwargs.keys():
|
||||||
|
kwargs["alpha"] = 0.1
|
||||||
|
if 'inside_alpha' in kwargs.keys():
|
||||||
|
inside_alpha = kwargs['inside_alpha']
|
||||||
|
else:
|
||||||
|
inside_alpha = kwargs["alpha"]
|
||||||
|
if 'outside_alpha' in kwargs.keys():
|
||||||
|
outside_alpha = kwargs['outside_alpha']
|
||||||
|
else:
|
||||||
|
outside_alpha = kwargs["alpha"]
|
||||||
fig = plt.figure()
|
fig = plt.figure()
|
||||||
if is_3d:
|
if is_3d:
|
||||||
ax = fig.add_subplot(projection='3d')
|
ax = fig.add_subplot(projection='3d')
|
||||||
else:
|
else:
|
||||||
ax = fig.add_subplot()
|
ax = fig.add_subplot()
|
||||||
for label, data in dataset.items():
|
for label, data in dataset.items():
|
||||||
|
if label == 'Inside':
|
||||||
|
alpha = inside_alpha
|
||||||
|
elif label == 'Outside':
|
||||||
|
alpha = outside_alpha
|
||||||
|
else:
|
||||||
|
alpha = kwargs["alpha"]
|
||||||
if class_max_num is not None:
|
if class_max_num is not None:
|
||||||
assert isinstance(class_max_num, int)
|
assert isinstance(class_max_num, int)
|
||||||
if data.shape[0] > class_max_num:
|
if data.shape[0] > class_max_num:
|
||||||
@ -128,9 +144,9 @@ def lab_scatter(dataset: dict, class_max_num=None, is_3d=False, is_ps_color_spac
|
|||||||
data = data[sample_idx, :]
|
data = data[sample_idx, :]
|
||||||
l, a, b = [data[:, i] for i in range(3)]
|
l, a, b = [data[:, i] for i in range(3)]
|
||||||
if is_3d:
|
if is_3d:
|
||||||
ax.scatter(a, b, l, label=label, alpha=0.1)
|
ax.scatter(a, b, l, label=label, alpha=alpha)
|
||||||
else:
|
else:
|
||||||
ax.scatter(a, b, label=label, alpha=0.1)
|
ax.scatter(a, b, label=label, alpha=alpha)
|
||||||
x_max, x_min, y_max, y_min, z_max, z_min = [127, -127, 127, -127, 100, 0] if is_ps_color_space else \
|
x_max, x_min, y_max, y_min, z_max, z_min = [127, -127, 127, -127, 100, 0] if is_ps_color_space else \
|
||||||
[255, 0, 255, 0, 255, 0]
|
[255, 0, 255, 0, 255, 0]
|
||||||
ax.set_xlim(x_min, x_max)
|
ax.set_xlim(x_min, x_max)
|
||||||
@ -178,6 +194,14 @@ def valve_expend(img: np.ndarray) -> np.ndarray:
|
|||||||
return img
|
return img
|
||||||
|
|
||||||
|
|
||||||
|
def shield_valve(mask: np.ndarray, left_shield: int = -1, right_shield: int = -1) -> np.asarray:
|
||||||
|
if (left_shield < mask.shape[1]) & (left_shield > 0):
|
||||||
|
mask[:, :left_shield] = 0
|
||||||
|
if (right_shield < mask.shape[1]) & (right_shield > 0):
|
||||||
|
mask[:, -right_shield:] = 0
|
||||||
|
return mask
|
||||||
|
|
||||||
|
|
||||||
def valve_limit(mask: np.ndarray, max_valve_num: int) -> np.ndarray:
|
def valve_limit(mask: np.ndarray, max_valve_num: int) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
用于限制阀门同时开启个数的函数
|
用于限制阀门同时开启个数的函数
|
||||||
@ -282,8 +306,8 @@ def valve_log(log_path: pathlib.Path, valve_num: [int, str]):
|
|||||||
将喷阀的开启次数记录到文件log_path当中。
|
将喷阀的开启次数记录到文件log_path当中。
|
||||||
"""
|
"""
|
||||||
valve_str = "截至 " + datetime.now().strftime('%Y-%m-%d %H:%M:%S') + f' 喷阀使用次数: {valve_num}.'
|
valve_str = "截至 " + datetime.now().strftime('%Y-%m-%d %H:%M:%S') + f' 喷阀使用次数: {valve_num}.'
|
||||||
with open(log_path, "w") as f:
|
with open(log_path, "a") as f:
|
||||||
f.write(str(valve_str))
|
f.write(str(valve_str) + "\n")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user