{ "cells": [ { "cell_type": "markdown", "source": [ "# 训练像素模型\n", "用这个文件可以训练出需要使用的光谱像素点模型" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 3, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "C:\\Users\\FEIJINTI\\miniconda3\\envs\\deepo\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "import numpy as np\n", "import pickle\n", "from utils import read_envi_ascii\n", "from config import Config\n", "from models import ManualTree" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "# 一些变量" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 4, "outputs": [], "source": [ "data_path = r'data/envi20220802.txt'\n", "name_dict = {'tobacco': 1, 'yantou':2, 'kazhi':3, 'bomo':4, 'jiaodai':5, 'background':0}" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "# 构建数据集" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 5, "outputs": [], "source": [ "data = read_envi_ascii(data_path)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 6, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "zibian (569, 448)\n", "tobacco (1457, 448)\n", "yantou (354, 448)\n", "kazhi (449, 448)\n", "bomo (1154, 448)\n", "jiaodai (566, 448)\n", "background (1235, 448)\n" ] } ], "source": [ "_ = [print(class_name, d.shape) for class_name, d in data.items()]" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 7, "outputs": [], "source": [ "data_x = [d for class_name, d in data.items() if class_name in name_dict.keys()]\n", "data_y = [np.ones((d.shape[0], ))*name_dict[class_name] for class_name, d in data.items() if class_name in name_dict.keys()]\n", "data_x, data_y = np.concatenate(data_x), np.concatenate(data_y)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "## 取出需要的22个特征谱段" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 8, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "这些是现在的数据: (5215, 448) (5215,)\n", "截取其中需要的部分后: (5215, 22) (5215,)\n" ] } ], "source": [ "print(\"这些是现在的数据: \", data_x.shape, data_y.shape)\n", "data_x_cut = data_x[..., Config.bands]\n", "print(\"截取其中需要的部分后: \", data_x_cut.shape, data_y.shape)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "## 进行样本平衡" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 9, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "这是重采样后的数据: (8742, 22) (8742,)\n" ] } ], "source": [ "from imblearn.over_sampling import RandomOverSampler\n", "ros = RandomOverSampler(random_state=0)\n", "x_resampled, y_resampled = ros.fit_resample(data_x_cut, data_y)\n", "print('这是重采样后的数据: ', x_resampled.shape, y_resampled.shape)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "# 进行模型训练\n", "分出一部分数据进行训练" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 10, "outputs": [], "source": [ "from models import DecisionTree\n", "from sklearn.model_selection import train_test_split" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 11, "outputs": [], "source": [ "train_x, test_x, train_y, test_y = train_test_split(x_resampled, y_resampled, test_size=0.2)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 12, "outputs": [], "source": [ "tree = DecisionTree(class_weight={1:20})" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 13, "outputs": [], "source": [ "tree = tree.fit(train_x, train_y)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "# 模型评估" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "markdown", "source": [ "## 多分类精度" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 14, "outputs": [], "source": [ "pred_y = tree.predict(test_x)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 15, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " 0.0 1.00 1.00 1.00 312\n", " 1.0 0.98 0.97 0.98 289\n", " 2.0 0.97 1.00 0.99 312\n", " 3.0 0.95 0.99 0.97 278\n", " 4.0 0.98 0.92 0.95 288\n", " 5.0 0.98 0.98 0.98 270\n", "\n", " accuracy 0.98 1749\n", " macro avg 0.98 0.98 0.98 1749\n", "weighted avg 0.98 0.98 0.98 1749\n", "\n" ] } ], "source": [ "from sklearn.metrics import classification_report\n", "\n", "print(classification_report(y_pred=pred_y, y_true=test_y))" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "## 二分类精度" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 16, "outputs": [], "source": [ "test_y[test_y <= 1] = 0\n", "test_y[test_y > 1] = 1\n", "pred_y = tree.predict_bin(test_x)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 17, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " 0.0 1.00 1.00 1.00 598\n", " 1.0 1.00 1.00 1.00 1151\n", "\n", " accuracy 1.00 1749\n", " macro avg 1.00 1.00 1.00 1749\n", "weighted avg 1.00 1.00 1.00 1749\n", "\n" ] } ], "source": [ "print(classification_report(y_true=pred_y, y_pred=pred_y))" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "# 模型保存" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 18, "outputs": [], "source": [ "import datetime\n", "\n", "path = datetime.datetime.now().strftime(f\"models/pixel_%Y-%m-%d_%H-%M.model\")\n", "with open(path, 'wb') as f:\n", " pickle.dump(tree, f)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 8, "outputs": [], "source": [ "from models import DecisionTree\n", "from sklearn.model_selection import train_test_split" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 9, "outputs": [], "source": [ "train_x, test_x, train_y, test_y = train_test_split(x_resampled, y_resampled, test_size=0.2)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 11, "outputs": [], "source": [ "tree = DecisionTree(class_weight={1:20})" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 12, "outputs": [], "source": [ "tree = tree.fit(train_x, train_y)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "# 模型评估" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "markdown", "source": [ "## 多分类精度" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 13, "outputs": [], "source": [ "pred_y = tree.predict(test_x)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 14, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " 0.0 1.00 1.00 1.00 304\n", " 1.0 0.97 0.98 0.97 275\n", " 2.0 0.98 1.00 0.99 297\n", " 3.0 0.95 0.99 0.97 293\n", " 4.0 0.98 0.91 0.95 316\n", " 5.0 0.98 0.98 0.98 264\n", "\n", " accuracy 0.98 1749\n", " macro avg 0.98 0.98 0.98 1749\n", "weighted avg 0.98 0.98 0.98 1749\n", "\n" ] } ], "source": [ "from sklearn.metrics import classification_report\n", "\n", "print(classification_report(y_pred=pred_y, y_true=test_y))" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "## 二分类精度" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 19, "outputs": [], "source": [ "test_y[test_y <= 1] = 0\n", "test_y[test_y > 1] = 1\n", "pred_y = tree.predict_bin(test_x)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 20, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " 0.0 1.00 1.00 1.00 582\n", " 1.0 1.00 1.00 1.00 1167\n", "\n", " accuracy 1.00 1749\n", " macro avg 1.00 1.00 1.00 1749\n", "weighted avg 1.00 1.00 1.00 1749\n", "\n" ] } ], "source": [ "print(classification_report(y_true=pred_y, y_pred=pred_y))" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "# 模型保存" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 22, "outputs": [], "source": [ "import datetime\n", "\n", "path = datetime.datetime.now().strftime(f\"models/pixel_%Y-%m-%d_%H-%M.model\")\n", "with open(path, 'wb') as f:\n", " pickle.dump(tree, f)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": null, "outputs": [], "source": [], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 0 }