模型训练保存与加载, 后台版

This commit is contained in:
FEIJINTI 2022-07-18 17:04:54 +08:00
parent d612222752
commit 00a46f010b
3 changed files with 95 additions and 55 deletions

View File

@ -2,76 +2,76 @@
"cells": [ "cells": [
{ {
"cell_type": "markdown", "cell_type": "markdown",
"source": [
"# 模型的训练"
],
"metadata": { "metadata": {
"collapsed": false,
"pycharm": { "pycharm": {
"name": "#%% md\n" "name": "#%% md\n"
} }
} },
"source": [
"# 模型的训练"
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 16, "execution_count": 16,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [], "outputs": [],
"source": [ "source": [
"import numpy as np\n", "import numpy as np\n",
"\n", "\n",
"from models import AnonymousColorDetector\n", "from models import AnonymousColorDetector\n",
"from utils import read_labeled_img" "from utils import read_labeled_img"
], ]
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"source": [
"## 读取数据与构建数据集"
],
"metadata": { "metadata": {
"collapsed": false,
"pycharm": { "pycharm": {
"name": "#%% md\n" "name": "#%% md\n"
} }
} },
"source": [
"## 读取数据与构建数据集"
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 17, "execution_count": 17,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [], "outputs": [],
"source": [ "source": [
"data_dir = \"data/dataset\"\n", "data_dir = \"data/dataset\"\n",
"color_dict = {(0, 0, 255): \"yangeng\"}\n", "color_dict = {(0, 0, 255): \"yangeng\"}\n",
"dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)" "dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)"
], ]
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"source": [
"## 模型训练"
],
"metadata": { "metadata": {
"collapsed": false,
"pycharm": { "pycharm": {
"name": "#%% md\n" "name": "#%% md\n"
} }
} },
"source": [
"## 模型训练"
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 18, "execution_count": 18,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [], "outputs": [],
"source": [ "source": [
"# 定义一些常量\n", "# 定义一些常量\n",
@ -82,31 +82,29 @@
"# 对数据进行预处理\n", "# 对数据进行预处理\n",
"x = np.concatenate([v for k, v in dataset.items()], axis=0)\n", "x = np.concatenate([v for k, v in dataset.items()], axis=0)\n",
"negative_sample_num = x.shape[0] if negative_sample_num is None else negative_sample_num" "negative_sample_num = x.shape[0] if negative_sample_num is None else negative_sample_num"
], ]
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 19, "execution_count": 19,
"outputs": [],
"source": [
"model = AnonymousColorDetector()"
],
"metadata": { "metadata": {
"collapsed": false,
"pycharm": { "pycharm": {
"name": "#%%\n" "name": "#%%\n"
} }
} },
"outputs": [],
"source": [
"model = AnonymousColorDetector()"
]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 20, "execution_count": 20,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [ "outputs": [
{ {
"ename": "KeyboardInterrupt", "ename": "KeyboardInterrupt",
@ -124,34 +122,28 @@
], ],
"source": [ "source": [
"model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7)" "model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7)"
], ]
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
} }
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },
"language_info": { "language_info": {
"codemirror_mode": { "codemirror_mode": {
"name": "ipython", "name": "ipython",
"version": 2 "version": 3
}, },
"file_extension": ".py", "file_extension": ".py",
"mimetype": "text/x-python", "mimetype": "text/x-python",
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython2", "pygments_lexer": "ipython3",
"version": "2.7.6" "version": "3.10.0"
} }
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 0 "nbformat_minor": 1
} }

41
02_classification.py Normal file
View File

@ -0,0 +1,41 @@
#!/usr/bin/env python
# coding: utf-8
# # 模型的训练
# In[16]:
import numpy as np
from models import AnonymousColorDetector
from utils import read_labeled_img
# ## 读取数据与构建数据集
# In[17]:
data_dir = "data/dataset"
color_dict = {(0, 0, 255): "yangeng"}
dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)
# ## 模型训练
# In[18]:
# 定义一些常量
threshold = 5
node_num = 20
negative_sample_num = None # None或者一个数字
world_boundary = np.array([0, 0, 0, 255, 255, 255])
# 对数据进行预处理
x = np.concatenate([v for k, v in dataset.items()], axis=0)
negative_sample_num = int(x.shape[0] * 0.7) if negative_sample_num is None else negative_sample_num
model = AnonymousColorDetector()
model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7)
model.save()

View File

@ -11,7 +11,7 @@ from elm import ELM
class AnonymousColorDetector(object): class AnonymousColorDetector(object):
def __init__(self): def __init__(self, file_path=None):
self.model = None self.model = None
def fit(self, x: np.ndarray, world_boundary: np.ndarray, threshold: float, model_selected: str = 'elm', def fit(self, x: np.ndarray, world_boundary: np.ndarray, threshold: float, model_selected: str = 'elm',
@ -71,9 +71,16 @@ class AnonymousColorDetector(object):
break break
return negative_samples return negative_samples
def save(self, file_path=None):
self.model.save(file_path)
def load(self, file_path):
self.model.load(file_path)
if __name__ == '__main__': if __name__ == '__main__':
detector = AnonymousColorDetector() detector = AnonymousColorDetector()
x = np.array([[10, 30, 20], [10, 35, 25], [10, 35, 36]]) x = np.array([[10, 30, 20], [10, 35, 25], [10, 35, 36]])
world_boundary = np.array([0, -127, -127, 100, 127, 127]) world_boundary = np.array([0, -127, -127, 100, 127, 127])
detector.fit(x, world_boundary, threshold=5, negative_sample_size=2000) detector.fit(x, world_boundary, threshold=5, negative_sample_size=2000)
detector.load('ELM_2022-07-18_17-01.mat')