{ "cells": [ { "cell_type": "markdown", "source": [], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ "import numpy as np\n", "import scipy\n", "from imblearn.under_sampling import RandomUnderSampler\n", "from models import AnonymousColorDetector\n", "from utils import read_labeled_img" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ "train_from_existed = True # 是否从现有数据训练,如果是的话,那就从dataset_file训练,否则就用data_dir里头的数据\n", "data_dir = \"data/dataset\" # 数据集,文件夹下必须包含`img`和`label`两个文件夹,放置相同文件名的图片和label\n", "dataset_file = \"data/dataset/dataset_2022-07-20_10-04.mat\"\n", "\n", "# color_dict = {(0, 0, 255): \"yangeng\", (255, 0, 0): 'beijing'} # 颜色对应的类别\n", "color_dict = {(0, 0, 255): \"yangeng\"}\n", "color_dict = {(255, 0, 0): 'beijing'}\n", "color_dict = {(0, 255, 0): \"zibian\"}\n", "label_index = {\"yangeng\": 1, \"beijing\": 0, \"zibian\":2} # 类别对应的序号\n", "show_samples = False # 是否展示样本\n", "\n", "# 定义一些训练量\n", "threshold = 5 # 正样本周围多大范围内的还算是正样本\n", "node_num = 20 # 如果使用ELM作为分类器物,有多少的节点\n", "negative_sample_num = None # None或者一个数字,对应生成的负样本数量" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "## 读取数据" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ "dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)\n", "if show_samples:\n", " from utils import lab_scatter\n", " lab_scatter(dataset, class_max_num=30000, is_3d=True, is_ps_color_space=False)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "## 数据平衡化" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ "rus = RandomUnderSampler(random_state=0)\n", "x_list, y_list = np.concatenate([v for k, v in dataset.items()], axis=0).tolist(), \\\n", " np.concatenate([np.ones((v.shape[0],)) * label_index[k] for k, v in dataset.items()], axis=0).tolist()\n", "x_resampled, y_resampled = rus.fit_resample(x_list, y_list)\n", "dataset = {\"inside\": np.array(x_resampled)}" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "## 模型训练" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ "# 对数据进行预处理\n", "x = np.concatenate([v for k, v in dataset.items()], axis=0)\n", "negative_sample_num = int(x.shape[0] * 1.2) if negative_sample_num is None else negative_sample_num\n", "model = AnonymousColorDetector()" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [ "if train_from_existed:\n", " data = scipy.io.loadmat(dataset_file)\n", " x, y = data['x'], data['y'].ravel()\n", " model.fit(x, y=y, is_generate_negative=False, model_selection='dt')\n", "else:\n", " world_boundary = np.array([0, 0, 0, 255, 255, 255])\n", " model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7,\n", " is_save_dataset=True, model_selection='dt')\n", "model.save()" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 14, "outputs": [ { "ename": "NameError", "evalue": "name 'train_from_existed' is not defined", "output_type": "error", "traceback": [ "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", "\u001B[0;31mNameError\u001B[0m Traceback (most recent call last)", "\u001B[0;32m/var/folders/wh/kr5c3dr12834pfk3j7yqnrq40000gn/T/ipykernel_30867/103680055.py\u001B[0m in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[0;32m----> 1\u001B[0;31m \u001B[0;32mif\u001B[0m \u001B[0mtrain_from_existed\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 2\u001B[0m \u001B[0mdata\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mscipy\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mio\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mloadmat\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mdataset_file\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 3\u001B[0m \u001B[0mx\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0my\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mdata\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;34m'x'\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mdata\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;34m'y'\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mravel\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 4\u001B[0m \u001B[0mmodel\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mfit\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0my\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0my\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mis_generate_negative\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;32mFalse\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mmodel_selection\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;34m'dt'\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 5\u001B[0m \u001B[0mmodel\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0msave\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", "\u001B[0;31mNameError\u001B[0m: name 'train_from_existed' is not defined" ] } ], "source": [ "if train_from_existed:\n", " data = scipy.io.loadmat(dataset_file)\n", " x, y = data['x'], data['y'].ravel()\n", " model.fit(x, y=y, is_generate_negative=False, model_selection='dt')\n", "else:\n", " world_boundary = np.array([0, 0, 0, 255, 255, 255])\n", " model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7,\n", " is_save_dataset=True, model_selection='dt')\n", "model.save()" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.0" } }, "nbformat": 4, "nbformat_minor": 1 }