supermachine-tobacco/05_model_training.ipynb

471 lines
9.2 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"source": [
"# 训练像素模型\n",
"用这个文件可以训练出需要使用的光谱像素点模型"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 1,
"outputs": [],
"source": [
"import numpy as np\n",
"import pickle\n",
"from utils import read_envi_ascii\n",
"from config import Config\n",
"from models import ManualTree"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"# 一些变量"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 2,
"outputs": [],
"source": [
"data_path = r'data/envi20220802.txt'\n",
"name_dict = {'tobacco': 1, 'yantou':2, 'kazhi':3, 'bomo':4, 'jiaodai':5, 'background':0}"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"# 构建数据集"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [],
"source": [
"data = read_envi_ascii(data_path)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"zibian (569, 448)\n",
"tobacco (1457, 448)\n",
"yantou (354, 448)\n",
"kazhi (449, 448)\n",
"bomo (1154, 448)\n",
"jiaodai (566, 448)\n",
"background (1235, 448)\n"
]
}
],
"source": [
"_ = [print(class_name, d.shape) for class_name, d in data.items()]"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 5,
"outputs": [],
"source": [
"data_x = [d for class_name, d in data.items() if class_name in name_dict.keys()]\n",
"data_y = [np.ones((d.shape[0], ))*name_dict[class_name] for class_name, d in data.items() if class_name in name_dict.keys()]\n",
"data_x, data_y = np.concatenate(data_x), np.concatenate(data_y)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"## 取出需要的22个特征谱段"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"这些是现在的数据: (5215, 448) (5215,)\n",
"截取其中需要的部分后: (5215, 22) (5215,)\n"
]
}
],
"source": [
"print(\"这些是现在的数据: \", data_x.shape, data_y.shape)\n",
"data_x_cut = data_x[..., Config.bands]\n",
"print(\"截取其中需要的部分后: \", data_x_cut.shape, data_y.shape)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"## 进行样本平衡"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 7,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"这是重采样后的数据: (8742, 22) (8742,)\n"
]
}
],
"source": [
"from imblearn.over_sampling import RandomOverSampler\n",
"ros = RandomOverSampler(random_state=0)\n",
"x_resampled, y_resampled = ros.fit_resample(data_x_cut, data_y)\n",
"print('这是重采样后的数据: ', x_resampled.shape, y_resampled.shape)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"# 进行模型训练\n",
"分出一部分数据进行训练"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 8,
"outputs": [],
"source": [
"from models import DecisionTree\n",
"from sklearn.model_selection import train_test_split"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 9,
"outputs": [],
"source": [
"train_x, test_x, train_y, test_y = train_test_split(x_resampled, y_resampled, test_size=0.2)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 11,
"outputs": [],
"source": [
"tree = DecisionTree(class_weight={1:20})"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 12,
"outputs": [],
"source": [
"tree = tree.fit(train_x, train_y)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"# 模型评估"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"## 多分类精度"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 13,
"outputs": [],
"source": [
"pred_y = tree.predict(test_x)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 14,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0.0 1.00 1.00 1.00 304\n",
" 1.0 0.97 0.98 0.97 275\n",
" 2.0 0.98 1.00 0.99 297\n",
" 3.0 0.95 0.99 0.97 293\n",
" 4.0 0.98 0.91 0.95 316\n",
" 5.0 0.98 0.98 0.98 264\n",
"\n",
" accuracy 0.98 1749\n",
" macro avg 0.98 0.98 0.98 1749\n",
"weighted avg 0.98 0.98 0.98 1749\n",
"\n"
]
}
],
"source": [
"from sklearn.metrics import classification_report\n",
"\n",
"print(classification_report(y_pred=pred_y, y_true=test_y))"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"## 二分类精度"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 19,
"outputs": [],
"source": [
"test_y[test_y <= 1] = 0\n",
"test_y[test_y > 1] = 1\n",
"pred_y = tree.predict_bin(test_x)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 20,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0.0 1.00 1.00 1.00 582\n",
" 1.0 1.00 1.00 1.00 1167\n",
"\n",
" accuracy 1.00 1749\n",
" macro avg 1.00 1.00 1.00 1749\n",
"weighted avg 1.00 1.00 1.00 1749\n",
"\n"
]
}
],
"source": [
"print(classification_report(y_true=pred_y, y_pred=pred_y))"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"# 模型保存"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 22,
"outputs": [],
"source": [
"import datetime\n",
"\n",
"path = datetime.datetime.now().strftime(f\"models/pixel_%Y-%m-%d_%H-%M.model\")\n",
"with open(path, 'wb') as f:\n",
" pickle.dump(tree, f)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}