mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 14:23:55 +00:00
模型测试
This commit is contained in:
parent
99fc21cb84
commit
73973216f6
2
.gitignore
vendored
2
.gitignore
vendored
@ -3,7 +3,7 @@
|
|||||||
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
||||||
|
|
||||||
data/*
|
data/*
|
||||||
|
models/*
|
||||||
# User-specific stuff
|
# User-specific stuff
|
||||||
.idea/**/workspace.xml
|
.idea/**/workspace.xml
|
||||||
.idea/**/tasks.xml
|
.idea/**/tasks.xml
|
||||||
|
|||||||
6
elm.py
6
elm.py
@ -36,10 +36,16 @@ class ELM:
|
|||||||
:param bias: [array] shape: 1 x node_num
|
:param bias: [array] shape: 1 x node_num
|
||||||
:param beta: [array] shape: node_num, output_num
|
:param beta: [array] shape: node_num, output_num
|
||||||
:param rand_seed: [int] the random seed
|
:param rand_seed: [int] the random seed
|
||||||
|
:param model_path: [str] the trained model path
|
||||||
"""
|
"""
|
||||||
if rand_seed is not None:
|
if rand_seed is not None:
|
||||||
np.random.seed(rand_seed)
|
np.random.seed(rand_seed)
|
||||||
|
|
||||||
|
if 'model_path' in kwargs:
|
||||||
|
data = scipy.io.loadmat(kwargs['model_path'])
|
||||||
|
self.w, self.b, self.beta = data['w'], data['b'], data['beta']
|
||||||
|
return
|
||||||
|
|
||||||
if weight is not None:
|
if weight is not None:
|
||||||
self.w = weight
|
self.w = weight
|
||||||
else:
|
else:
|
||||||
|
|||||||
47
main_test.py
Normal file
47
main_test.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
# -*- codeing = utf-8 -*-
|
||||||
|
# Time : 2022/7/19 10:49
|
||||||
|
# @Auther : zhouchao
|
||||||
|
# @File: main_test.py
|
||||||
|
# @Software:PyCharm
|
||||||
|
import time
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from models import Detector, AnonymousColorDetector
|
||||||
|
|
||||||
|
|
||||||
|
def virtual_main(detector: Detector, test_img=None, test_img_dir=None):
|
||||||
|
"""
|
||||||
|
虚拟读图测试程序
|
||||||
|
|
||||||
|
:param detector: 杂质探测器,需要继承Detector类
|
||||||
|
:param test_img: 测试图像,rgb格式的图片或者路径
|
||||||
|
:param test_img_dir: 测试图像文件夹
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if test_img is not None:
|
||||||
|
if isinstance(test_img, str):
|
||||||
|
img = cv2.imread(test_img)[:, :, ::-1]
|
||||||
|
elif isinstance(test_img, np.ndarray):
|
||||||
|
img = test_img
|
||||||
|
else:
|
||||||
|
raise TypeError("test img should be np.ndarray or str")
|
||||||
|
t1 = time.time()
|
||||||
|
result = detector.predict(img)
|
||||||
|
t2 = time.time()
|
||||||
|
fig, axs = plt.subplots(3, 1)
|
||||||
|
axs[0].imshow(img)
|
||||||
|
axs[1].imshow(result)
|
||||||
|
mask_color = np.zeros_like(img)
|
||||||
|
mask_color[result > 0] = (0, 0, 255)
|
||||||
|
result_show = cv2.addWeighted(img, 1, mask_color, 0.5, 0)
|
||||||
|
axs[2].imshow(result_show)
|
||||||
|
plt.title(f'{(t2 - t1) * 1000:.2f} ms')
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
detector = AnonymousColorDetector(file_path='models/ELM_2022-07-18_17-22.mat')
|
||||||
|
virtual_main(detector, test_img='data/dataset/img/yangeng.bmp')
|
||||||
53
models.py
53
models.py
@ -3,32 +3,54 @@
|
|||||||
# @Auther : zhouchao
|
# @Auther : zhouchao
|
||||||
# @File: models.py
|
# @File: models.py
|
||||||
# @Software:PyCharm、
|
# @Software:PyCharm、
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import scipy.io
|
||||||
from sklearn.metrics import classification_report
|
from sklearn.metrics import classification_report
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
|
|
||||||
from elm import ELM
|
from elm import ELM
|
||||||
|
|
||||||
|
|
||||||
class AnonymousColorDetector(object):
|
class Detector(object):
|
||||||
def __init__(self, file_path=None):
|
def __int__(self, *args, **kwargs):
|
||||||
self.model = None
|
raise NotImplementedError
|
||||||
|
|
||||||
def fit(self, x: np.ndarray, world_boundary: np.ndarray, threshold: float, model_selected: str = 'elm',
|
def predict(self, *args, **kwargs):
|
||||||
negative_sample_size: int = 1000, train_size: float = 0.8, **kwargs):
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def load(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def save(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def fit(self, *args, **kwargs):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
class AnonymousColorDetector(Detector):
|
||||||
|
def __init__(self, file_path: str = None):
|
||||||
|
self.model = None
|
||||||
|
if file_path is not None:
|
||||||
|
self.model = ELM(model_path=file_path)
|
||||||
|
|
||||||
|
def fit(self, x: np.ndarray, world_boundary: np.ndarray, threshold: float,
|
||||||
|
negative_sample_size: int = 1000, train_size: float = 0.8, is_save_dataset=False, **kwargs):
|
||||||
"""
|
"""
|
||||||
拟合到指定的样本分布情况下,根据x进行分布的变化。
|
拟合到指定的样本分布情况下,根据x进行分布的变化。
|
||||||
|
|
||||||
:param x: ndarray类型的正样本数据,给出的正样本形状为 n x feature_num
|
:param x: ndarray类型的正样本数据,给出的正样本形状为 n x feature_num
|
||||||
:param world_boundary: 整个世界的边界,边界形状为 feature_num个下限, feature_num个上限
|
:param world_boundary: 整个世界的边界,边界形状为 feature_num个下限, feature_num个上限
|
||||||
:param threshold: 与正样本之间的距离阈值大于多少则不认为是指定的样本类别
|
:param threshold: 与正样本之间的距离阈值大于多少则不认为是指定的样本类别
|
||||||
:param model_selected: 选择模型,默认为elm
|
|
||||||
:param negative_sample_size: 负样本的数量
|
:param negative_sample_size: 负样本的数量
|
||||||
|
:param train_size: 训练集的比例, float
|
||||||
|
:param is_save_dataset: 是否保存数据集
|
||||||
:param kwargs: 与模型相对应的参数
|
:param kwargs: 与模型相对应的参数
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
assert model_selected in ['elm']
|
|
||||||
if model_selected == 'elm':
|
|
||||||
node_num = kwargs.get('node_num', 10)
|
node_num = kwargs.get('node_num', 10)
|
||||||
self.model = ELM(input_size=x.shape[1], node_num=node_num, output_num=2, **kwargs)
|
self.model = ELM(input_size=x.shape[1], node_num=node_num, output_num=2, **kwargs)
|
||||||
negative_samples = self.generate_negative_samples(x, world_boundary, threshold,
|
negative_samples = self.generate_negative_samples(x, world_boundary, threshold,
|
||||||
@ -36,14 +58,25 @@ class AnonymousColorDetector(object):
|
|||||||
data_x, data_y = np.concatenate([x, negative_samples], axis=0), \
|
data_x, data_y = np.concatenate([x, negative_samples], axis=0), \
|
||||||
np.concatenate([np.ones(x.shape[0], dtype=int),
|
np.concatenate([np.ones(x.shape[0], dtype=int),
|
||||||
np.zeros(negative_samples.shape[0], dtype=int)], axis=0)
|
np.zeros(negative_samples.shape[0], dtype=int)], axis=0)
|
||||||
|
if is_save_dataset:
|
||||||
|
path = datetime.datetime.now().strftime("dataset_%Y-%m-%d_%H-%M.mat")
|
||||||
|
scipy.io.savemat(path, {'x': data_x, 'y': data_y})
|
||||||
x_train, x_val, y_train, y_val = train_test_split(data_x, data_y, train_size=train_size, shuffle=True)
|
x_train, x_val, y_train, y_val = train_test_split(data_x, data_y, train_size=train_size, shuffle=True)
|
||||||
|
|
||||||
self.model.fit(x_train, y_train)
|
self.model.fit(x_train, y_train)
|
||||||
y_predict = self.model.predict(x_val)
|
y_predict = self.model.predict(x_val)
|
||||||
print(classification_report(y_true=y_val, y_pred=y_predict))
|
print(classification_report(y_true=y_val, y_pred=y_predict))
|
||||||
|
|
||||||
def predict(self, x):
|
def predict(self, x):
|
||||||
return self.model.predict(x)
|
"""
|
||||||
|
输入rgb彩色图像
|
||||||
|
|
||||||
|
:param x: rgb彩色图像,np.ndarray
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
w, h = x.shape[1], x.shape[0]
|
||||||
|
x = cv2.cvtColor(x, cv2.COLOR_RGB2LAB)
|
||||||
|
result = self.model.predict(x.reshape(w * h, -1))
|
||||||
|
return result.reshape(h, w)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_negative_samples(x: np.ndarray, world_boundary: np.ndarray, threshold: float, sample_size: int):
|
def generate_negative_samples(x: np.ndarray, world_boundary: np.ndarray, threshold: float, sample_size: int):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user