mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 14:23:55 +00:00
模型训练保存与加载, 后台版
This commit is contained in:
parent
d612222752
commit
00a46f010b
@ -2,76 +2,76 @@
|
|||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"source": [
|
|
||||||
"# 模型的训练"
|
|
||||||
],
|
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": false,
|
|
||||||
"pycharm": {
|
"pycharm": {
|
||||||
"name": "#%% md\n"
|
"name": "#%% md\n"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"source": [
|
||||||
|
"# 模型的训练"
|
||||||
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 16,
|
"execution_count": 16,
|
||||||
|
"metadata": {
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import numpy as np\n",
|
"import numpy as np\n",
|
||||||
"\n",
|
"\n",
|
||||||
"from models import AnonymousColorDetector\n",
|
"from models import AnonymousColorDetector\n",
|
||||||
"from utils import read_labeled_img"
|
"from utils import read_labeled_img"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"collapsed": false,
|
|
||||||
"pycharm": {
|
|
||||||
"name": "#%%\n"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"source": [
|
|
||||||
"## 读取数据与构建数据集"
|
|
||||||
],
|
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": false,
|
|
||||||
"pycharm": {
|
"pycharm": {
|
||||||
"name": "#%% md\n"
|
"name": "#%% md\n"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"source": [
|
||||||
|
"## 读取数据与构建数据集"
|
||||||
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 17,
|
"execution_count": 17,
|
||||||
|
"metadata": {
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"data_dir = \"data/dataset\"\n",
|
"data_dir = \"data/dataset\"\n",
|
||||||
"color_dict = {(0, 0, 255): \"yangeng\"}\n",
|
"color_dict = {(0, 0, 255): \"yangeng\"}\n",
|
||||||
"dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)"
|
"dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"collapsed": false,
|
|
||||||
"pycharm": {
|
|
||||||
"name": "#%%\n"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
"source": [
|
|
||||||
"## 模型训练"
|
|
||||||
],
|
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": false,
|
|
||||||
"pycharm": {
|
"pycharm": {
|
||||||
"name": "#%% md\n"
|
"name": "#%% md\n"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"source": [
|
||||||
|
"## 模型训练"
|
||||||
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 18,
|
"execution_count": 18,
|
||||||
|
"metadata": {
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"# 定义一些常量\n",
|
"# 定义一些常量\n",
|
||||||
@ -82,31 +82,29 @@
|
|||||||
"# 对数据进行预处理\n",
|
"# 对数据进行预处理\n",
|
||||||
"x = np.concatenate([v for k, v in dataset.items()], axis=0)\n",
|
"x = np.concatenate([v for k, v in dataset.items()], axis=0)\n",
|
||||||
"negative_sample_num = x.shape[0] if negative_sample_num is None else negative_sample_num"
|
"negative_sample_num = x.shape[0] if negative_sample_num is None else negative_sample_num"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"collapsed": false,
|
|
||||||
"pycharm": {
|
|
||||||
"name": "#%%\n"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 19,
|
"execution_count": 19,
|
||||||
"outputs": [],
|
|
||||||
"source": [
|
|
||||||
"model = AnonymousColorDetector()"
|
|
||||||
],
|
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": false,
|
|
||||||
"pycharm": {
|
"pycharm": {
|
||||||
"name": "#%%\n"
|
"name": "#%%\n"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"model = AnonymousColorDetector()"
|
||||||
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 20,
|
"execution_count": 20,
|
||||||
|
"metadata": {
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
},
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"ename": "KeyboardInterrupt",
|
"ename": "KeyboardInterrupt",
|
||||||
@ -124,34 +122,28 @@
|
|||||||
],
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7)"
|
"model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7)"
|
||||||
],
|
]
|
||||||
"metadata": {
|
|
||||||
"collapsed": false,
|
|
||||||
"pycharm": {
|
|
||||||
"name": "#%%\n"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"kernelspec": {
|
"kernelspec": {
|
||||||
"display_name": "Python 3",
|
"display_name": "Python 3 (ipykernel)",
|
||||||
"language": "python",
|
"language": "python",
|
||||||
"name": "python3"
|
"name": "python3"
|
||||||
},
|
},
|
||||||
"language_info": {
|
"language_info": {
|
||||||
"codemirror_mode": {
|
"codemirror_mode": {
|
||||||
"name": "ipython",
|
"name": "ipython",
|
||||||
"version": 2
|
"version": 3
|
||||||
},
|
},
|
||||||
"file_extension": ".py",
|
"file_extension": ".py",
|
||||||
"mimetype": "text/x-python",
|
"mimetype": "text/x-python",
|
||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython2",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "2.7.6"
|
"version": "3.10.0"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
"nbformat_minor": 0
|
"nbformat_minor": 1
|
||||||
}
|
}
|
||||||
41
02_classification.py
Normal file
41
02_classification.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding: utf-8
|
||||||
|
|
||||||
|
# # 模型的训练
|
||||||
|
|
||||||
|
# In[16]:
|
||||||
|
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from models import AnonymousColorDetector
|
||||||
|
from utils import read_labeled_img
|
||||||
|
|
||||||
|
# ## 读取数据与构建数据集
|
||||||
|
|
||||||
|
# In[17]:
|
||||||
|
|
||||||
|
|
||||||
|
data_dir = "data/dataset"
|
||||||
|
color_dict = {(0, 0, 255): "yangeng"}
|
||||||
|
dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)
|
||||||
|
|
||||||
|
# ## 模型训练
|
||||||
|
|
||||||
|
# In[18]:
|
||||||
|
|
||||||
|
|
||||||
|
# 定义一些常量
|
||||||
|
threshold = 5
|
||||||
|
node_num = 20
|
||||||
|
negative_sample_num = None # None或者一个数字
|
||||||
|
world_boundary = np.array([0, 0, 0, 255, 255, 255])
|
||||||
|
# 对数据进行预处理
|
||||||
|
x = np.concatenate([v for k, v in dataset.items()], axis=0)
|
||||||
|
negative_sample_num = int(x.shape[0] * 0.7) if negative_sample_num is None else negative_sample_num
|
||||||
|
|
||||||
|
model = AnonymousColorDetector()
|
||||||
|
|
||||||
|
model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7)
|
||||||
|
|
||||||
|
model.save()
|
||||||
@ -11,7 +11,7 @@ from elm import ELM
|
|||||||
|
|
||||||
|
|
||||||
class AnonymousColorDetector(object):
|
class AnonymousColorDetector(object):
|
||||||
def __init__(self):
|
def __init__(self, file_path=None):
|
||||||
self.model = None
|
self.model = None
|
||||||
|
|
||||||
def fit(self, x: np.ndarray, world_boundary: np.ndarray, threshold: float, model_selected: str = 'elm',
|
def fit(self, x: np.ndarray, world_boundary: np.ndarray, threshold: float, model_selected: str = 'elm',
|
||||||
@ -71,9 +71,16 @@ class AnonymousColorDetector(object):
|
|||||||
break
|
break
|
||||||
return negative_samples
|
return negative_samples
|
||||||
|
|
||||||
|
def save(self, file_path=None):
|
||||||
|
self.model.save(file_path)
|
||||||
|
|
||||||
|
def load(self, file_path):
|
||||||
|
self.model.load(file_path)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
detector = AnonymousColorDetector()
|
detector = AnonymousColorDetector()
|
||||||
x = np.array([[10, 30, 20], [10, 35, 25], [10, 35, 36]])
|
x = np.array([[10, 30, 20], [10, 35, 25], [10, 35, 36]])
|
||||||
world_boundary = np.array([0, -127, -127, 100, 127, 127])
|
world_boundary = np.array([0, -127, -127, 100, 127, 127])
|
||||||
detector.fit(x, world_boundary, threshold=5, negative_sample_size=2000)
|
detector.fit(x, world_boundary, threshold=5, negative_sample_size=2000)
|
||||||
|
detector.load('ELM_2022-07-18_17-01.mat')
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user