模型训练

This commit is contained in:
FEIJINTI 2022-07-18 16:55:05 +08:00
parent a4652427c7
commit d612222752
2 changed files with 89 additions and 19 deletions

View File

@ -14,10 +14,74 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 16,
"outputs": [], "outputs": [],
"source": [ "source": [
"from model import AnonymousColorDetector" "import numpy as np\n",
"\n",
"from models import AnonymousColorDetector\n",
"from utils import read_labeled_img"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"## 读取数据与构建数据集"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 17,
"outputs": [],
"source": [
"data_dir = \"data/dataset\"\n",
"color_dict = {(0, 0, 255): \"yangeng\"}\n",
"dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"## 模型训练"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 18,
"outputs": [],
"source": [
"# 定义一些常量\n",
"threshold = 5\n",
"node_num = 20\n",
"negative_sample_num = None # None或者一个数字\n",
"world_boundary = np.array([0, 0, 0, 255, 255, 255])\n",
"# 对数据进行预处理\n",
"x = np.concatenate([v for k, v in dataset.items()], axis=0)\n",
"negative_sample_num = x.shape[0] if negative_sample_num is None else negative_sample_num"
], ],
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
@ -28,9 +92,11 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 19,
"outputs": [], "outputs": [],
"source": [], "source": [
"model = AnonymousColorDetector()"
],
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"pycharm": { "pycharm": {
@ -40,21 +106,25 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 20,
"outputs": [], "outputs": [
"source": [],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{ {
"cell_type": "code", "ename": "KeyboardInterrupt",
"execution_count": null, "evalue": "",
"outputs": [], "output_type": "error",
"source": [], "traceback": [
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[1;31mKeyboardInterrupt\u001B[0m Traceback (most recent call last)",
"Input \u001B[1;32mIn [20]\u001B[0m, in \u001B[0;36m<cell line: 1>\u001B[1;34m()\u001B[0m\n\u001B[1;32m----> 1\u001B[0m \u001B[43mmodel\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mfit\u001B[49m\u001B[43m(\u001B[49m\u001B[43mx\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mworld_boundary\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mthreshold\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mnegative_sample_size\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mnegative_sample_num\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtrain_size\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m0.7\u001B[39;49m\u001B[43m)\u001B[49m\n",
"File \u001B[1;32m~\\PycharmProjects\\tobacco_color\\models.py:34\u001B[0m, in \u001B[0;36mAnonymousColorDetector.fit\u001B[1;34m(self, x, world_boundary, threshold, model_selected, negative_sample_size, train_size, **kwargs)\u001B[0m\n\u001B[0;32m 32\u001B[0m node_num \u001B[38;5;241m=\u001B[39m kwargs\u001B[38;5;241m.\u001B[39mget(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mnode_num\u001B[39m\u001B[38;5;124m'\u001B[39m, \u001B[38;5;241m10\u001B[39m)\n\u001B[0;32m 33\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mmodel \u001B[38;5;241m=\u001B[39m ELM(input_size\u001B[38;5;241m=\u001B[39mx\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m1\u001B[39m], node_num\u001B[38;5;241m=\u001B[39mnode_num, output_num\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m2\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n\u001B[1;32m---> 34\u001B[0m negative_samples \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mgenerate_negative_samples\u001B[49m\u001B[43m(\u001B[49m\u001B[43mx\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mworld_boundary\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mthreshold\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 35\u001B[0m \u001B[43m \u001B[49m\u001B[43msample_size\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mnegative_sample_size\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 36\u001B[0m data_x, data_y \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39mconcatenate([x, negative_samples], axis\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m0\u001B[39m), \\\n\u001B[0;32m 37\u001B[0m np\u001B[38;5;241m.\u001B[39mconcatenate([np\u001B[38;5;241m.\u001B[39mones(x\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m0\u001B[39m], dtype\u001B[38;5;241m=\u001B[39m\u001B[38;5;28mint\u001B[39m),\n\u001B[0;32m 38\u001B[0m np\u001B[38;5;241m.\u001B[39mzeros(negative_samples\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m0\u001B[39m], dtype\u001B[38;5;241m=\u001B[39m\u001B[38;5;28mint\u001B[39m)], axis\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m0\u001B[39m)\n\u001B[0;32m 39\u001B[0m x_train, x_val, y_train, y_val \u001B[38;5;241m=\u001B[39m train_test_split(data_x, data_y, train_size\u001B[38;5;241m=\u001B[39mtrain_size, shuffle\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mTrue\u001B[39;00m)\n",
"File \u001B[1;32m~\\PycharmProjects\\tobacco_color\\models.py:66\u001B[0m, in \u001B[0;36mAnonymousColorDetector.generate_negative_samples\u001B[1;34m(x, world_boundary, threshold, sample_size)\u001B[0m\n\u001B[0;32m 64\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m sample_idx \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mrange\u001B[39m(generated_data\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m0\u001B[39m]):\n\u001B[0;32m 65\u001B[0m sample \u001B[38;5;241m=\u001B[39m generated_data[sample_idx, :]\n\u001B[1;32m---> 66\u001B[0m in_threshold \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39many(np\u001B[38;5;241m.\u001B[39msum(\u001B[43mnp\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mpower\u001B[49m\u001B[43m(\u001B[49m\u001B[43msample\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m-\u001B[39;49m\u001B[43m \u001B[49m\u001B[43mx\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m2\u001B[39;49m\u001B[43m)\u001B[49m, axis\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m1\u001B[39m) \u001B[38;5;241m<\u001B[39m threshold)\n\u001B[0;32m 67\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m in_threshold:\n\u001B[0;32m 68\u001B[0m negative_samples[sample_idx, :] \u001B[38;5;241m=\u001B[39m sample\n",
"\u001B[1;31mKeyboardInterrupt\u001B[0m: "
]
}
],
"source": [
"model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7)"
],
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"pycharm": { "pycharm": {

View File

@ -1,7 +1,7 @@
# -*- codeing = utf-8 -*- # -*- codeing = utf-8 -*-
# Time : 2022/7/18 14:03 # Time : 2022/7/18 14:03
# @Auther : zhouchao # @Auther : zhouchao
# @File: model.py # @File: models.py
# @Software:PyCharm、 # @Software:PyCharm、
import numpy as np import numpy as np
from sklearn.metrics import classification_report from sklearn.metrics import classification_report