From 00a46f010bac711fcc6759628bfef221548dda3d Mon Sep 17 00:00:00 2001 From: FEIJINTI <83849113+FEIJINTI@users.noreply.github.com> Date: Mon, 18 Jul 2022 17:04:54 +0800 Subject: [PATCH] =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E8=AE=AD=E7=BB=83=E4=BF=9D?= =?UTF-8?q?=E5=AD=98=E4=B8=8E=E5=8A=A0=E8=BD=BD,=20=E5=90=8E=E5=8F=B0?= =?UTF-8?q?=E7=89=88?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- 02_classification.ipynb | 100 ++++++++++++++++++---------------------- 02_classification.py | 41 ++++++++++++++++ models.py | 9 +++- 3 files changed, 95 insertions(+), 55 deletions(-) create mode 100644 02_classification.py diff --git a/02_classification.ipynb b/02_classification.ipynb index 33b8d5b..2b3812e 100644 --- a/02_classification.ipynb +++ b/02_classification.ipynb @@ -2,76 +2,76 @@ "cells": [ { "cell_type": "markdown", - "source": [ - "# 模型的训练" - ], "metadata": { - "collapsed": false, "pycharm": { "name": "#%% md\n" } - } + }, + "source": [ + "# 模型的训练" + ] }, { "cell_type": "code", "execution_count": 16, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "import numpy as np\n", "\n", "from models import AnonymousColorDetector\n", "from utils import read_labeled_img" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } + ] }, { "cell_type": "markdown", - "source": [ - "## 读取数据与构建数据集" - ], "metadata": { - "collapsed": false, "pycharm": { "name": "#%% md\n" } - } + }, + "source": [ + "## 读取数据与构建数据集" + ] }, { "cell_type": "code", "execution_count": 17, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "data_dir = \"data/dataset\"\n", "color_dict = {(0, 0, 255): \"yangeng\"}\n", "dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } + ] }, { "cell_type": "markdown", - "source": [ - "## 模型训练" - ], "metadata": { - "collapsed": false, "pycharm": { "name": "#%% md\n" } - } + }, + "source": [ + "## 模型训练" + ] }, { "cell_type": "code", "execution_count": 18, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [], "source": [ "# 定义一些常量\n", @@ -82,31 +82,29 @@ "# 对数据进行预处理\n", "x = np.concatenate([v for k, v in dataset.items()], axis=0)\n", "negative_sample_num = x.shape[0] if negative_sample_num is None else negative_sample_num" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } + ] }, { "cell_type": "code", "execution_count": 19, - "outputs": [], - "source": [ - "model = AnonymousColorDetector()" - ], "metadata": { - "collapsed": false, "pycharm": { "name": "#%%\n" } - } + }, + "outputs": [], + "source": [ + "model = AnonymousColorDetector()" + ] }, { "cell_type": "code", "execution_count": 20, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, "outputs": [ { "ename": "KeyboardInterrupt", @@ -124,34 +122,28 @@ ], "source": [ "model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7)" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } + ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", - "version": 2 + "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", - "pygments_lexer": "ipython2", - "version": "2.7.6" + "pygments_lexer": "ipython3", + "version": "3.10.0" } }, "nbformat": 4, - "nbformat_minor": 0 + "nbformat_minor": 1 } \ No newline at end of file diff --git a/02_classification.py b/02_classification.py new file mode 100644 index 0000000..fc0aff8 --- /dev/null +++ b/02_classification.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python +# coding: utf-8 + +# # 模型的训练 + +# In[16]: + + +import numpy as np + +from models import AnonymousColorDetector +from utils import read_labeled_img + +# ## 读取数据与构建数据集 + +# In[17]: + + +data_dir = "data/dataset" +color_dict = {(0, 0, 255): "yangeng"} +dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False) + +# ## 模型训练 + +# In[18]: + + +# 定义一些常量 +threshold = 5 +node_num = 20 +negative_sample_num = None # None或者一个数字 +world_boundary = np.array([0, 0, 0, 255, 255, 255]) +# 对数据进行预处理 +x = np.concatenate([v for k, v in dataset.items()], axis=0) +negative_sample_num = int(x.shape[0] * 0.7) if negative_sample_num is None else negative_sample_num + +model = AnonymousColorDetector() + +model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7) + +model.save() diff --git a/models.py b/models.py index 5318b2b..17b6de4 100644 --- a/models.py +++ b/models.py @@ -11,7 +11,7 @@ from elm import ELM class AnonymousColorDetector(object): - def __init__(self): + def __init__(self, file_path=None): self.model = None def fit(self, x: np.ndarray, world_boundary: np.ndarray, threshold: float, model_selected: str = 'elm', @@ -71,9 +71,16 @@ class AnonymousColorDetector(object): break return negative_samples + def save(self, file_path=None): + self.model.save(file_path) + + def load(self, file_path): + self.model.load(file_path) + if __name__ == '__main__': detector = AnonymousColorDetector() x = np.array([[10, 30, 20], [10, 35, 25], [10, 35, 36]]) world_boundary = np.array([0, -127, -127, 100, 127, 127]) detector.fit(x, world_boundary, threshold=5, negative_sample_size=2000) + detector.load('ELM_2022-07-18_17-01.mat')