From 73973216f62ea2a7027e9f0a680a04022cf27331 Mon Sep 17 00:00:00 2001 From: FEIJINTI <83849113+FEIJINTI@users.noreply.github.com> Date: Tue, 19 Jul 2022 11:24:53 +0800 Subject: [PATCH] =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 2 +- elm.py | 6 ++++++ main_test.py | 47 +++++++++++++++++++++++++++++++++++++++++++ models.py | 57 +++++++++++++++++++++++++++++++++++++++++----------- 4 files changed, 99 insertions(+), 13 deletions(-) create mode 100644 main_test.py diff --git a/.gitignore b/.gitignore index 501d2c9..0d32a3e 100644 --- a/.gitignore +++ b/.gitignore @@ -3,7 +3,7 @@ # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 data/* - +models/* # User-specific stuff .idea/**/workspace.xml .idea/**/tasks.xml diff --git a/elm.py b/elm.py index 87f890d..f35036a 100644 --- a/elm.py +++ b/elm.py @@ -36,10 +36,16 @@ class ELM: :param bias: [array] shape: 1 x node_num :param beta: [array] shape: node_num, output_num :param rand_seed: [int] the random seed + :param model_path: [str] the trained model path """ if rand_seed is not None: np.random.seed(rand_seed) + if 'model_path' in kwargs: + data = scipy.io.loadmat(kwargs['model_path']) + self.w, self.b, self.beta = data['w'], data['b'], data['beta'] + return + if weight is not None: self.w = weight else: diff --git a/main_test.py b/main_test.py new file mode 100644 index 0000000..6626e82 --- /dev/null +++ b/main_test.py @@ -0,0 +1,47 @@ +# -*- codeing = utf-8 -*- +# Time : 2022/7/19 10:49 +# @Auther : zhouchao +# @File: main_test.py +# @Software:PyCharm +import time + +import cv2 +import matplotlib.pyplot as plt +import numpy as np + +from models import Detector, AnonymousColorDetector + + +def virtual_main(detector: Detector, test_img=None, test_img_dir=None): + """ + 虚拟读图测试程序 + + :param detector: 杂质探测器,需要继承Detector类 + :param test_img: 测试图像,rgb格式的图片或者路径 + :param test_img_dir: 测试图像文件夹 + :return: + """ + if test_img is not None: + if isinstance(test_img, str): + img = cv2.imread(test_img)[:, :, ::-1] + elif isinstance(test_img, np.ndarray): + img = test_img + else: + raise TypeError("test img should be np.ndarray or str") + t1 = time.time() + result = detector.predict(img) + t2 = time.time() + fig, axs = plt.subplots(3, 1) + axs[0].imshow(img) + axs[1].imshow(result) + mask_color = np.zeros_like(img) + mask_color[result > 0] = (0, 0, 255) + result_show = cv2.addWeighted(img, 1, mask_color, 0.5, 0) + axs[2].imshow(result_show) + plt.title(f'{(t2 - t1) * 1000:.2f} ms') + plt.show() + + +if __name__ == '__main__': + detector = AnonymousColorDetector(file_path='models/ELM_2022-07-18_17-22.mat') + virtual_main(detector, test_img='data/dataset/img/yangeng.bmp') diff --git a/models.py b/models.py index 17b6de4..99ed01a 100644 --- a/models.py +++ b/models.py @@ -3,47 +3,80 @@ # @Auther : zhouchao # @File: models.py # @Software:PyCharm、 +import datetime + +import cv2 import numpy as np +import scipy.io from sklearn.metrics import classification_report from sklearn.model_selection import train_test_split from elm import ELM -class AnonymousColorDetector(object): - def __init__(self, file_path=None): - self.model = None +class Detector(object): + def __int__(self, *args, **kwargs): + raise NotImplementedError - def fit(self, x: np.ndarray, world_boundary: np.ndarray, threshold: float, model_selected: str = 'elm', - negative_sample_size: int = 1000, train_size: float = 0.8, **kwargs): + def predict(self, *args, **kwargs): + raise NotImplementedError + + def load(self, *args, **kwargs): + raise NotImplementedError + + def save(self, *args, **kwargs): + raise NotImplementedError + + def fit(self, *args, **kwargs): + raise NotImplementedError + + +class AnonymousColorDetector(Detector): + def __init__(self, file_path: str = None): + self.model = None + if file_path is not None: + self.model = ELM(model_path=file_path) + + def fit(self, x: np.ndarray, world_boundary: np.ndarray, threshold: float, + negative_sample_size: int = 1000, train_size: float = 0.8, is_save_dataset=False, **kwargs): """ 拟合到指定的样本分布情况下,根据x进行分布的变化。 :param x: ndarray类型的正样本数据,给出的正样本形状为 n x feature_num :param world_boundary: 整个世界的边界,边界形状为 feature_num个下限, feature_num个上限 :param threshold: 与正样本之间的距离阈值大于多少则不认为是指定的样本类别 - :param model_selected: 选择模型,默认为elm :param negative_sample_size: 负样本的数量 + :param train_size: 训练集的比例, float + :param is_save_dataset: 是否保存数据集 :param kwargs: 与模型相对应的参数 :return: """ - assert model_selected in ['elm'] - if model_selected == 'elm': - node_num = kwargs.get('node_num', 10) - self.model = ELM(input_size=x.shape[1], node_num=node_num, output_num=2, **kwargs) + node_num = kwargs.get('node_num', 10) + self.model = ELM(input_size=x.shape[1], node_num=node_num, output_num=2, **kwargs) negative_samples = self.generate_negative_samples(x, world_boundary, threshold, sample_size=negative_sample_size) data_x, data_y = np.concatenate([x, negative_samples], axis=0), \ np.concatenate([np.ones(x.shape[0], dtype=int), np.zeros(negative_samples.shape[0], dtype=int)], axis=0) + if is_save_dataset: + path = datetime.datetime.now().strftime("dataset_%Y-%m-%d_%H-%M.mat") + scipy.io.savemat(path, {'x': data_x, 'y': data_y}) x_train, x_val, y_train, y_val = train_test_split(data_x, data_y, train_size=train_size, shuffle=True) - self.model.fit(x_train, y_train) y_predict = self.model.predict(x_val) print(classification_report(y_true=y_val, y_pred=y_predict)) def predict(self, x): - return self.model.predict(x) + """ + 输入rgb彩色图像 + + :param x: rgb彩色图像,np.ndarray + :return: + """ + w, h = x.shape[1], x.shape[0] + x = cv2.cvtColor(x, cv2.COLOR_RGB2LAB) + result = self.model.predict(x.reshape(w * h, -1)) + return result.reshape(h, w) @staticmethod def generate_negative_samples(x: np.ndarray, world_boundary: np.ndarray, threshold: float, sample_size: int):