mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 06:13:53 +00:00
模型测试
This commit is contained in:
parent
99fc21cb84
commit
73973216f6
2
.gitignore
vendored
2
.gitignore
vendored
@ -3,7 +3,7 @@
|
||||
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
||||
|
||||
data/*
|
||||
|
||||
models/*
|
||||
# User-specific stuff
|
||||
.idea/**/workspace.xml
|
||||
.idea/**/tasks.xml
|
||||
|
||||
6
elm.py
6
elm.py
@ -36,10 +36,16 @@ class ELM:
|
||||
:param bias: [array] shape: 1 x node_num
|
||||
:param beta: [array] shape: node_num, output_num
|
||||
:param rand_seed: [int] the random seed
|
||||
:param model_path: [str] the trained model path
|
||||
"""
|
||||
if rand_seed is not None:
|
||||
np.random.seed(rand_seed)
|
||||
|
||||
if 'model_path' in kwargs:
|
||||
data = scipy.io.loadmat(kwargs['model_path'])
|
||||
self.w, self.b, self.beta = data['w'], data['b'], data['beta']
|
||||
return
|
||||
|
||||
if weight is not None:
|
||||
self.w = weight
|
||||
else:
|
||||
|
||||
47
main_test.py
Normal file
47
main_test.py
Normal file
@ -0,0 +1,47 @@
|
||||
# -*- codeing = utf-8 -*-
|
||||
# Time : 2022/7/19 10:49
|
||||
# @Auther : zhouchao
|
||||
# @File: main_test.py
|
||||
# @Software:PyCharm
|
||||
import time
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
||||
from models import Detector, AnonymousColorDetector
|
||||
|
||||
|
||||
def virtual_main(detector: Detector, test_img=None, test_img_dir=None):
|
||||
"""
|
||||
虚拟读图测试程序
|
||||
|
||||
:param detector: 杂质探测器,需要继承Detector类
|
||||
:param test_img: 测试图像,rgb格式的图片或者路径
|
||||
:param test_img_dir: 测试图像文件夹
|
||||
:return:
|
||||
"""
|
||||
if test_img is not None:
|
||||
if isinstance(test_img, str):
|
||||
img = cv2.imread(test_img)[:, :, ::-1]
|
||||
elif isinstance(test_img, np.ndarray):
|
||||
img = test_img
|
||||
else:
|
||||
raise TypeError("test img should be np.ndarray or str")
|
||||
t1 = time.time()
|
||||
result = detector.predict(img)
|
||||
t2 = time.time()
|
||||
fig, axs = plt.subplots(3, 1)
|
||||
axs[0].imshow(img)
|
||||
axs[1].imshow(result)
|
||||
mask_color = np.zeros_like(img)
|
||||
mask_color[result > 0] = (0, 0, 255)
|
||||
result_show = cv2.addWeighted(img, 1, mask_color, 0.5, 0)
|
||||
axs[2].imshow(result_show)
|
||||
plt.title(f'{(t2 - t1) * 1000:.2f} ms')
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
detector = AnonymousColorDetector(file_path='models/ELM_2022-07-18_17-22.mat')
|
||||
virtual_main(detector, test_img='data/dataset/img/yangeng.bmp')
|
||||
53
models.py
53
models.py
@ -3,32 +3,54 @@
|
||||
# @Auther : zhouchao
|
||||
# @File: models.py
|
||||
# @Software:PyCharm、
|
||||
import datetime
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import scipy.io
|
||||
from sklearn.metrics import classification_report
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
from elm import ELM
|
||||
|
||||
|
||||
class AnonymousColorDetector(object):
|
||||
def __init__(self, file_path=None):
|
||||
self.model = None
|
||||
class Detector(object):
|
||||
def __int__(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def fit(self, x: np.ndarray, world_boundary: np.ndarray, threshold: float, model_selected: str = 'elm',
|
||||
negative_sample_size: int = 1000, train_size: float = 0.8, **kwargs):
|
||||
def predict(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def load(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def fit(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class AnonymousColorDetector(Detector):
|
||||
def __init__(self, file_path: str = None):
|
||||
self.model = None
|
||||
if file_path is not None:
|
||||
self.model = ELM(model_path=file_path)
|
||||
|
||||
def fit(self, x: np.ndarray, world_boundary: np.ndarray, threshold: float,
|
||||
negative_sample_size: int = 1000, train_size: float = 0.8, is_save_dataset=False, **kwargs):
|
||||
"""
|
||||
拟合到指定的样本分布情况下,根据x进行分布的变化。
|
||||
|
||||
:param x: ndarray类型的正样本数据,给出的正样本形状为 n x feature_num
|
||||
:param world_boundary: 整个世界的边界,边界形状为 feature_num个下限, feature_num个上限
|
||||
:param threshold: 与正样本之间的距离阈值大于多少则不认为是指定的样本类别
|
||||
:param model_selected: 选择模型,默认为elm
|
||||
:param negative_sample_size: 负样本的数量
|
||||
:param train_size: 训练集的比例, float
|
||||
:param is_save_dataset: 是否保存数据集
|
||||
:param kwargs: 与模型相对应的参数
|
||||
:return:
|
||||
"""
|
||||
assert model_selected in ['elm']
|
||||
if model_selected == 'elm':
|
||||
node_num = kwargs.get('node_num', 10)
|
||||
self.model = ELM(input_size=x.shape[1], node_num=node_num, output_num=2, **kwargs)
|
||||
negative_samples = self.generate_negative_samples(x, world_boundary, threshold,
|
||||
@ -36,14 +58,25 @@ class AnonymousColorDetector(object):
|
||||
data_x, data_y = np.concatenate([x, negative_samples], axis=0), \
|
||||
np.concatenate([np.ones(x.shape[0], dtype=int),
|
||||
np.zeros(negative_samples.shape[0], dtype=int)], axis=0)
|
||||
if is_save_dataset:
|
||||
path = datetime.datetime.now().strftime("dataset_%Y-%m-%d_%H-%M.mat")
|
||||
scipy.io.savemat(path, {'x': data_x, 'y': data_y})
|
||||
x_train, x_val, y_train, y_val = train_test_split(data_x, data_y, train_size=train_size, shuffle=True)
|
||||
|
||||
self.model.fit(x_train, y_train)
|
||||
y_predict = self.model.predict(x_val)
|
||||
print(classification_report(y_true=y_val, y_pred=y_predict))
|
||||
|
||||
def predict(self, x):
|
||||
return self.model.predict(x)
|
||||
"""
|
||||
输入rgb彩色图像
|
||||
|
||||
:param x: rgb彩色图像,np.ndarray
|
||||
:return:
|
||||
"""
|
||||
w, h = x.shape[1], x.shape[0]
|
||||
x = cv2.cvtColor(x, cv2.COLOR_RGB2LAB)
|
||||
result = self.model.predict(x.reshape(w * h, -1))
|
||||
return result.reshape(h, w)
|
||||
|
||||
@staticmethod
|
||||
def generate_negative_samples(x: np.ndarray, world_boundary: np.ndarray, threshold: float, sample_size: int):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user