mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 06:13:53 +00:00
fix:
添加了lab颜色空间识别绿色杂质、蓝色杂质,添加a、b的阈值为127,尺寸限制改为6。 优化了lab色彩空间内绘制3维数据分布情况。 喷阀横向膨胀尺寸改为5。 添加test_rgb测试单张图片
This commit is contained in:
parent
f2c614dbcf
commit
c4191be192
@ -33,7 +33,10 @@ class Config:
|
||||
# rgb_background_model_path = r"/home/dt/tobacco-color/weights/background_dt_2022-08-22_22-15.model" # 机器上部署的路径
|
||||
threshold_low, threshold_high = 10, 230
|
||||
threshold_s = 190 # 饱和度的最高允许值
|
||||
rgb_size_threshold = 4 # rgb的尺寸限制
|
||||
threshold_a = 127 # a的最高允许值
|
||||
threshold_b = 127 # b的最高允许值
|
||||
rgb_size_threshold = 6 # rgb的尺寸限制
|
||||
lab_size_threshold = 6 # lab的尺寸限制
|
||||
ai_path = 'weights/best0827.pt' # 开发时的路径
|
||||
# ai_path = '/home/dt/tobacco-color/weights/best0827.pt' # 机器上部署的路径
|
||||
ai_conf_threshold = 0.6
|
||||
@ -41,7 +44,7 @@ class Config:
|
||||
# mask parameter
|
||||
target_size = (1024, 1024) # (Width, Height) of mask
|
||||
valve_merge_size = 2 # 每两个喷阀当中有任意一个出现杂质则认为都是杂质
|
||||
valve_horizontal_padding = 3 # 喷阀横向膨胀的尺寸,应该是奇数,3时表示左右各膨胀1
|
||||
valve_horizontal_padding = 5 # 喷阀横向膨胀的尺寸,应该是奇数,3时表示左右各膨胀1,5时表示左右各膨胀2(23.9.20老倪要求改为5)
|
||||
max_open_valve_limit = 25 # 最大同时开启喷阀限制,按照电流计算,当前的喷阀可以开启的喷阀 600W的电源 / 12V电源 = 50A, 一个阀门1A
|
||||
max_time_spent = 200
|
||||
# save part
|
||||
|
||||
@ -15,6 +15,7 @@ from scipy.ndimage import binary_dilation
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
from sklearn.metrics import classification_report
|
||||
from sklearn.model_selection import train_test_split
|
||||
import time
|
||||
|
||||
from config import Config
|
||||
from detector import SugarDetect
|
||||
@ -328,6 +329,18 @@ class RgbDetector(Detector):
|
||||
mask_ai = self.ai_detector.detect(rgb_data, Config.ai_conf_threshold)
|
||||
mask_ai = cv2.resize(mask_ai, dsize=(mask_rgb.shape[1], mask_rgb.shape[0]))
|
||||
mask_rgb = mask_ai | mask_rgb
|
||||
# # 测试时间
|
||||
# start = time.time()
|
||||
# 转换为lab,提取a通道,识别绿色杂质
|
||||
lab_a = cv2.cvtColor(rgb_data, cv2.COLOR_RGB2LAB)[..., 1] < Config.threshold_a
|
||||
# 转换为lab,提取b通道,识别蓝色杂质
|
||||
lab_b = cv2.cvtColor(rgb_data, cv2.COLOR_RGB2LAB)[..., 2] < Config.threshold_b
|
||||
lab_predict_result = lab_a | lab_b
|
||||
mask_lab = size_threshold(lab_predict_result, Config.blk_size, Config.lab_size_threshold)
|
||||
mask_rgb = mask_rgb | mask_lab
|
||||
# # 测试时间
|
||||
# end = time.time()
|
||||
# print("lab time: ", end - start)
|
||||
return mask_rgb
|
||||
|
||||
def load(self, tobacco_model_path, background_model_path):
|
||||
@ -456,15 +469,17 @@ class DecisionTree(DecisionTreeClassifier):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
data_dir = "data/dataset"
|
||||
import os
|
||||
data_dir = os.path.join('E:\Tobacco\data', 'dataset')
|
||||
color_dict = {(0, 0, 255): "yangeng"}
|
||||
dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)
|
||||
ground_truth = dataset['yangeng']
|
||||
detector = AnonymousColorDetector(file_path='models/dt_2022-07-19_14-38.model')
|
||||
detector = AnonymousColorDetector(file_path=r'E:\Tobacco\weights\tobacco_dt_2022-08-05_10-38.model')
|
||||
# x = np.array([[10, 30, 20], [10, 35, 25], [10, 35, 36]])
|
||||
boundary = np.array([0, 0, 0, 255, 255, 255])
|
||||
# detector.fit(x, world_boundary, threshold=5, negative_sample_size=2000)
|
||||
detector.visualize(boundary, sample_size=50000, class_max_num=5000, ground_truth=ground_truth)
|
||||
detector.visualize(boundary, sample_size=500000, class_max_num=5000, ground_truth=ground_truth, inside_alpha=0.3,
|
||||
outside_alpha=0.01)
|
||||
temp = scipy.io.loadmat('data/dataset_2022-07-19_11-35.mat')
|
||||
x, y = temp['x'], temp['y']
|
||||
dataset = {'inside': x[y.ravel() == 1, :], "outside": x[y.ravel() == 0, :]}
|
||||
|
||||
47
tests/test_rgb.py
Normal file
47
tests/test_rgb.py
Normal file
@ -0,0 +1,47 @@
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from config import Config
|
||||
from models import Detector, AnonymousColorDetector, RgbDetector
|
||||
import cv2
|
||||
|
||||
# 测试单张图片使用RGB进行预测的效果
|
||||
|
||||
# # 测试时间
|
||||
# import time
|
||||
# start_time = time.time()
|
||||
# 读取图片
|
||||
file_path = r"E:\Tobacco\data\testImgs\Image_2022_0726_1413_46_400-001165.bmp"
|
||||
img = cv2.imread(file_path)[..., ::-1]
|
||||
print("img.shape:", img.shape)
|
||||
|
||||
# 初始化和加载色彩模型
|
||||
print('Initializing color model...')
|
||||
rgb_detector = RgbDetector(tobacco_model_path=r'../weights/tobacco_dt_2022-08-27_14-43.model',
|
||||
background_model_path=r"../weights/background_dt_2022-08-22_22-15.model",
|
||||
ai_path='../weights/best0827.pt')
|
||||
_ = rgb_detector.predict(np.ones((Config.nRgbRows, Config.nRgbCols, Config.nRgbBands), dtype=np.uint8) * 40)
|
||||
print('Color model loaded.')
|
||||
|
||||
# 预测单张图片
|
||||
print('Predicting...')
|
||||
mask_rgb = rgb_detector.predict(img).astype(np.uint8)
|
||||
|
||||
# # 测试时间
|
||||
# end_time = time.time()
|
||||
# print("time cost:", end_time - start_time)
|
||||
|
||||
# 使用matplotlib展示两个图片的对比
|
||||
import matplotlib.pyplot as plt
|
||||
# 切换matplotlib的后端为qt,否则会报错
|
||||
plt.switch_backend('qt5agg')
|
||||
|
||||
fig, ax = plt.subplots(1, 2)
|
||||
ax[0].imshow(img)
|
||||
ax[1].matshow(mask_rgb)
|
||||
plt.show()
|
||||
|
||||
|
||||
|
||||
|
||||
@ -114,12 +114,28 @@ def lab_scatter(dataset: dict, class_max_num=None, is_3d=False, is_ps_color_spac
|
||||
:return: None
|
||||
"""
|
||||
# 观察色彩分布情况
|
||||
if "alpha" not in kwargs.keys():
|
||||
kwargs["alpha"] = 0.1
|
||||
if 'inside_alpha' in kwargs.keys():
|
||||
inside_alpha = kwargs['inside_alpha']
|
||||
else:
|
||||
inside_alpha = kwargs["alpha"]
|
||||
if 'outside_alpha' in kwargs.keys():
|
||||
outside_alpha = kwargs['outside_alpha']
|
||||
else:
|
||||
outside_alpha = kwargs["alpha"]
|
||||
fig = plt.figure()
|
||||
if is_3d:
|
||||
ax = fig.add_subplot(projection='3d')
|
||||
else:
|
||||
ax = fig.add_subplot()
|
||||
for label, data in dataset.items():
|
||||
if label == 'Inside':
|
||||
alpha = inside_alpha
|
||||
elif label == 'Outside':
|
||||
alpha = outside_alpha
|
||||
else:
|
||||
alpha = kwargs["alpha"]
|
||||
if class_max_num is not None:
|
||||
assert isinstance(class_max_num, int)
|
||||
if data.shape[0] > class_max_num:
|
||||
@ -128,9 +144,9 @@ def lab_scatter(dataset: dict, class_max_num=None, is_3d=False, is_ps_color_spac
|
||||
data = data[sample_idx, :]
|
||||
l, a, b = [data[:, i] for i in range(3)]
|
||||
if is_3d:
|
||||
ax.scatter(a, b, l, label=label, alpha=0.1)
|
||||
ax.scatter(a, b, l, label=label, alpha=alpha)
|
||||
else:
|
||||
ax.scatter(a, b, label=label, alpha=0.1)
|
||||
ax.scatter(a, b, label=label, alpha=alpha)
|
||||
x_max, x_min, y_max, y_min, z_max, z_min = [127, -127, 127, -127, 100, 0] if is_ps_color_space else \
|
||||
[255, 0, 255, 0, 255, 0]
|
||||
ax.set_xlim(x_min, x_max)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user