添加了dt方法

This commit is contained in:
FEIJINTI 2022-07-19 15:39:51 +08:00
parent 73973216f6
commit bfba067285
4 changed files with 99 additions and 28 deletions

View File

@ -7,7 +7,8 @@
import numpy as np import numpy as np
import scipy
from imblearn.under_sampling import RandomUnderSampler
from models import AnonymousColorDetector from models import AnonymousColorDetector
from utils import read_labeled_img from utils import read_labeled_img
@ -17,8 +18,15 @@ from utils import read_labeled_img
data_dir = "data/dataset" data_dir = "data/dataset"
color_dict = {(0, 0, 255): "yangeng"} color_dict = {(0, 0, 255): "yangeng", (255, 0, 0): 'beijing'}
label_index = {"yangeng": 1, "beijing": 0}
dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False) dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)
rus = RandomUnderSampler(random_state=0)
x_list, y_list = np.concatenate([v for k, v in dataset.items()], axis=0).tolist(), \
np.concatenate([np.ones((v.shape[0],)) * label_index[k] for k, v in dataset.items()], axis=0).tolist()
x_resampled, y_resampled = rus.fit_resample(x_list, y_list)
dataset = {"inside": np.array(x_resampled)}
# ## 模型训练 # ## 模型训练
@ -32,10 +40,14 @@ negative_sample_num = None # None或者一个数字
world_boundary = np.array([0, 0, 0, 255, 255, 255]) world_boundary = np.array([0, 0, 0, 255, 255, 255])
# 对数据进行预处理 # 对数据进行预处理
x = np.concatenate([v for k, v in dataset.items()], axis=0) x = np.concatenate([v for k, v in dataset.items()], axis=0)
negative_sample_num = int(x.shape[0] * 0.7) if negative_sample_num is None else negative_sample_num negative_sample_num = int(x.shape[0] * 1.2) if negative_sample_num is None else negative_sample_num
model = AnonymousColorDetector() model = AnonymousColorDetector()
model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7) model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7,
is_save_dataset=True, model_selection='dt')
# data = scipy.io.loadmat('dataset_2022-07-19_15-07.mat')
# x, y = data['x'], data['y'].ravel()
# model.fit(x, y=y, is_generate_negative=False, model_selection='dt')
model.save() model.save()

View File

@ -10,15 +10,17 @@ import matplotlib.pyplot as plt
import numpy as np import numpy as np
from models import Detector, AnonymousColorDetector from models import Detector, AnonymousColorDetector
from utils import read_labeled_img
def virtual_main(detector: Detector, test_img=None, test_img_dir=None): def virtual_main(detector: Detector, test_img=None, test_img_dir=None, test_model=False):
""" """
虚拟读图测试程序 虚拟读图测试程序
:param detector: 杂质探测器需要继承Detector类 :param detector: 杂质探测器需要继承Detector类
:param test_img: 测试图像rgb格式的图片或者路径 :param test_img: 测试图像rgb格式的图片或者路径
:param test_img_dir: 测试图像文件夹 :param test_img_dir: 测试图像文件夹
:param test_model: 是否进行模型约束性测试
:return: :return:
""" """
if test_img is not None: if test_img is not None:
@ -29,8 +31,10 @@ def virtual_main(detector: Detector, test_img=None, test_img_dir=None):
else: else:
raise TypeError("test img should be np.ndarray or str") raise TypeError("test img should be np.ndarray or str")
t1 = time.time() t1 = time.time()
result = detector.predict(img) img = cv2.resize(img, (1024, 256))
t2 = time.time() t2 = time.time()
result = 1 - detector.predict(img)
t3 = time.time()
fig, axs = plt.subplots(3, 1) fig, axs = plt.subplots(3, 1)
axs[0].imshow(img) axs[0].imshow(img)
axs[1].imshow(result) axs[1].imshow(result)
@ -38,10 +42,18 @@ def virtual_main(detector: Detector, test_img=None, test_img_dir=None):
mask_color[result > 0] = (0, 0, 255) mask_color[result > 0] = (0, 0, 255)
result_show = cv2.addWeighted(img, 1, mask_color, 0.5, 0) result_show = cv2.addWeighted(img, 1, mask_color, 0.5, 0)
axs[2].imshow(result_show) axs[2].imshow(result_show)
plt.title(f'{(t2 - t1) * 1000:.2f} ms') axs[0].set_title(
f' resize {(t2 - t1) * 1000:.2f} ms, predict {(t3 - t2) * 1000:.2f} ms, total {(t3 - t1) * 1000:.2f} ms')
plt.show() plt.show()
if test_model:
data_dir = "data/dataset"
color_dict = {(0, 0, 255): "yangeng"}
dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)
ground_truth = dataset['yangeng']
world_boundary = np.array([0, 0, 0, 255, 255, 255])
detector.visualize(world_boundary, sample_size=50000, class_max_num=5000, ground_truth=ground_truth)
if __name__ == '__main__': if __name__ == '__main__':
detector = AnonymousColorDetector(file_path='models/ELM_2022-07-18_17-22.mat') detector = AnonymousColorDetector(file_path='models/dt_2022-07-19_14-38.model')
virtual_main(detector, test_img='data/dataset/img/yangeng.bmp') virtual_main(detector, test_img=r'data/dataset/img/yangeng.bmp', test_model=True)

