mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 14:23:55 +00:00
fix:
添加了lab颜色空间识别绿色杂质、蓝色杂质,添加a、b的阈值为127,尺寸限制改为6。 优化了lab色彩空间内绘制3维数据分布情况。 喷阀横向膨胀尺寸改为5。 添加test_rgb测试单张图片
This commit is contained in:
parent
f2c614dbcf
commit
c4191be192
@ -33,7 +33,10 @@ class Config:
|
|||||||
# rgb_background_model_path = r"/home/dt/tobacco-color/weights/background_dt_2022-08-22_22-15.model" # 机器上部署的路径
|
# rgb_background_model_path = r"/home/dt/tobacco-color/weights/background_dt_2022-08-22_22-15.model" # 机器上部署的路径
|
||||||
threshold_low, threshold_high = 10, 230
|
threshold_low, threshold_high = 10, 230
|
||||||
threshold_s = 190 # 饱和度的最高允许值
|
threshold_s = 190 # 饱和度的最高允许值
|
||||||
rgb_size_threshold = 4 # rgb的尺寸限制
|
threshold_a = 127 # a的最高允许值
|
||||||
|
threshold_b = 127 # b的最高允许值
|
||||||
|
rgb_size_threshold = 6 # rgb的尺寸限制
|
||||||
|
lab_size_threshold = 6 # lab的尺寸限制
|
||||||
ai_path = 'weights/best0827.pt' # 开发时的路径
|
ai_path = 'weights/best0827.pt' # 开发时的路径
|
||||||
# ai_path = '/home/dt/tobacco-color/weights/best0827.pt' # 机器上部署的路径
|
# ai_path = '/home/dt/tobacco-color/weights/best0827.pt' # 机器上部署的路径
|
||||||
ai_conf_threshold = 0.6
|
ai_conf_threshold = 0.6
|
||||||
@ -41,7 +44,7 @@ class Config:
|
|||||||
# mask parameter
|
# mask parameter
|
||||||
target_size = (1024, 1024) # (Width, Height) of mask
|
target_size = (1024, 1024) # (Width, Height) of mask
|
||||||
valve_merge_size = 2 # 每两个喷阀当中有任意一个出现杂质则认为都是杂质
|
valve_merge_size = 2 # 每两个喷阀当中有任意一个出现杂质则认为都是杂质
|
||||||
valve_horizontal_padding = 3 # 喷阀横向膨胀的尺寸,应该是奇数,3时表示左右各膨胀1
|
valve_horizontal_padding = 5 # 喷阀横向膨胀的尺寸,应该是奇数,3时表示左右各膨胀1,5时表示左右各膨胀2(23.9.20老倪要求改为5)
|
||||||
max_open_valve_limit = 25 # 最大同时开启喷阀限制,按照电流计算,当前的喷阀可以开启的喷阀 600W的电源 / 12V电源 = 50A, 一个阀门1A
|
max_open_valve_limit = 25 # 最大同时开启喷阀限制,按照电流计算,当前的喷阀可以开启的喷阀 600W的电源 / 12V电源 = 50A, 一个阀门1A
|
||||||
max_time_spent = 200
|
max_time_spent = 200
|
||||||
# save part
|
# save part
|
||||||
|
|||||||
@ -15,6 +15,7 @@ from scipy.ndimage import binary_dilation
|
|||||||
from sklearn.tree import DecisionTreeClassifier
|
from sklearn.tree import DecisionTreeClassifier
|
||||||
from sklearn.metrics import classification_report
|
from sklearn.metrics import classification_report
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
|
import time
|
||||||
|
|
||||||
from config import Config
|
from config import Config
|
||||||
from detector import SugarDetect
|
from detector import SugarDetect
|
||||||
@ -328,6 +329,18 @@ class RgbDetector(Detector):
|
|||||||
mask_ai = self.ai_detector.detect(rgb_data, Config.ai_conf_threshold)
|
mask_ai = self.ai_detector.detect(rgb_data, Config.ai_conf_threshold)
|
||||||
mask_ai = cv2.resize(mask_ai, dsize=(mask_rgb.shape[1], mask_rgb.shape[0]))
|
mask_ai = cv2.resize(mask_ai, dsize=(mask_rgb.shape[1], mask_rgb.shape[0]))
|
||||||
mask_rgb = mask_ai | mask_rgb
|
mask_rgb = mask_ai | mask_rgb
|
||||||
|
# # 测试时间
|
||||||
|
# start = time.time()
|
||||||
|
# 转换为lab,提取a通道,识别绿色杂质
|
||||||
|
lab_a = cv2.cvtColor(rgb_data, cv2.COLOR_RGB2LAB)[..., 1] < Config.threshold_a
|
||||||
|
# 转换为lab,提取b通道,识别蓝色杂质
|
||||||
|
lab_b = cv2.cvtColor(rgb_data, cv2.COLOR_RGB2LAB)[..., 2] < Config.threshold_b
|
||||||
|
lab_predict_result = lab_a | lab_b
|
||||||
|
mask_lab = size_threshold(lab_predict_result, Config.blk_size, Config.lab_size_threshold)
|
||||||
|
mask_rgb = mask_rgb | mask_lab
|
||||||
|
# # 测试时间
|
||||||
|
# end = time.time()
|
||||||
|
# print("lab time: ", end - start)
|
||||||
return mask_rgb
|
return mask_rgb
|
||||||
|
|
||||||
def load(self, tobacco_model_path, background_model_path):
|
def load(self, tobacco_model_path, background_model_path):
|
||||||
@ -456,15 +469,17 @@ class DecisionTree(DecisionTreeClassifier):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
data_dir = "data/dataset"
|
import os
|
||||||
|
data_dir = os.path.join('E:\Tobacco\data', 'dataset')
|
||||||
color_dict = {(0, 0, 255): "yangeng"}
|
color_dict = {(0, 0, 255): "yangeng"}
|
||||||
dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)
|
dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)
|
||||||
ground_truth = dataset['yangeng']
|
ground_truth = dataset['yangeng']
|
||||||
detector = AnonymousColorDetector(file_path='models/dt_2022-07-19_14-38.model')
|
detector = AnonymousColorDetector(file_path=r'E:\Tobacco\weights\tobacco_dt_2022-08-05_10-38.model')
|
||||||
# x = np.array([[10, 30, 20], [10, 35, 25], [10, 35, 36]])
|
# x = np.array([[10, 30, 20], [10, 35, 25], [10, 35, 36]])
|
||||||
boundary = np.array([0, 0, 0, 255, 255, 255])
|
boundary = np.array([0, 0, 0, 255, 255, 255])
|
||||||
# detector.fit(x, world_boundary, threshold=5, negative_sample_size=2000)
|
# detector.fit(x, world_boundary, threshold=5, negative_sample_size=2000)
|
||||||
detector.visualize(boundary, sample_size=50000, class_max_num=5000, ground_truth=ground_truth)
|
detector.visualize(boundary, sample_size=500000, class_max_num=5000, ground_truth=ground_truth, inside_alpha=0.3,
|
||||||
|
outside_alpha=0.01)
|
||||||
temp = scipy.io.loadmat('data/dataset_2022-07-19_11-35.mat')
|
temp = scipy.io.loadmat('data/dataset_2022-07-19_11-35.mat')
|
||||||
x, y = temp['x'], temp['y']
|
x, y = temp['x'], temp['y']
|
||||||
dataset = {'inside': x[y.ravel() == 1, :], "outside": x[y.ravel() == 0, :]}
|
dataset = {'inside': x[y.ravel() == 1, :], "outside": x[y.ravel() == 0, :]}
|
||||||
|
|||||||
47
tests/test_rgb.py
Normal file
47
tests/test_rgb.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from config import Config
|
||||||
|
from models import Detector, AnonymousColorDetector, RgbDetector
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
# 测试单张图片使用RGB进行预测的效果
|
||||||
|
|
||||||
|
# # 测试时间
|
||||||
|
# import time
|
||||||
|
# start_time = time.time()
|
||||||
|
# 读取图片
|
||||||
|
file_path = r"E:\Tobacco\data\testImgs\Image_2022_0726_1413_46_400-001165.bmp"
|
||||||
|
img = cv2.imread(file_path)[..., ::-1]
|
||||||
|
print("img.shape:", img.shape)
|
||||||
|
|
||||||
|
# 初始化和加载色彩模型
|
||||||
|
print('Initializing color model...')
|
||||||
|
rgb_detector = RgbDetector(tobacco_model_path=r'../weights/tobacco_dt_2022-08-27_14-43.model',
|
||||||
|
background_model_path=r"../weights/background_dt_2022-08-22_22-15.model",
|
||||||
|
ai_path='../weights/best0827.pt')
|
||||||
|
_ = rgb_detector.predict(np.ones((Config.nRgbRows, Config.nRgbCols, Config.nRgbBands), dtype=np.uint8) * 40)
|
||||||
|
print('Color model loaded.')
|
||||||
|
|
||||||
|
# 预测单张图片
|
||||||
|
print('Predicting...')
|
||||||
|
mask_rgb = rgb_detector.predict(img).astype(np.uint8)
|
||||||
|
|
||||||
|
# # 测试时间
|
||||||
|
# end_time = time.time()
|
||||||
|
# print("time cost:", end_time - start_time)
|
||||||
|
|
||||||
|
# 使用matplotlib展示两个图片的对比
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
# 切换matplotlib的后端为qt,否则会报错
|
||||||
|
plt.switch_backend('qt5agg')
|
||||||
|
|
||||||
|
fig, ax = plt.subplots(1, 2)
|
||||||
|
ax[0].imshow(img)
|
||||||
|
ax[1].matshow(mask_rgb)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@ -114,12 +114,28 @@ def lab_scatter(dataset: dict, class_max_num=None, is_3d=False, is_ps_color_spac
|
|||||||
:return: None
|
:return: None
|
||||||
"""
|
"""
|
||||||
# 观察色彩分布情况
|
# 观察色彩分布情况
|
||||||
|
if "alpha" not in kwargs.keys():
|
||||||
|
kwargs["alpha"] = 0.1
|
||||||
|
if 'inside_alpha' in kwargs.keys():
|
||||||
|
inside_alpha = kwargs['inside_alpha']
|
||||||
|
else:
|
||||||
|
inside_alpha = kwargs["alpha"]
|
||||||
|
if 'outside_alpha' in kwargs.keys():
|
||||||
|
outside_alpha = kwargs['outside_alpha']
|
||||||
|
else:
|
||||||
|
outside_alpha = kwargs["alpha"]
|
||||||
fig = plt.figure()
|
fig = plt.figure()
|
||||||
if is_3d:
|
if is_3d:
|
||||||
ax = fig.add_subplot(projection='3d')
|
ax = fig.add_subplot(projection='3d')
|
||||||
else:
|
else:
|
||||||
ax = fig.add_subplot()
|
ax = fig.add_subplot()
|
||||||
for label, data in dataset.items():
|
for label, data in dataset.items():
|
||||||
|
if label == 'Inside':
|
||||||
|
alpha = inside_alpha
|
||||||
|
elif label == 'Outside':
|
||||||
|
alpha = outside_alpha
|
||||||
|
else:
|
||||||
|
alpha = kwargs["alpha"]
|
||||||
if class_max_num is not None:
|
if class_max_num is not None:
|
||||||
assert isinstance(class_max_num, int)
|
assert isinstance(class_max_num, int)
|
||||||
if data.shape[0] > class_max_num:
|
if data.shape[0] > class_max_num:
|
||||||
@ -128,9 +144,9 @@ def lab_scatter(dataset: dict, class_max_num=None, is_3d=False, is_ps_color_spac
|
|||||||
data = data[sample_idx, :]
|
data = data[sample_idx, :]
|
||||||
l, a, b = [data[:, i] for i in range(3)]
|
l, a, b = [data[:, i] for i in range(3)]
|
||||||
if is_3d:
|
if is_3d:
|
||||||
ax.scatter(a, b, l, label=label, alpha=0.1)
|
ax.scatter(a, b, l, label=label, alpha=alpha)
|
||||||
else:
|
else:
|
||||||
ax.scatter(a, b, label=label, alpha=0.1)
|
ax.scatter(a, b, label=label, alpha=alpha)
|
||||||
x_max, x_min, y_max, y_min, z_max, z_min = [127, -127, 127, -127, 100, 0] if is_ps_color_space else \
|
x_max, x_min, y_max, y_min, z_max, z_min = [127, -127, 127, -127, 100, 0] if is_ps_color_space else \
|
||||||
[255, 0, 255, 0, 255, 0]
|
[255, 0, 255, 0, 255, 0]
|
||||||
ax.set_xlim(x_min, x_max)
|
ax.set_xlim(x_min, x_max)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user