mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 14:23:55 +00:00
Compare commits
No commits in common. "master" and "v1.0.0" have entirely different histories.
@ -12,8 +12,24 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"outputs": [],
|
||||
"execution_count": 1,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"C:\\Users\\FEIJINTI\\miniconda3\\envs\\deepo\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Training env\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import scipy\n",
|
||||
@ -30,7 +46,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"execution_count": 2,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_from_existed = False # 是否从现有数据训练,如果是的话,那就从dataset_file训练,否则就用data_dir里头的数据\n",
|
||||
@ -38,8 +54,8 @@
|
||||
"dataset_file = \"data/dataset/dataset_2022-07-20_10-04.mat\"\n",
|
||||
"\n",
|
||||
"# color_dict = {(0, 0, 255): \"yangeng\", (255, 0, 0): 'beijing',(0, 255, 0): \"zibian\"} # 颜色对应的类别\n",
|
||||
"# color_dict = {(0, 0, 255): \"yangeng\"}\n",
|
||||
"color_dict = {(255, 0, 0): 'beijing'}\n",
|
||||
"color_dict = {(0, 0, 255): \"yangeng\"}\n",
|
||||
"# color_dict = {(255, 0, 0): 'beijing'}\n",
|
||||
"# color_dict = {(0, 255, 0): \"zibian\"}\n",
|
||||
"label_index = {\"yangeng\": 1, \"beijing\": 0, \"zibian\":2} # 类别对应的序号\n",
|
||||
"show_samples = False # 是否展示样本\n",
|
||||
@ -70,7 +86,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"execution_count": 3,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)\n",
|
||||
@ -153,7 +169,7 @@
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"152147it [42:48, 59.24it/s] \n"
|
||||
"23961it [00:46, 519.87it/s] "
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -162,12 +178,19 @@
|
||||
"text": [
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" 0 1.00 1.00 1.00 45644\n",
|
||||
" 1 1.00 1.00 1.00 38037\n",
|
||||
" 0 1.00 1.00 1.00 7188\n",
|
||||
" 1 1.00 1.00 1.00 5991\n",
|
||||
"\n",
|
||||
" accuracy 1.00 83681\n",
|
||||
" macro avg 1.00 1.00 1.00 83681\n",
|
||||
"weighted avg 1.00 1.00 1.00 83681\n",
|
||||
" accuracy 1.00 13179\n",
|
||||
" macro avg 1.00 1.00 1.00 13179\n",
|
||||
"weighted avg 1.00 1.00 1.00 13179\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
|
||||
@ -15,20 +15,14 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 3,
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "NotImplementedError",
|
||||
"evalue": "开发时和机器上部署的路径不同,请注意选择rgb_tobacco_model_path、rgb_background_model_path、ai_path后删除本行",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
|
||||
"\u001B[1;31mNotImplementedError\u001B[0m Traceback (most recent call last)",
|
||||
"Input \u001B[1;32mIn [1]\u001B[0m, in \u001B[0;36m<cell line: 3>\u001B[1;34m()\u001B[0m\n\u001B[0;32m 1\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mnumpy\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m \u001B[38;5;21;01mnp\u001B[39;00m\n\u001B[0;32m 2\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mpickle\u001B[39;00m\n\u001B[1;32m----> 3\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mutils\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m read_envi_ascii\n\u001B[0;32m 4\u001B[0m \u001B[38;5;66;03m# from config import Config\u001B[39;00m\n\u001B[0;32m 5\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mmodels\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m ManualTree\n",
|
||||
"File \u001B[1;32m~\\OneDrive - macrosolid\\PycharmProjects\\tobacco_color\\utils\\__init__.py:20\u001B[0m, in \u001B[0;36m<module>\u001B[1;34m\u001B[0m\n\u001B[0;32m 17\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mmatplotlib\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m pyplot \u001B[38;5;28;01mas\u001B[39;00m plt\n\u001B[0;32m 18\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mre\u001B[39;00m\n\u001B[1;32m---> 20\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mconfig\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Config\n\u001B[0;32m 23\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mnatural_sort\u001B[39m(l):\n\u001B[0;32m 24\u001B[0m \u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[0;32m 25\u001B[0m \u001B[38;5;124;03m 自然排序\u001B[39;00m\n\u001B[0;32m 26\u001B[0m \u001B[38;5;124;03m :param l: 待排序\u001B[39;00m\n\u001B[0;32m 27\u001B[0m \u001B[38;5;124;03m :return:\u001B[39;00m\n\u001B[0;32m 28\u001B[0m \u001B[38;5;124;03m \"\"\"\u001B[39;00m\n",
|
||||
"File \u001B[1;32m~\\OneDrive - macrosolid\\PycharmProjects\\tobacco_color\\config.py:6\u001B[0m, in \u001B[0;36m<module>\u001B[1;34m\u001B[0m\n\u001B[0;32m 1\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mos\u001B[39;00m\n\u001B[0;32m 3\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mnumpy\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m \u001B[38;5;21;01mnp\u001B[39;00m\n\u001B[1;32m----> 6\u001B[0m \u001B[38;5;28;01mclass\u001B[39;00m \u001B[38;5;21;01mConfig\u001B[39;00m:\n\u001B[0;32m 7\u001B[0m \u001B[38;5;66;03m# 文件相关参数\u001B[39;00m\n\u001B[0;32m 8\u001B[0m nRows, nCols, nBands \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m256\u001B[39m, \u001B[38;5;241m1024\u001B[39m, \u001B[38;5;241m22\u001B[39m\n\u001B[0;32m 9\u001B[0m nRgbRows, nRgbCols, nRgbBands \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m1024\u001B[39m, \u001B[38;5;241m4096\u001B[39m, \u001B[38;5;241m3\u001B[39m\n",
|
||||
"File \u001B[1;32m~\\OneDrive - macrosolid\\PycharmProjects\\tobacco_color\\config.py:31\u001B[0m, in \u001B[0;36mConfig\u001B[1;34m()\u001B[0m\n\u001B[0;32m 28\u001B[0m spec_size_threshold \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m3\u001B[39m\n\u001B[0;32m 30\u001B[0m \u001B[38;5;66;03m# rgb模型参数\u001B[39;00m\n\u001B[1;32m---> 31\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mNotImplementedError\u001B[39;00m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m开发时和机器上部署的路径不同,请注意选择rgb_tobacco_model_path、rgb_background_model_path、ai_path后删除本行\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[0;32m 32\u001B[0m \u001B[38;5;66;03m# rgb_tobacco_model_path = r\"weights/tobacco_dt_2022-08-27_14-43.model\" # 开发时的路径\u001B[39;00m\n\u001B[0;32m 33\u001B[0m \u001B[38;5;66;03m# rgb_tobacco_model_path = r\"/home/dt/tobacco-color/weights/tobacco_dt_2022-08-27_14-43.model\" # 机器上部署的路径\u001B[39;00m\n\u001B[0;32m 34\u001B[0m \u001B[38;5;66;03m# rgb_background_model_path = r\"weights/background_dt_2022-08-22_22-15.model\" # 开发时的路径\u001B[39;00m\n\u001B[0;32m 35\u001B[0m \u001B[38;5;66;03m# rgb_background_model_path = r\"/home/dt/tobacco-color/weights/background_dt_2022-08-22_22-15.model\" # 机器上部署的路径\u001B[39;00m\n\u001B[0;32m 36\u001B[0m threshold_low, threshold_high \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m10\u001B[39m, \u001B[38;5;241m230\u001B[39m\n",
|
||||
"\u001B[1;31mNotImplementedError\u001B[0m: 开发时和机器上部署的路径不同,请注意选择rgb_tobacco_model_path、rgb_background_model_path、ai_path后删除本行"
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"C:\\Users\\FEIJINTI\\miniconda3\\envs\\deepo\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
24
config.py
24
config.py
@ -20,34 +20,30 @@ class Config:
|
||||
|
||||
# 光谱模型参数
|
||||
blk_size = 4 # 必须是2的倍数,不然会出错
|
||||
pixel_model_path = r"./weights/pixel_2022-08-02_15-22.model" # 开发时的路径
|
||||
raise NotImplementedError("开发时和机器上部署的路径不同,请注意选择pixel_model_path、blk_model_path后删除本行")
|
||||
# pixel_model_path = r"./weights/pixel_2022-08-02_15-22.model" # 开发时的路径
|
||||
# pixel_model_path = r"/home/dt/tobacco-color/weights/pixel_2022-08-02_15-22.model" # 机器上部署的路径
|
||||
blk_model_path = r"./weights/rf_4x4_c22_20_sen8_9.model" # 开发时的路径
|
||||
# blk_model_path = r"./weights/rf_4x4_c22_20_sen8_9.model" # 开发时的路径
|
||||
# blk_model_path = r"/home/dt/tobacco-color/weights/rf_4x4_c22_20_sen8_9.model" # 机器上部署的路径
|
||||
spec_size_threshold = 3
|
||||
|
||||
s_threshold_a = 124 # s_a的最高允许值
|
||||
s_threshold_b = 124 # s_b的最高允许值
|
||||
|
||||
# rgb模型参数
|
||||
rgb_tobacco_model_path = r"weights/tobacco_dt_2022-08-27_14-43.model" # 开发时的路径
|
||||
raise NotImplementedError("开发时和机器上部署的路径不同,请注意选择rgb_tobacco_model_path、rgb_background_model_path、ai_path后删除本行")
|
||||
# rgb_tobacco_model_path = r"weights/tobacco_dt_2022-08-27_14-43.model" # 开发时的路径
|
||||
# rgb_tobacco_model_path = r"/home/dt/tobacco-color/weights/tobacco_dt_2022-08-27_14-43.model" # 机器上部署的路径
|
||||
rgb_background_model_path = r"weights/background_dt_2023-12-26_20-39.model" # 开发时的路径
|
||||
# rgb_background_model_path = r"weights/background_dt_2022-08-22_22-15.model" # 开发时的路径
|
||||
# rgb_background_model_path = r"/home/dt/tobacco-color/weights/background_dt_2022-08-22_22-15.model" # 机器上部署的路径
|
||||
threshold_low, threshold_high = 10, 230
|
||||
threshold_s = 190 # 饱和度的最高允许值
|
||||
threshold_a = 125 # a的最高允许值
|
||||
threshold_b = 126 # b的最高允许值
|
||||
rgb_size_threshold = 6 # rgb的尺寸限制
|
||||
lab_size_threshold = 6 # lab的尺寸限制
|
||||
ai_path = 'weights/best1227.pt' # 开发时的路径
|
||||
rgb_size_threshold = 4 # rgb的尺寸限制
|
||||
# ai_path = 'weights/best0827.pt' # 开发时的路径
|
||||
# ai_path = '/home/dt/tobacco-color/weights/best0827.pt' # 机器上部署的路径
|
||||
ai_conf_threshold = 0.8
|
||||
ai_conf_threshold = 0.6
|
||||
|
||||
# mask parameter
|
||||
target_size = (1024, 1024) # (Width, Height) of mask
|
||||
valve_merge_size = 2 # 每两个喷阀当中有任意一个出现杂质则认为都是杂质
|
||||
valve_horizontal_padding = 5 # 喷阀横向膨胀的尺寸,应该是奇数,3时表示左右各膨胀1,5时表示左右各膨胀2(23.9.20老倪要求改为5)
|
||||
valve_horizontal_padding = 3 # 喷阀横向膨胀的尺寸,应该是奇数,3时表示左右各膨胀1
|
||||
max_open_valve_limit = 25 # 最大同时开启喷阀限制,按照电流计算,当前的喷阀可以开启的喷阀 600W的电源 / 12V电源 = 50A, 一个阀门1A
|
||||
max_time_spent = 200
|
||||
# save part
|
||||
|
||||
@ -48,8 +48,7 @@ class SugarDetect(object):
|
||||
img = letterbox(img, (imgsz, imgsz), stride=stride)[0]
|
||||
|
||||
# Convert
|
||||
# img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
|
||||
img = img.transpose(2, 0, 1) # to 3x416x416
|
||||
img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416
|
||||
img = np.ascontiguousarray(img)
|
||||
|
||||
# Preprocess
|
||||
|
||||
1
main.py
1
main.py
@ -124,7 +124,6 @@ def main(only_spec=False, only_color=False, if_merge=False, interval_time=None,
|
||||
mask_rgb = rgb_detector.predict(rgb_data).astype(np.uint8)
|
||||
masks = [mask_spec, mask_rgb]
|
||||
# 进行多个喷阀的合并
|
||||
masks = [utils_customized.shield_valve(mask, left_shield=10, right_shield=10) for mask in masks]
|
||||
masks = [utils_customized.valve_expend(mask) for mask in masks]
|
||||
mask_nums = sum([np.sum(np.sum(mask)) for mask in masks])
|
||||
log_time_count += 1
|
||||
|
||||
@ -15,7 +15,6 @@ from scipy.ndimage import binary_dilation
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
from sklearn.metrics import classification_report
|
||||
from sklearn.model_selection import train_test_split
|
||||
import time
|
||||
|
||||
from config import Config
|
||||
from detector import SugarDetect
|
||||
@ -329,18 +328,6 @@ class RgbDetector(Detector):
|
||||
mask_ai = self.ai_detector.detect(rgb_data, Config.ai_conf_threshold)
|
||||
mask_ai = cv2.resize(mask_ai, dsize=(mask_rgb.shape[1], mask_rgb.shape[0]))
|
||||
mask_rgb = mask_ai | mask_rgb
|
||||
# # 测试时间
|
||||
# start = time.time()
|
||||
# 转换为lab,提取a通道,识别绿色杂质
|
||||
lab_a = cv2.cvtColor(rgb_data, cv2.COLOR_RGB2LAB)[..., 1] < Config.threshold_a
|
||||
# 转换为lab,提取b通道,识别蓝色杂质
|
||||
lab_b = cv2.cvtColor(rgb_data, cv2.COLOR_RGB2LAB)[..., 2] < Config.threshold_b
|
||||
lab_predict_result = lab_a | lab_b
|
||||
mask_lab = size_threshold(lab_predict_result, Config.blk_size, Config.lab_size_threshold)
|
||||
mask_rgb = mask_rgb.astype(np.uint8) | mask_lab.astype(np.uint8)
|
||||
# # 测试时间
|
||||
# end = time.time()
|
||||
# print("lab time: ", end - start)
|
||||
return mask_rgb
|
||||
|
||||
def load(self, tobacco_model_path, background_model_path):
|
||||
@ -370,19 +357,9 @@ class SpecDetector(Detector):
|
||||
pixel_predict_result = self.pixel_predict_ml_dilation(data=img_data, iteration=1)
|
||||
blk_predict_result = self.blk_predict(data=img_data)
|
||||
mask = (pixel_predict_result & blk_predict_result).astype(np.uint8)
|
||||
spec_cv = np.clip(img_data[..., [21, 3, 0]], a_min=0, a_max=1) * 255
|
||||
spec_cv = spec_cv.astype(np.uint8)
|
||||
# spec转lab 提取a通道,识别绿色杂质
|
||||
lab_a = cv2.cvtColor(spec_cv, cv2.COLOR_RGB2LAB)[..., 1] < Config.s_threshold_a
|
||||
# spec转lab 提取b通道,识别蓝色杂质
|
||||
lab_b = cv2.cvtColor(spec_cv, cv2.COLOR_RGB2LAB)[..., 2] < Config.s_threshold_b
|
||||
lab_predict_result = lab_a | lab_b
|
||||
mask_lab = size_threshold(lab_predict_result, Config.blk_size, Config.lab_size_threshold)
|
||||
|
||||
if save_part:
|
||||
self.spare_part = mask[-(Config.blk_size//2):, :]
|
||||
mask = size_threshold(mask, Config.blk_size, Config.spec_size_threshold, self.spare_part)
|
||||
mask = mask.astype(np.uint8) | mask_lab.astype(np.uint8)
|
||||
return mask
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
@ -479,17 +456,15 @@ class DecisionTree(DecisionTreeClassifier):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import os
|
||||
data_dir = os.path.join('E:\Tobacco\data', 'dataset')
|
||||
data_dir = "data/dataset"
|
||||
color_dict = {(0, 0, 255): "yangeng"}
|
||||
dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)
|
||||
ground_truth = dataset['yangeng']
|
||||
detector = AnonymousColorDetector(file_path=r'E:\Tobacco\weights\tobacco_dt_2022-08-05_10-38.model')
|
||||
detector = AnonymousColorDetector(file_path='models/dt_2022-07-19_14-38.model')
|
||||
# x = np.array([[10, 30, 20], [10, 35, 25], [10, 35, 36]])
|
||||
boundary = np.array([0, 0, 0, 255, 255, 255])
|
||||
# detector.fit(x, world_boundary, threshold=5, negative_sample_size=2000)
|
||||
detector.visualize(boundary, sample_size=500000, class_max_num=5000, ground_truth=ground_truth, inside_alpha=0.3,
|
||||
outside_alpha=0.01)
|
||||
detector.visualize(boundary, sample_size=50000, class_max_num=5000, ground_truth=ground_truth)
|
||||
temp = scipy.io.loadmat('data/dataset_2022-07-19_11-35.mat')
|
||||
x, y = temp['x'], temp['y']
|
||||
dataset = {'inside': x[y.ravel() == 1, :], "outside": x[y.ravel() == 0, :]}
|
||||
|
||||
BIN
requirements.txt
BIN
requirements.txt
Binary file not shown.
@ -1,47 +0,0 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from config import Config
|
||||
from models import Detector, AnonymousColorDetector, RgbDetector
|
||||
import cv2
|
||||
|
||||
# 测试单张图片使用RGB进行预测的效果
|
||||
|
||||
# # 测试时间
|
||||
# import time
|
||||
# start_time = time.time()
|
||||
# 读取图片
|
||||
file_path = r"E:\Tobacco\data\testImgs\Image_2022_0726_1413_46_400-001165.bmp"
|
||||
img = cv2.imread(file_path)[..., ::-1]
|
||||
print("img.shape:", img.shape)
|
||||
|
||||
# 初始化和加载色彩模型
|
||||
print('Initializing color model...')
|
||||
rgb_detector = RgbDetector(tobacco_model_path=r'../weights/tobacco_dt_2022-08-27_14-43.model',
|
||||
background_model_path=r"../weights/background_dt_2022-08-22_22-15.model",
|
||||
ai_path='../weights/best0827.pt')
|
||||
_ = rgb_detector.predict(np.ones((Config.nRgbRows, Config.nRgbCols, Config.nRgbBands), dtype=np.uint8) * 40)
|
||||
print('Color model loaded.')
|
||||
|
||||
# 预测单张图片
|
||||
print('Predicting...')
|
||||
mask_rgb = rgb_detector.predict(img).astype(np.uint8)
|
||||
|
||||
# # 测试时间
|
||||
# end_time = time.time()
|
||||
# print("time cost:", end_time - start_time)
|
||||
|
||||
# 使用matplotlib展示两个图片的对比
|
||||
import matplotlib.pyplot as plt
|
||||
# 切换matplotlib的后端为qt,否则会报错
|
||||
plt.switch_backend('qt5agg')
|
||||
|
||||
fig, ax = plt.subplots(1, 2)
|
||||
ax[0].imshow(img)
|
||||
ax[1].matshow(mask_rgb)
|
||||
plt.show()
|
||||
|
||||
|
||||
|
||||
|
||||
@ -114,28 +114,12 @@ def lab_scatter(dataset: dict, class_max_num=None, is_3d=False, is_ps_color_spac
|
||||
:return: None
|
||||
"""
|
||||
# 观察色彩分布情况
|
||||
if "alpha" not in kwargs.keys():
|
||||
kwargs["alpha"] = 0.1
|
||||
if 'inside_alpha' in kwargs.keys():
|
||||
inside_alpha = kwargs['inside_alpha']
|
||||
else:
|
||||
inside_alpha = kwargs["alpha"]
|
||||
if 'outside_alpha' in kwargs.keys():
|
||||
outside_alpha = kwargs['outside_alpha']
|
||||
else:
|
||||
outside_alpha = kwargs["alpha"]
|
||||
fig = plt.figure()
|
||||
if is_3d:
|
||||
ax = fig.add_subplot(projection='3d')
|
||||
else:
|
||||
ax = fig.add_subplot()
|
||||
for label, data in dataset.items():
|
||||
if label == 'Inside':
|
||||
alpha = inside_alpha
|
||||
elif label == 'Outside':
|
||||
alpha = outside_alpha
|
||||
else:
|
||||
alpha = kwargs["alpha"]
|
||||
if class_max_num is not None:
|
||||
assert isinstance(class_max_num, int)
|
||||
if data.shape[0] > class_max_num:
|
||||
@ -144,9 +128,9 @@ def lab_scatter(dataset: dict, class_max_num=None, is_3d=False, is_ps_color_spac
|
||||
data = data[sample_idx, :]
|
||||
l, a, b = [data[:, i] for i in range(3)]
|
||||
if is_3d:
|
||||
ax.scatter(a, b, l, label=label, alpha=alpha)
|
||||
ax.scatter(a, b, l, label=label, alpha=0.1)
|
||||
else:
|
||||
ax.scatter(a, b, label=label, alpha=alpha)
|
||||
ax.scatter(a, b, label=label, alpha=0.1)
|
||||
x_max, x_min, y_max, y_min, z_max, z_min = [127, -127, 127, -127, 100, 0] if is_ps_color_space else \
|
||||
[255, 0, 255, 0, 255, 0]
|
||||
ax.set_xlim(x_min, x_max)
|
||||
@ -194,14 +178,6 @@ def valve_expend(img: np.ndarray) -> np.ndarray:
|
||||
return img
|
||||
|
||||
|
||||
def shield_valve(mask: np.ndarray, left_shield: int = -1, right_shield: int = -1) -> np.asarray:
|
||||
if (left_shield < mask.shape[1]) & (left_shield > 0):
|
||||
mask[:, :left_shield] = 0
|
||||
if (right_shield < mask.shape[1]) & (right_shield > 0):
|
||||
mask[:, -right_shield:] = 0
|
||||
return mask
|
||||
|
||||
|
||||
def valve_limit(mask: np.ndarray, max_valve_num: int) -> np.ndarray:
|
||||
"""
|
||||
用于限制阀门同时开启个数的函数
|
||||
@ -306,8 +282,8 @@ def valve_log(log_path: pathlib.Path, valve_num: [int, str]):
|
||||
将喷阀的开启次数记录到文件log_path当中。
|
||||
"""
|
||||
valve_str = "截至 " + datetime.now().strftime('%Y-%m-%d %H:%M:%S') + f' 喷阀使用次数: {valve_num}.'
|
||||
with open(log_path, "a") as f:
|
||||
f.write(str(valve_str) + "\n")
|
||||
with open(log_path, "w") as f:
|
||||
f.write(str(valve_str))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
Loading…
Reference in New Issue
Block a user