{ "cells": [ { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "# 模型的训练" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "import numpy as np\n", "import scipy\n", "from imblearn.under_sampling import RandomUnderSampler\n", "from models import AnonymousColorDetector\n", "from utils import read_labeled_img" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## 读取数据与构建数据集" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "data_dir = \"data/dataset\"\n", "color_dict = {(0, 0, 255): \"yangeng\", (255, 0, 0): 'beijing'}\n", "label_index = {\"yangeng\": 1, \"beijing\": 0}\n", "dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)\n", "rus = RandomUnderSampler(random_state=0)\n", "x_list, y_list = np.concatenate([v for k, v in dataset.items()], axis=0).tolist(), \\\n", " np.concatenate([np.ones((v.shape[0],)) * label_index[k] for k, v in dataset.items()], axis=0).tolist()\n", "\n", "x_resampled, y_resampled = rus.fit_resample(x_list, y_list)\n", "dataset = {\"inside\": np.array(x_resampled)}" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "name": "#%% md\n" } }, "source": [ "## 模型训练" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "# 定义一些常量\n", "threshold = 5\n", "node_num = 20\n", "negative_sample_num = None # None或者一个数字\n", "world_boundary = np.array([0, 0, 0, 255, 255, 255])\n", "# 对数据进行预处理\n", "x = np.concatenate([v for k, v in dataset.items()], axis=0)\n", "negative_sample_num = int(x.shape[0] * 1.2) if negative_sample_num is None else negative_sample_num\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [], "source": [ "model = AnonymousColorDetector()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "pycharm": { "name": "#%%\n" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " 0.0 0.99 0.99 0.99 26314\n", " 1.0 0.99 0.99 0.99 24492\n", "\n", " accuracy 0.99 50806\n", " macro avg 0.99 0.99 0.99 50806\n", "weighted avg 0.99 0.99 0.99 50806\n", "\n" ] } ], "source": [ "# model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7,\n", "# is_save_dataset=True, model_selection='dt')\n", "data = scipy.io.loadmat('data/dataset/dataset_2022-07-19_17-06.mat')\n", "x, y = data['x'], data['y'].ravel()\n", "model.fit(x, y=y, is_generate_negative=False, model_selection='dt')\n", "model.save()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.0" } }, "nbformat": 4, "nbformat_minor": 1 }