diff --git a/config.py b/config.py index ce5c46a..cfc6424 100644 --- a/config.py +++ b/config.py @@ -33,7 +33,10 @@ class Config: # rgb_background_model_path = r"/home/dt/tobacco-color/weights/background_dt_2022-08-22_22-15.model" # 机器上部署的路径 threshold_low, threshold_high = 10, 230 threshold_s = 190 # 饱和度的最高允许值 - rgb_size_threshold = 4 # rgb的尺寸限制 + threshold_a = 127 # a的最高允许值 + threshold_b = 127 # b的最高允许值 + rgb_size_threshold = 6 # rgb的尺寸限制 + lab_size_threshold = 6 # lab的尺寸限制 ai_path = 'weights/best0827.pt' # 开发时的路径 # ai_path = '/home/dt/tobacco-color/weights/best0827.pt' # 机器上部署的路径 ai_conf_threshold = 0.6 @@ -41,7 +44,7 @@ class Config: # mask parameter target_size = (1024, 1024) # (Width, Height) of mask valve_merge_size = 2 # 每两个喷阀当中有任意一个出现杂质则认为都是杂质 - valve_horizontal_padding = 3 # 喷阀横向膨胀的尺寸,应该是奇数,3时表示左右各膨胀1 + valve_horizontal_padding = 5 # 喷阀横向膨胀的尺寸,应该是奇数,3时表示左右各膨胀1,5时表示左右各膨胀2(23.9.20老倪要求改为5) max_open_valve_limit = 25 # 最大同时开启喷阀限制,按照电流计算,当前的喷阀可以开启的喷阀 600W的电源 / 12V电源 = 50A, 一个阀门1A max_time_spent = 200 # save part diff --git a/models/__init__.py b/models/__init__.py index 0a53b49..b52b2e5 100755 --- a/models/__init__.py +++ b/models/__init__.py @@ -15,6 +15,7 @@ from scipy.ndimage import binary_dilation from sklearn.tree import DecisionTreeClassifier from sklearn.metrics import classification_report from sklearn.model_selection import train_test_split +import time from config import Config from detector import SugarDetect @@ -328,6 +329,18 @@ class RgbDetector(Detector): mask_ai = self.ai_detector.detect(rgb_data, Config.ai_conf_threshold) mask_ai = cv2.resize(mask_ai, dsize=(mask_rgb.shape[1], mask_rgb.shape[0])) mask_rgb = mask_ai | mask_rgb + # # 测试时间 + # start = time.time() + # 转换为lab,提取a通道,识别绿色杂质 + lab_a = cv2.cvtColor(rgb_data, cv2.COLOR_RGB2LAB)[..., 1] < Config.threshold_a + # 转换为lab,提取b通道,识别蓝色杂质 + lab_b = cv2.cvtColor(rgb_data, cv2.COLOR_RGB2LAB)[..., 2] < Config.threshold_b + lab_predict_result = lab_a | lab_b + mask_lab = size_threshold(lab_predict_result, Config.blk_size, Config.lab_size_threshold) + mask_rgb = mask_rgb | mask_lab + # # 测试时间 + # end = time.time() + # print("lab time: ", end - start) return mask_rgb def load(self, tobacco_model_path, background_model_path): @@ -456,15 +469,17 @@ class DecisionTree(DecisionTreeClassifier): if __name__ == '__main__': - data_dir = "data/dataset" + import os + data_dir = os.path.join('E:\Tobacco\data', 'dataset') color_dict = {(0, 0, 255): "yangeng"} dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False) ground_truth = dataset['yangeng'] - detector = AnonymousColorDetector(file_path='models/dt_2022-07-19_14-38.model') + detector = AnonymousColorDetector(file_path=r'E:\Tobacco\weights\tobacco_dt_2022-08-05_10-38.model') # x = np.array([[10, 30, 20], [10, 35, 25], [10, 35, 36]]) boundary = np.array([0, 0, 0, 255, 255, 255]) # detector.fit(x, world_boundary, threshold=5, negative_sample_size=2000) - detector.visualize(boundary, sample_size=50000, class_max_num=5000, ground_truth=ground_truth) + detector.visualize(boundary, sample_size=500000, class_max_num=5000, ground_truth=ground_truth, inside_alpha=0.3, + outside_alpha=0.01) temp = scipy.io.loadmat('data/dataset_2022-07-19_11-35.mat') x, y = temp['x'], temp['y'] dataset = {'inside': x[y.ravel() == 1, :], "outside": x[y.ravel() == 0, :]} diff --git a/tests/test_rgb.py b/tests/test_rgb.py new file mode 100644 index 0000000..7edf7a0 --- /dev/null +++ b/tests/test_rgb.py @@ -0,0 +1,47 @@ +import os + +import numpy as np + +from config import Config +from models import Detector, AnonymousColorDetector, RgbDetector +import cv2 + +# 测试单张图片使用RGB进行预测的效果 + +# # 测试时间 +# import time +# start_time = time.time() +# 读取图片 +file_path = r"E:\Tobacco\data\testImgs\Image_2022_0726_1413_46_400-001165.bmp" +img = cv2.imread(file_path)[..., ::-1] +print("img.shape:", img.shape) + +# 初始化和加载色彩模型 +print('Initializing color model...') +rgb_detector = RgbDetector(tobacco_model_path=r'../weights/tobacco_dt_2022-08-27_14-43.model', + background_model_path=r"../weights/background_dt_2022-08-22_22-15.model", + ai_path='../weights/best0827.pt') +_ = rgb_detector.predict(np.ones((Config.nRgbRows, Config.nRgbCols, Config.nRgbBands), dtype=np.uint8) * 40) +print('Color model loaded.') + +# 预测单张图片 +print('Predicting...') +mask_rgb = rgb_detector.predict(img).astype(np.uint8) + +# # 测试时间 +# end_time = time.time() +# print("time cost:", end_time - start_time) + +# 使用matplotlib展示两个图片的对比 +import matplotlib.pyplot as plt +# 切换matplotlib的后端为qt,否则会报错 +plt.switch_backend('qt5agg') + +fig, ax = plt.subplots(1, 2) +ax[0].imshow(img) +ax[1].matshow(mask_rgb) +plt.show() + + + + diff --git a/utils/__init__.py b/utils/__init__.py index be7ebc3..d6413d0 100755 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -114,12 +114,28 @@ def lab_scatter(dataset: dict, class_max_num=None, is_3d=False, is_ps_color_spac :return: None """ # 观察色彩分布情况 + if "alpha" not in kwargs.keys(): + kwargs["alpha"] = 0.1 + if 'inside_alpha' in kwargs.keys(): + inside_alpha = kwargs['inside_alpha'] + else: + inside_alpha = kwargs["alpha"] + if 'outside_alpha' in kwargs.keys(): + outside_alpha = kwargs['outside_alpha'] + else: + outside_alpha = kwargs["alpha"] fig = plt.figure() if is_3d: ax = fig.add_subplot(projection='3d') else: ax = fig.add_subplot() for label, data in dataset.items(): + if label == 'Inside': + alpha = inside_alpha + elif label == 'Outside': + alpha = outside_alpha + else: + alpha = kwargs["alpha"] if class_max_num is not None: assert isinstance(class_max_num, int) if data.shape[0] > class_max_num: @@ -128,9 +144,9 @@ def lab_scatter(dataset: dict, class_max_num=None, is_3d=False, is_ps_color_spac data = data[sample_idx, :] l, a, b = [data[:, i] for i in range(3)] if is_3d: - ax.scatter(a, b, l, label=label, alpha=0.1) + ax.scatter(a, b, l, label=label, alpha=alpha) else: - ax.scatter(a, b, label=label, alpha=0.1) + ax.scatter(a, b, label=label, alpha=alpha) x_max, x_min, y_max, y_min, z_max, z_min = [127, -127, 127, -127, 100, 0] if is_ps_color_space else \ [255, 0, 255, 0, 255, 0] ax.set_xlim(x_min, x_max)