模型测试

This commit is contained in:
FEIJINTI 2022-07-19 11:24:53 +08:00
parent 99fc21cb84
commit 73973216f6
4 changed files with 99 additions and 13 deletions

2
.gitignore vendored
View File

@ -3,7 +3,7 @@
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
data/* data/*
models/*
# User-specific stuff # User-specific stuff
.idea/**/workspace.xml .idea/**/workspace.xml
.idea/**/tasks.xml .idea/**/tasks.xml

6
elm.py
View File

@ -36,10 +36,16 @@ class ELM:
:param bias: [array] shape: 1 x node_num :param bias: [array] shape: 1 x node_num
:param beta: [array] shape: node_num, output_num :param beta: [array] shape: node_num, output_num
:param rand_seed: [int] the random seed :param rand_seed: [int] the random seed
:param model_path: [str] the trained model path
""" """
if rand_seed is not None: if rand_seed is not None:
np.random.seed(rand_seed) np.random.seed(rand_seed)
if 'model_path' in kwargs:
data = scipy.io.loadmat(kwargs['model_path'])
self.w, self.b, self.beta = data['w'], data['b'], data['beta']
return
if weight is not None: if weight is not None:
self.w = weight self.w = weight
else: else:

47
main_test.py Normal file
View File

@ -0,0 +1,47 @@
# -*- codeing = utf-8 -*-
# Time : 2022/7/19 10:49
# @Auther : zhouchao
# @File: main_test.py
# @Software:PyCharm
import time
import cv2
import matplotlib.pyplot as plt
import numpy as np
from models import Detector, AnonymousColorDetector
def virtual_main(detector: Detector, test_img=None, test_img_dir=None):
"""
虚拟读图测试程序
:param detector: 杂质探测器需要继承Detector类
:param test_img: 测试图像rgb格式的图片或者路径
:param test_img_dir: 测试图像文件夹
:return:
"""
if test_img is not None:
if isinstance(test_img, str):
img = cv2.imread(test_img)[:, :, ::-1]
elif isinstance(test_img, np.ndarray):
img = test_img
else:
raise TypeError("test img should be np.ndarray or str")
t1 = time.time()
result = detector.predict(img)
t2 = time.time()
fig, axs = plt.subplots(3, 1)
axs[0].imshow(img)
axs[1].imshow(result)
mask_color = np.zeros_like(img)
mask_color[result > 0] = (0, 0, 255)
result_show = cv2.addWeighted(img, 1, mask_color, 0.5, 0)
axs[2].imshow(result_show)
plt.title(f'{(t2 - t1) * 1000:.2f} ms')
plt.show()
if __name__ == '__main__':
detector = AnonymousColorDetector(file_path='models/ELM_2022-07-18_17-22.mat')
virtual_main(detector, test_img='data/dataset/img/yangeng.bmp')

View File

@ -3,32 +3,54 @@
# @Auther : zhouchao # @Auther : zhouchao
# @File: models.py # @File: models.py
# @Software:PyCharm、 # @Software:PyCharm、
import datetime
import cv2
import numpy as np import numpy as np
import scipy.io
from sklearn.metrics import classification_report from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from elm import ELM from elm import ELM
class AnonymousColorDetector(object): class Detector(object):
def __init__(self, file_path=None): def __int__(self, *args, **kwargs):
self.model = None raise NotImplementedError
def fit(self, x: np.ndarray, world_boundary: np.ndarray, threshold: float, model_selected: str = 'elm', def predict(self, *args, **kwargs):
negative_sample_size: int = 1000, train_size: float = 0.8, **kwargs): raise NotImplementedError
def load(self, *args, **kwargs):
raise NotImplementedError
def save(self, *args, **kwargs):
raise NotImplementedError
def fit(self, *args, **kwargs):
raise NotImplementedError
class AnonymousColorDetector(Detector):
def __init__(self, file_path: str = None):
self.model = None
if file_path is not None:
self.model = ELM(model_path=file_path)
def fit(self, x: np.ndarray, world_boundary: np.ndarray, threshold: float,
negative_sample_size: int = 1000, train_size: float = 0.8, is_save_dataset=False, **kwargs):
""" """
拟合到指定的样本分布情况下根据x进行分布的变化 拟合到指定的样本分布情况下根据x进行分布的变化
:param x: ndarray类型的正样本数据给出的正样本形状为 n x feature_num :param x: ndarray类型的正样本数据给出的正样本形状为 n x feature_num
:param world_boundary: 整个世界的边界边界形状为 feature_num个下限, feature_num个上限 :param world_boundary: 整个世界的边界边界形状为 feature_num个下限, feature_num个上限
:param threshold: 与正样本之间的距离阈值大于多少则不认为是指定的样本类别 :param threshold: 与正样本之间的距离阈值大于多少则不认为是指定的样本类别
:param model_selected: 选择模型默认为elm
:param negative_sample_size: 负样本的数量 :param negative_sample_size: 负样本的数量
:param train_size: 训练集的比例, float
:param is_save_dataset: 是否保存数据集
:param kwargs: 与模型相对应的参数 :param kwargs: 与模型相对应的参数
:return: :return:
""" """
assert model_selected in ['elm']
if model_selected == 'elm':
node_num = kwargs.get('node_num', 10) node_num = kwargs.get('node_num', 10)
self.model = ELM(input_size=x.shape[1], node_num=node_num, output_num=2, **kwargs) self.model = ELM(input_size=x.shape[1], node_num=node_num, output_num=2, **kwargs)
negative_samples = self.generate_negative_samples(x, world_boundary, threshold, negative_samples = self.generate_negative_samples(x, world_boundary, threshold,
@ -36,14 +58,25 @@ class AnonymousColorDetector(object):
data_x, data_y = np.concatenate([x, negative_samples], axis=0), \ data_x, data_y = np.concatenate([x, negative_samples], axis=0), \
np.concatenate([np.ones(x.shape[0], dtype=int), np.concatenate([np.ones(x.shape[0], dtype=int),
np.zeros(negative_samples.shape[0], dtype=int)], axis=0) np.zeros(negative_samples.shape[0], dtype=int)], axis=0)
if is_save_dataset:
path = datetime.datetime.now().strftime("dataset_%Y-%m-%d_%H-%M.mat")
scipy.io.savemat(path, {'x': data_x, 'y': data_y})
x_train, x_val, y_train, y_val = train_test_split(data_x, data_y, train_size=train_size, shuffle=True) x_train, x_val, y_train, y_val = train_test_split(data_x, data_y, train_size=train_size, shuffle=True)
self.model.fit(x_train, y_train) self.model.fit(x_train, y_train)
y_predict = self.model.predict(x_val) y_predict = self.model.predict(x_val)
print(classification_report(y_true=y_val, y_pred=y_predict)) print(classification_report(y_true=y_val, y_pred=y_predict))
def predict(self, x): def predict(self, x):
return self.model.predict(x) """
输入rgb彩色图像
:param x: rgb彩色图像,np.ndarray
:return:
"""
w, h = x.shape[1], x.shape[0]
x = cv2.cvtColor(x, cv2.COLOR_RGB2LAB)
result = self.model.predict(x.reshape(w * h, -1))
return result.reshape(h, w)
@staticmethod @staticmethod
def generate_negative_samples(x: np.ndarray, world_boundary: np.ndarray, threshold: float, sample_size: int): def generate_negative_samples(x: np.ndarray, world_boundary: np.ndarray, threshold: float, sample_size: int):