View File

@ -4,12 +4,17 @@
# @File: models.py # @File: models.py
# @Software:PyCharm、 # @Software:PyCharm、
import datetime import datetime
import pickle
import cv2 import cv2
import numpy as np import numpy as np
import scipy.io import scipy.io
import tqdm
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from utils import lab_scatter, read_labeled_img
from tqdm import tqdm
from elm import ELM from elm import ELM
@ -34,10 +39,12 @@ class Detector(object):
class AnonymousColorDetector(Detector): class AnonymousColorDetector(Detector):
def __init__(self, file_path: str = None): def __init__(self, file_path: str = None):
self.model = None self.model = None
self.model_type = 'None'
if file_path is not None: if file_path is not None:
self.model = ELM(model_path=file_path) self.load(file_path)
def fit(self, x: np.ndarray, world_boundary: np.ndarray, threshold: float, def fit(self, x: np.ndarray, world_boundary: np.ndarray = None, threshold: float = None,
is_generate_negative: bool = True, y: np.ndarray = None, model_selection='elm',
negative_sample_size: int = 1000, train_size: float = 0.8, is_save_dataset=False, **kwargs): negative_sample_size: int = 1000, train_size: float = 0.8, is_save_dataset=False, **kwargs):
""" """
拟合到指定的样本分布情况下根据x进行分布的变化 拟合到指定的样本分布情况下根据x进行分布的变化
@ -45,23 +52,36 @@ class AnonymousColorDetector(Detector):
:param x: ndarray类型的正样本数据给出的正样本形状为 n x feature_num :param x: ndarray类型的正样本数据给出的正样本形状为 n x feature_num
:param world_boundary: 整个世界的边界边界形状为 feature_num个下限, feature_num个上限 :param world_boundary: 整个世界的边界边界形状为 feature_num个下限, feature_num个上限
:param threshold: 与正样本之间的距离阈值大于多少则不认为是指定的样本类别 :param threshold: 与正样本之间的距离阈值大于多少则不认为是指定的样本类别
:param is_generate_negative: 是否生成负样本
:param y: 给出x对应的样本y
:param model_selection: 模型的选择, in ['elm', 'decision tree']
:param negative_sample_size: 负样本的数量 :param negative_sample_size: 负样本的数量
:param train_size: 训练集的比例, float :param train_size: 训练集的比例, float
:param is_save_dataset: 是否保存数据集 :param is_save_dataset: 是否保存数据集
:param kwargs: 与模型相对应的参数 :param kwargs: 与模型相对应的参数
:return: :return:
""" """
if model_selection == 'elm':
node_num = kwargs.get('node_num', 10) node_num = kwargs.get('node_num', 10)
self.model = ELM(input_size=x.shape[1], node_num=node_num, output_num=2, **kwargs) self.model = ELM(input_size=x.shape[1], node_num=node_num, output_num=2, **kwargs)
elif model_selection == 'dt':
self.model = DecisionTreeClassifier(**kwargs)
else:
raise ValueError("你看看我要的是啥")
self.model_type = model_selection
if is_generate_negative:
negative_samples = self.generate_negative_samples(x, world_boundary, threshold, negative_samples = self.generate_negative_samples(x, world_boundary, threshold,
sample_size=negative_sample_size) sample_size=negative_sample_size)
data_x, data_y = np.concatenate([x, negative_samples], axis=0), \ data_x, data_y = np.concatenate([x, negative_samples], axis=0), \
np.concatenate([np.ones(x.shape[0], dtype=int), np.concatenate([np.ones(x.shape[0], dtype=int),
np.zeros(negative_samples.shape[0], dtype=int)], axis=0) np.zeros(negative_samples.shape[0], dtype=int)], axis=0)
else:
data_x, data_y = x, y
if is_save_dataset: if is_save_dataset:
path = datetime.datetime.now().strftime("dataset_%Y-%m-%d_%H-%M.mat") path = datetime.datetime.now().strftime("dataset_%Y-%m-%d_%H-%M.mat")
scipy.io.savemat(path, {'x': data_x, 'y': data_y}) scipy.io.savemat(path, {'x': data_x, 'y': data_y})
x_train, x_val, y_train, y_val = train_test_split(data_x, data_y, train_size=train_size, shuffle=True) x_train, x_val, y_train, y_val = train_test_split(data_x, data_y, train_size=train_size, shuffle=True,
stratify=data_y)
self.model.fit(x_train, y_train) self.model.fit(x_train, y_train)
y_predict = self.model.predict(x_val) y_predict = self.model.predict(x_val)
print(classification_report(y_true=y_val, y_pred=y_predict)) print(classification_report(y_true=y_val, y_pred=y_predict))
@ -88,9 +108,11 @@ class AnonymousColorDetector(Detector):
:param threshold: 与正样本x之间的距离限制 :param threshold: 与正样本x之间的距离限制
:return: 负样本形状为(sample_size, feature_num) :return: 负样本形状为(sample_size, feature_num)
""" """
feature_num = x.shape[1] feature_num = x.shape[1]
negative_samples = np.zeros((sample_size, feature_num), dtype=x.dtype) negative_samples = np.zeros((sample_size, feature_num), dtype=x.dtype)
generated_sample_num = 0 generated_sample_num = 0
bar = tqdm(total=sample_size, ncols=100)
while generated_sample_num <= sample_size: while generated_sample_num <= sample_size:
generated_data = np.random.uniform(world_boundary[:feature_num], world_boundary[feature_num:], generated_data = np.random.uniform(world_boundary[:feature_num], world_boundary[feature_num:],
size=(sample_size, feature_num)) size=(sample_size, feature_num))
@ -100,20 +122,45 @@ class AnonymousColorDetector(Detector):
if not in_threshold: if not in_threshold:
negative_samples[sample_idx, :] = sample negative_samples[sample_idx, :] = sample
generated_sample_num += 1 generated_sample_num += 1
bar.update()
if generated_sample_num >= sample_size: if generated_sample_num >= sample_size:
break break
bar.close()
return negative_samples return negative_samples
def save(self, file_path=None): def save(self):
self.model.save(file_path) path = datetime.datetime.now().strftime(f"{self.model_type}_%Y-%m-%d_%H-%M.model")
with open(path, 'wb') as f:
pickle.dump((self.model_type, self.model), f)
def load(self, file_path): def load(self, file_path):
self.model.load(file_path) with open(file_path, 'rb') as model_file:
data = pickle.load(model_file)
self.model_type, self.model = data
def visualize(self, world_boundary: np.ndarray, sample_size: int, ground_truth=None,
**kwargs):
feature_num = world_boundary.shape[0] // 2
x = np.random.uniform(world_boundary[:feature_num], world_boundary[feature_num:],
size=(sample_size, feature_num))
pred_y = self.model.predict(x)
draw_dataset = {'Inside': x[pred_y == 1, :], 'Outside': x[pred_y == 0, :]}
if ground_truth is not None:
draw_dataset.update({'Given': ground_truth})
lab_scatter(draw_dataset, is_3d=True, is_ps_color_space=False, **kwargs)
if __name__ == '__main__': if __name__ == '__main__':
detector = AnonymousColorDetector() data_dir = "data/dataset"
x = np.array([[10, 30, 20], [10, 35, 25], [10, 35, 36]]) color_dict = {(0, 0, 255): "yangeng"}
world_boundary = np.array([0, -127, -127, 100, 127, 127]) dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)
detector.fit(x, world_boundary, threshold=5, negative_sample_size=2000) ground_truth = dataset['yangeng']
detector.load('ELM_2022-07-18_17-01.mat') detector = AnonymousColorDetector(file_path='models/dt_2022-07-19_14-38.model')
# x = np.array([[10, 30, 20], [10, 35, 25], [10, 35, 36]])
world_boundary = np.array([0, 0, 0, 255, 255, 255])
# detector.fit(x, world_boundary, threshold=5, negative_sample_size=2000)
detector.visualize(world_boundary, sample_size=50000, class_max_num=5000, ground_truth=ground_truth)
data = scipy.io.loadmat('data/dataset_2022-07-19_11-35.mat')
x, y = data['x'], data['y']
dataset = {'inside': x[y.ravel() == 1, :], "outside": x[y.ravel() == 0, :]}
lab_scatter(dataset, class_max_num=5000, is_3d=True, is_ps_color_space=False)

View File

@ -55,7 +55,7 @@ def read_labeled_img(dataset_dir: str, color_dict: dict, ext='.bmp', is_ps_color
return total_dataset return total_dataset
def lab_scatter(dataset: dict, class_max_num=None, is_3d=False, is_ps_color_space=True): def lab_scatter(dataset: dict, class_max_num=None, is_3d=False, is_ps_color_space=True, **kwargs):
""" """
在lab色彩空间内绘制3维数据分布情况 在lab色彩空间内绘制3维数据分布情况