{ "cells": [ { "cell_type": "markdown", "source": [ "# 模型的训练" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 16, "outputs": [], "source": [ "import numpy as np\n", "\n", "from models import AnonymousColorDetector\n", "from utils import read_labeled_img" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "## 读取数据与构建数据集" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 17, "outputs": [], "source": [ "data_dir = \"data/dataset\"\n", "color_dict = {(0, 0, 255): \"yangeng\"}\n", "dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "## 模型训练" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 18, "outputs": [], "source": [ "# 定义一些常量\n", "threshold = 5\n", "node_num = 20\n", "negative_sample_num = None # None或者一个数字\n", "world_boundary = np.array([0, 0, 0, 255, 255, 255])\n", "# 对数据进行预处理\n", "x = np.concatenate([v for k, v in dataset.items()], axis=0)\n", "negative_sample_num = x.shape[0] if negative_sample_num is None else negative_sample_num" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 19, "outputs": [], "source": [ "model = AnonymousColorDetector()" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 20, "outputs": [ { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m", "\u001B[1;31mKeyboardInterrupt\u001B[0m Traceback (most recent call last)", "Input \u001B[1;32mIn [20]\u001B[0m, in \u001B[0;36m\u001B[1;34m()\u001B[0m\n\u001B[1;32m----> 1\u001B[0m \u001B[43mmodel\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mfit\u001B[49m\u001B[43m(\u001B[49m\u001B[43mx\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mworld_boundary\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mthreshold\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mnegative_sample_size\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mnegative_sample_num\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtrain_size\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m0.7\u001B[39;49m\u001B[43m)\u001B[49m\n", "File \u001B[1;32m~\\PycharmProjects\\tobacco_color\\models.py:34\u001B[0m, in \u001B[0;36mAnonymousColorDetector.fit\u001B[1;34m(self, x, world_boundary, threshold, model_selected, negative_sample_size, train_size, **kwargs)\u001B[0m\n\u001B[0;32m 32\u001B[0m node_num \u001B[38;5;241m=\u001B[39m kwargs\u001B[38;5;241m.\u001B[39mget(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mnode_num\u001B[39m\u001B[38;5;124m'\u001B[39m, \u001B[38;5;241m10\u001B[39m)\n\u001B[0;32m 33\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mmodel \u001B[38;5;241m=\u001B[39m ELM(input_size\u001B[38;5;241m=\u001B[39mx\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m1\u001B[39m], node_num\u001B[38;5;241m=\u001B[39mnode_num, output_num\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m2\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n\u001B[1;32m---> 34\u001B[0m negative_samples \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mgenerate_negative_samples\u001B[49m\u001B[43m(\u001B[49m\u001B[43mx\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mworld_boundary\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mthreshold\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 35\u001B[0m \u001B[43m \u001B[49m\u001B[43msample_size\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mnegative_sample_size\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 36\u001B[0m data_x, data_y \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39mconcatenate([x, negative_samples], axis\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m0\u001B[39m), \\\n\u001B[0;32m 37\u001B[0m np\u001B[38;5;241m.\u001B[39mconcatenate([np\u001B[38;5;241m.\u001B[39mones(x\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m0\u001B[39m], dtype\u001B[38;5;241m=\u001B[39m\u001B[38;5;28mint\u001B[39m),\n\u001B[0;32m 38\u001B[0m np\u001B[38;5;241m.\u001B[39mzeros(negative_samples\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m0\u001B[39m], dtype\u001B[38;5;241m=\u001B[39m\u001B[38;5;28mint\u001B[39m)], axis\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m0\u001B[39m)\n\u001B[0;32m 39\u001B[0m x_train, x_val, y_train, y_val \u001B[38;5;241m=\u001B[39m train_test_split(data_x, data_y, train_size\u001B[38;5;241m=\u001B[39mtrain_size, shuffle\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mTrue\u001B[39;00m)\n", "File \u001B[1;32m~\\PycharmProjects\\tobacco_color\\models.py:66\u001B[0m, in \u001B[0;36mAnonymousColorDetector.generate_negative_samples\u001B[1;34m(x, world_boundary, threshold, sample_size)\u001B[0m\n\u001B[0;32m 64\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m sample_idx \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mrange\u001B[39m(generated_data\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m0\u001B[39m]):\n\u001B[0;32m 65\u001B[0m sample \u001B[38;5;241m=\u001B[39m generated_data[sample_idx, :]\n\u001B[1;32m---> 66\u001B[0m in_threshold \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39many(np\u001B[38;5;241m.\u001B[39msum(\u001B[43mnp\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mpower\u001B[49m\u001B[43m(\u001B[49m\u001B[43msample\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m-\u001B[39;49m\u001B[43m \u001B[49m\u001B[43mx\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m2\u001B[39;49m\u001B[43m)\u001B[49m, axis\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m1\u001B[39m) \u001B[38;5;241m<\u001B[39m threshold)\n\u001B[0;32m 67\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m in_threshold:\n\u001B[0;32m 68\u001B[0m negative_samples[sample_idx, :] \u001B[38;5;241m=\u001B[39m sample\n", "\u001B[1;31mKeyboardInterrupt\u001B[0m: " ] } ], "source": [ "model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 0 }