mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 22:33:54 +00:00
添加了dt方法
This commit is contained in:
parent
73973216f6
commit
bfba067285
@ -7,7 +7,8 @@
|
||||
|
||||
|
||||
import numpy as np
|
||||
|
||||
import scipy
|
||||
from imblearn.under_sampling import RandomUnderSampler
|
||||
from models import AnonymousColorDetector
|
||||
from utils import read_labeled_img
|
||||
|
||||
@ -17,8 +18,15 @@ from utils import read_labeled_img
|
||||
|
||||
|
||||
data_dir = "data/dataset"
|
||||
color_dict = {(0, 0, 255): "yangeng"}
|
||||
color_dict = {(0, 0, 255): "yangeng", (255, 0, 0): 'beijing'}
|
||||
label_index = {"yangeng": 1, "beijing": 0}
|
||||
dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)
|
||||
rus = RandomUnderSampler(random_state=0)
|
||||
x_list, y_list = np.concatenate([v for k, v in dataset.items()], axis=0).tolist(), \
|
||||
np.concatenate([np.ones((v.shape[0],)) * label_index[k] for k, v in dataset.items()], axis=0).tolist()
|
||||
|
||||
x_resampled, y_resampled = rus.fit_resample(x_list, y_list)
|
||||
dataset = {"inside": np.array(x_resampled)}
|
||||
|
||||
# ## 模型训练
|
||||
|
||||
@ -32,10 +40,14 @@ negative_sample_num = None # None或者一个数字
|
||||
world_boundary = np.array([0, 0, 0, 255, 255, 255])
|
||||
# 对数据进行预处理
|
||||
x = np.concatenate([v for k, v in dataset.items()], axis=0)
|
||||
negative_sample_num = int(x.shape[0] * 0.7) if negative_sample_num is None else negative_sample_num
|
||||
negative_sample_num = int(x.shape[0] * 1.2) if negative_sample_num is None else negative_sample_num
|
||||
|
||||
model = AnonymousColorDetector()
|
||||
|
||||
model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7)
|
||||
model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7,
|
||||
is_save_dataset=True, model_selection='dt')
|
||||
# data = scipy.io.loadmat('dataset_2022-07-19_15-07.mat')
|
||||
# x, y = data['x'], data['y'].ravel()
|
||||
# model.fit(x, y=y, is_generate_negative=False, model_selection='dt')
|
||||
|
||||
model.save()
|
||||
|
||||
22
main_test.py
22
main_test.py
@ -10,15 +10,17 @@ import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
from models import Detector, AnonymousColorDetector
|
||||
from utils import read_labeled_img
|
||||
|
||||
|
||||
def virtual_main(detector: Detector, test_img=None, test_img_dir=None):
|
||||
def virtual_main(detector: Detector, test_img=None, test_img_dir=None, test_model=False):
|
||||
"""
|
||||
虚拟读图测试程序
|
||||
|
||||
:param detector: 杂质探测器,需要继承Detector类
|
||||
:param test_img: 测试图像,rgb格式的图片或者路径
|
||||
:param test_img_dir: 测试图像文件夹
|
||||
:param test_model: 是否进行模型约束性测试
|
||||
:return:
|
||||
"""
|
||||
if test_img is not None:
|
||||
@ -29,8 +31,10 @@ def virtual_main(detector: Detector, test_img=None, test_img_dir=None):
|
||||
else:
|
||||
raise TypeError("test img should be np.ndarray or str")
|
||||
t1 = time.time()
|
||||
result = detector.predict(img)
|
||||
img = cv2.resize(img, (1024, 256))
|
||||
t2 = time.time()
|
||||
result = 1 - detector.predict(img)
|
||||
t3 = time.time()
|
||||
fig, axs = plt.subplots(3, 1)
|
||||
axs[0].imshow(img)
|
||||
axs[1].imshow(result)
|
||||
@ -38,10 +42,18 @@ def virtual_main(detector: Detector, test_img=None, test_img_dir=None):
|
||||
mask_color[result > 0] = (0, 0, 255)
|
||||
result_show = cv2.addWeighted(img, 1, mask_color, 0.5, 0)
|
||||
axs[2].imshow(result_show)
|
||||
plt.title(f'{(t2 - t1) * 1000:.2f} ms')
|
||||
axs[0].set_title(
|
||||
f' resize {(t2 - t1) * 1000:.2f} ms, predict {(t3 - t2) * 1000:.2f} ms, total {(t3 - t1) * 1000:.2f} ms')
|
||||
plt.show()
|
||||
if test_model:
|
||||
data_dir = "data/dataset"
|
||||
color_dict = {(0, 0, 255): "yangeng"}
|
||||
dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)
|
||||
ground_truth = dataset['yangeng']
|
||||
world_boundary = np.array([0, 0, 0, 255, 255, 255])
|
||||
detector.visualize(world_boundary, sample_size=50000, class_max_num=5000, ground_truth=ground_truth)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
detector = AnonymousColorDetector(file_path='models/ELM_2022-07-18_17-22.mat')
|
||||
virtual_main(detector, test_img='data/dataset/img/yangeng.bmp')
|
||||
detector = AnonymousColorDetector(file_path='models/dt_2022-07-19_14-38.model')
|
||||
virtual_main(detector, test_img=r'data/dataset/img/yangeng.bmp', test_model=True)
|
||||
|
||||
83
models.py
83
models.py
@ -4,12 +4,17 @@
|
||||
# @File: models.py
|
||||
# @Software:PyCharm、
|
||||
import datetime
|
||||
import pickle
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import scipy.io
|
||||
import tqdm
|
||||
from sklearn.tree import DecisionTreeClassifier
|
||||
from sklearn.metrics import classification_report
|
||||
from sklearn.model_selection import train_test_split
|
||||
from utils import lab_scatter, read_labeled_img
|
||||
from tqdm import tqdm
|
||||
|
||||
from elm import ELM
|
||||
|
||||
@ -34,10 +39,12 @@ class Detector(object):
|
||||
class AnonymousColorDetector(Detector):
|
||||
def __init__(self, file_path: str = None):
|
||||
self.model = None
|
||||
self.model_type = 'None'
|
||||
if file_path is not None:
|
||||
self.model = ELM(model_path=file_path)
|
||||
self.load(file_path)
|
||||
|
||||
def fit(self, x: np.ndarray, world_boundary: np.ndarray, threshold: float,
|
||||
def fit(self, x: np.ndarray, world_boundary: np.ndarray = None, threshold: float = None,
|
||||
is_generate_negative: bool = True, y: np.ndarray = None, model_selection='elm',
|
||||
negative_sample_size: int = 1000, train_size: float = 0.8, is_save_dataset=False, **kwargs):
|
||||
"""
|
||||
拟合到指定的样本分布情况下,根据x进行分布的变化。
|
||||
@ -45,23 +52,36 @@ class AnonymousColorDetector(Detector):
|
||||
:param x: ndarray类型的正样本数据,给出的正样本形状为 n x feature_num
|
||||
:param world_boundary: 整个世界的边界,边界形状为 feature_num个下限, feature_num个上限
|
||||
:param threshold: 与正样本之间的距离阈值大于多少则不认为是指定的样本类别
|
||||
:param is_generate_negative: 是否生成负样本
|
||||
:param y: 给出x对应的样本y
|
||||
:param model_selection: 模型的选择, in ['elm', 'decision tree']
|
||||
:param negative_sample_size: 负样本的数量
|
||||
:param train_size: 训练集的比例, float
|
||||
:param is_save_dataset: 是否保存数据集
|
||||
:param kwargs: 与模型相对应的参数
|
||||
:return:
|
||||
"""
|
||||
node_num = kwargs.get('node_num', 10)
|
||||
self.model = ELM(input_size=x.shape[1], node_num=node_num, output_num=2, **kwargs)
|
||||
negative_samples = self.generate_negative_samples(x, world_boundary, threshold,
|
||||
sample_size=negative_sample_size)
|
||||
data_x, data_y = np.concatenate([x, negative_samples], axis=0), \
|
||||
np.concatenate([np.ones(x.shape[0], dtype=int),
|
||||
np.zeros(negative_samples.shape[0], dtype=int)], axis=0)
|
||||
if model_selection == 'elm':
|
||||
node_num = kwargs.get('node_num', 10)
|
||||
self.model = ELM(input_size=x.shape[1], node_num=node_num, output_num=2, **kwargs)
|
||||
elif model_selection == 'dt':
|
||||
self.model = DecisionTreeClassifier(**kwargs)
|
||||
else:
|
||||
raise ValueError("你看看我要的是啥")
|
||||
self.model_type = model_selection
|
||||
if is_generate_negative:
|
||||
negative_samples = self.generate_negative_samples(x, world_boundary, threshold,
|
||||
sample_size=negative_sample_size)
|
||||
data_x, data_y = np.concatenate([x, negative_samples], axis=0), \
|
||||
np.concatenate([np.ones(x.shape[0], dtype=int),
|
||||
np.zeros(negative_samples.shape[0], dtype=int)], axis=0)
|
||||
else:
|
||||
data_x, data_y = x, y
|
||||
if is_save_dataset:
|
||||
path = datetime.datetime.now().strftime("dataset_%Y-%m-%d_%H-%M.mat")
|
||||
scipy.io.savemat(path, {'x': data_x, 'y': data_y})
|
||||
x_train, x_val, y_train, y_val = train_test_split(data_x, data_y, train_size=train_size, shuffle=True)
|
||||
x_train, x_val, y_train, y_val = train_test_split(data_x, data_y, train_size=train_size, shuffle=True,
|
||||
stratify=data_y)
|
||||
self.model.fit(x_train, y_train)
|
||||
y_predict = self.model.predict(x_val)
|
||||
print(classification_report(y_true=y_val, y_pred=y_predict))
|
||||
@ -88,9 +108,11 @@ class AnonymousColorDetector(Detector):
|
||||
:param threshold: 与正样本x之间的距离限制
|
||||
:return: 负样本形状为:(sample_size, feature_num)
|
||||
"""
|
||||
|
||||
feature_num = x.shape[1]
|
||||
negative_samples = np.zeros((sample_size, feature_num), dtype=x.dtype)
|
||||
generated_sample_num = 0
|
||||
bar = tqdm(total=sample_size, ncols=100)
|
||||
while generated_sample_num <= sample_size:
|
||||
generated_data = np.random.uniform(world_boundary[:feature_num], world_boundary[feature_num:],
|
||||
size=(sample_size, feature_num))
|
||||
@ -100,20 +122,45 @@ class AnonymousColorDetector(Detector):
|
||||
if not in_threshold:
|
||||
negative_samples[sample_idx, :] = sample
|
||||
generated_sample_num += 1
|
||||
bar.update()
|
||||
if generated_sample_num >= sample_size:
|
||||
break
|
||||
bar.close()
|
||||
return negative_samples
|
||||
|
||||
def save(self, file_path=None):
|
||||
self.model.save(file_path)
|
||||
def save(self):
|
||||
path = datetime.datetime.now().strftime(f"{self.model_type}_%Y-%m-%d_%H-%M.model")
|
||||
with open(path, 'wb') as f:
|
||||
pickle.dump((self.model_type, self.model), f)
|
||||
|
||||
def load(self, file_path):
|
||||
self.model.load(file_path)
|
||||
with open(file_path, 'rb') as model_file:
|
||||
data = pickle.load(model_file)
|
||||
self.model_type, self.model = data
|
||||
|
||||
def visualize(self, world_boundary: np.ndarray, sample_size: int, ground_truth=None,
|
||||
**kwargs):
|
||||
feature_num = world_boundary.shape[0] // 2
|
||||
x = np.random.uniform(world_boundary[:feature_num], world_boundary[feature_num:],
|
||||
size=(sample_size, feature_num))
|
||||
pred_y = self.model.predict(x)
|
||||
draw_dataset = {'Inside': x[pred_y == 1, :], 'Outside': x[pred_y == 0, :]}
|
||||
if ground_truth is not None:
|
||||
draw_dataset.update({'Given': ground_truth})
|
||||
lab_scatter(draw_dataset, is_3d=True, is_ps_color_space=False, **kwargs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
detector = AnonymousColorDetector()
|
||||
x = np.array([[10, 30, 20], [10, 35, 25], [10, 35, 36]])
|
||||
world_boundary = np.array([0, -127, -127, 100, 127, 127])
|
||||
detector.fit(x, world_boundary, threshold=5, negative_sample_size=2000)
|
||||
detector.load('ELM_2022-07-18_17-01.mat')
|
||||
data_dir = "data/dataset"
|
||||
color_dict = {(0, 0, 255): "yangeng"}
|
||||
dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)
|
||||
ground_truth = dataset['yangeng']
|
||||
detector = AnonymousColorDetector(file_path='models/dt_2022-07-19_14-38.model')
|
||||
# x = np.array([[10, 30, 20], [10, 35, 25], [10, 35, 36]])
|
||||
world_boundary = np.array([0, 0, 0, 255, 255, 255])
|
||||
# detector.fit(x, world_boundary, threshold=5, negative_sample_size=2000)
|
||||
detector.visualize(world_boundary, sample_size=50000, class_max_num=5000, ground_truth=ground_truth)
|
||||
data = scipy.io.loadmat('data/dataset_2022-07-19_11-35.mat')
|
||||
x, y = data['x'], data['y']
|
||||
dataset = {'inside': x[y.ravel() == 1, :], "outside": x[y.ravel() == 0, :]}
|
||||
lab_scatter(dataset, class_max_num=5000, is_3d=True, is_ps_color_space=False)
|
||||
|
||||
2
utils.py
2
utils.py
@ -55,7 +55,7 @@ def read_labeled_img(dataset_dir: str, color_dict: dict, ext='.bmp', is_ps_color
|
||||
return total_dataset
|
||||
|
||||
|
||||
def lab_scatter(dataset: dict, class_max_num=None, is_3d=False, is_ps_color_space=True):
|
||||
def lab_scatter(dataset: dict, class_max_num=None, is_3d=False, is_ps_color_space=True, **kwargs):
|
||||
"""
|
||||
在lab色彩空间内绘制3维数据分布情况
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user