supermachine-tobacco/05_model_training.ipynb
2022-11-15 12:46:20 +08:00

717 lines
19 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"source": [
"# 训练像素模型\n",
"用这个文件可以训练出需要使用的光谱像素点模型"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 1,
"outputs": [
{
"ename": "NotImplementedError",
"evalue": "开发时和机器上部署的路径不同请注意选择rgb_tobacco_model_path、rgb_background_model_path、ai_path后删除本行",
"output_type": "error",
"traceback": [
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[1;31mNotImplementedError\u001B[0m Traceback (most recent call last)",
"Input \u001B[1;32mIn [1]\u001B[0m, in \u001B[0;36m<cell line: 3>\u001B[1;34m()\u001B[0m\n\u001B[0;32m 1\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mnumpy\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m \u001B[38;5;21;01mnp\u001B[39;00m\n\u001B[0;32m 2\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mpickle\u001B[39;00m\n\u001B[1;32m----> 3\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mutils\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m read_envi_ascii\n\u001B[0;32m 4\u001B[0m \u001B[38;5;66;03m# from config import Config\u001B[39;00m\n\u001B[0;32m 5\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mmodels\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m ManualTree\n",
"File \u001B[1;32m~\\OneDrive - macrosolid\\PycharmProjects\\tobacco_color\\utils\\__init__.py:20\u001B[0m, in \u001B[0;36m<module>\u001B[1;34m\u001B[0m\n\u001B[0;32m 17\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mmatplotlib\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m pyplot \u001B[38;5;28;01mas\u001B[39;00m plt\n\u001B[0;32m 18\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mre\u001B[39;00m\n\u001B[1;32m---> 20\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mconfig\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Config\n\u001B[0;32m 23\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mnatural_sort\u001B[39m(l):\n\u001B[0;32m 24\u001B[0m \u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[0;32m 25\u001B[0m \u001B[38;5;124;03m 自然排序\u001B[39;00m\n\u001B[0;32m 26\u001B[0m \u001B[38;5;124;03m :param l: 待排序\u001B[39;00m\n\u001B[0;32m 27\u001B[0m \u001B[38;5;124;03m :return:\u001B[39;00m\n\u001B[0;32m 28\u001B[0m \u001B[38;5;124;03m \"\"\"\u001B[39;00m\n",
"File \u001B[1;32m~\\OneDrive - macrosolid\\PycharmProjects\\tobacco_color\\config.py:6\u001B[0m, in \u001B[0;36m<module>\u001B[1;34m\u001B[0m\n\u001B[0;32m 1\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mos\u001B[39;00m\n\u001B[0;32m 3\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mnumpy\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m \u001B[38;5;21;01mnp\u001B[39;00m\n\u001B[1;32m----> 6\u001B[0m \u001B[38;5;28;01mclass\u001B[39;00m \u001B[38;5;21;01mConfig\u001B[39;00m:\n\u001B[0;32m 7\u001B[0m \u001B[38;5;66;03m# 文件相关参数\u001B[39;00m\n\u001B[0;32m 8\u001B[0m nRows, nCols, nBands \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m256\u001B[39m, \u001B[38;5;241m1024\u001B[39m, \u001B[38;5;241m22\u001B[39m\n\u001B[0;32m 9\u001B[0m nRgbRows, nRgbCols, nRgbBands \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m1024\u001B[39m, \u001B[38;5;241m4096\u001B[39m, \u001B[38;5;241m3\u001B[39m\n",
"File \u001B[1;32m~\\OneDrive - macrosolid\\PycharmProjects\\tobacco_color\\config.py:31\u001B[0m, in \u001B[0;36mConfig\u001B[1;34m()\u001B[0m\n\u001B[0;32m 28\u001B[0m spec_size_threshold \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m3\u001B[39m\n\u001B[0;32m 30\u001B[0m \u001B[38;5;66;03m# rgb模型参数\u001B[39;00m\n\u001B[1;32m---> 31\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mNotImplementedError\u001B[39;00m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m开发时和机器上部署的路径不同请注意选择rgb_tobacco_model_path、rgb_background_model_path、ai_path后删除本行\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[0;32m 32\u001B[0m \u001B[38;5;66;03m# rgb_tobacco_model_path = r\"weights/tobacco_dt_2022-08-27_14-43.model\" # 开发时的路径\u001B[39;00m\n\u001B[0;32m 33\u001B[0m \u001B[38;5;66;03m# rgb_tobacco_model_path = r\"/home/dt/tobacco-color/weights/tobacco_dt_2022-08-27_14-43.model\" # 机器上部署的路径\u001B[39;00m\n\u001B[0;32m 34\u001B[0m \u001B[38;5;66;03m# rgb_background_model_path = r\"weights/background_dt_2022-08-22_22-15.model\" # 开发时的路径\u001B[39;00m\n\u001B[0;32m 35\u001B[0m \u001B[38;5;66;03m# rgb_background_model_path = r\"/home/dt/tobacco-color/weights/background_dt_2022-08-22_22-15.model\" # 机器上部署的路径\u001B[39;00m\n\u001B[0;32m 36\u001B[0m threshold_low, threshold_high \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m10\u001B[39m, \u001B[38;5;241m230\u001B[39m\n",
"\u001B[1;31mNotImplementedError\u001B[0m: 开发时和机器上部署的路径不同请注意选择rgb_tobacco_model_path、rgb_background_model_path、ai_path后删除本行"
]
}
],
"source": [
"import numpy as np\n",
"import pickle\n",
"from utils import read_envi_ascii\n",
"from config import Config\n",
"from models import ManualTree"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"# 一些变量"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [],
"source": [
"data_path = r'data/envi20220802.txt'\n",
"name_dict = {'tobacco': 1, 'yantou':2, 'kazhi':3, 'bomo':4, 'jiaodai':5, 'background':0}"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"# 构建数据集"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 5,
"outputs": [],
"source": [
"data = read_envi_ascii(data_path)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"zibian (569, 448)\n",
"tobacco (1457, 448)\n",
"yantou (354, 448)\n",
"kazhi (449, 448)\n",
"bomo (1154, 448)\n",
"jiaodai (566, 448)\n",
"background (1235, 448)\n"
]
}
],
"source": [
"_ = [print(class_name, d.shape) for class_name, d in data.items()]"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 7,
"outputs": [],
"source": [
"data_x = [d for class_name, d in data.items() if class_name in name_dict.keys()]\n",
"data_y = [np.ones((d.shape[0], ))*name_dict[class_name] for class_name, d in data.items() if class_name in name_dict.keys()]\n",
"data_x, data_y = np.concatenate(data_x), np.concatenate(data_y)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"## 取出需要的22个特征谱段"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 8,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"这些是现在的数据: (5215, 448) (5215,)\n",
"截取其中需要的部分后: (5215, 22) (5215,)\n"
]
}
],
"source": [
"print(\"这些是现在的数据: \", data_x.shape, data_y.shape)\n",
"data_x_cut = data_x[..., Config.bands]\n",
"print(\"截取其中需要的部分后: \", data_x_cut.shape, data_y.shape)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"## 进行样本平衡"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 9,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"这是重采样后的数据: (8742, 22) (8742,)\n"
]
}
],
"source": [
"from imblearn.over_sampling import RandomOverSampler\n",
"ros = RandomOverSampler(random_state=0)\n",
"x_resampled, y_resampled = ros.fit_resample(data_x_cut, data_y)\n",
"print('这是重采样后的数据: ', x_resampled.shape, y_resampled.shape)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"# 进行模型训练\n",
"分出一部分数据进行训练"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 10,
"outputs": [],
"source": [
"from models import DecisionTree\n",
"from sklearn.model_selection import train_test_split"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 11,
"outputs": [],
"source": [
"train_x, test_x, train_y, test_y = train_test_split(x_resampled, y_resampled, test_size=0.2)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 12,
"outputs": [],
"source": [
"tree = DecisionTree(class_weight={1:20})"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 13,
"outputs": [],
"source": [
"tree = tree.fit(train_x, train_y)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"# 模型评估"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"## 多分类精度"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 14,
"outputs": [],
"source": [
"pred_y = tree.predict(test_x)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 15,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0.0 1.00 1.00 1.00 312\n",
" 1.0 0.98 0.97 0.98 289\n",
" 2.0 0.97 1.00 0.99 312\n",
" 3.0 0.95 0.99 0.97 278\n",
" 4.0 0.98 0.92 0.95 288\n",
" 5.0 0.98 0.98 0.98 270\n",
"\n",
" accuracy 0.98 1749\n",
" macro avg 0.98 0.98 0.98 1749\n",
"weighted avg 0.98 0.98 0.98 1749\n",
"\n"
]
}
],
"source": [
"from sklearn.metrics import classification_report\n",
"\n",
"print(classification_report(y_pred=pred_y, y_true=test_y))"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"## 二分类精度"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 16,
"outputs": [],
"source": [
"test_y[test_y <= 1] = 0\n",
"test_y[test_y > 1] = 1\n",
"pred_y = tree.predict_bin(test_x)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 17,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0.0 1.00 1.00 1.00 598\n",
" 1.0 1.00 1.00 1.00 1151\n",
"\n",
" accuracy 1.00 1749\n",
" macro avg 1.00 1.00 1.00 1749\n",
"weighted avg 1.00 1.00 1.00 1749\n",
"\n"
]
}
],
"source": [
"print(classification_report(y_true=pred_y, y_pred=pred_y))"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"# 模型保存"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 18,
"outputs": [],
"source": [
"import datetime\n",
"\n",
"path = datetime.datetime.now().strftime(f\"models/pixel_%Y-%m-%d_%H-%M.model\")\n",
"with open(path, 'wb') as f:\n",
" pickle.dump(tree, f)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 8,
"outputs": [],
"source": [
"from models import DecisionTree\n",
"from sklearn.model_selection import train_test_split"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 9,
"outputs": [],
"source": [
"train_x, test_x, train_y, test_y = train_test_split(x_resampled, y_resampled, test_size=0.2)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 11,
"outputs": [],
"source": [
"tree = DecisionTree(class_weight={1:20})"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 12,
"outputs": [],
"source": [
"tree = tree.fit(train_x, train_y)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"# 模型评估"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"## 多分类精度"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 13,
"outputs": [],
"source": [
"pred_y = tree.predict(test_x)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 14,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0.0 1.00 1.00 1.00 304\n",
" 1.0 0.97 0.98 0.97 275\n",
" 2.0 0.98 1.00 0.99 297\n",
" 3.0 0.95 0.99 0.97 293\n",
" 4.0 0.98 0.91 0.95 316\n",
" 5.0 0.98 0.98 0.98 264\n",
"\n",
" accuracy 0.98 1749\n",
" macro avg 0.98 0.98 0.98 1749\n",
"weighted avg 0.98 0.98 0.98 1749\n",
"\n"
]
}
],
"source": [
"from sklearn.metrics import classification_report\n",
"\n",
"print(classification_report(y_pred=pred_y, y_true=test_y))"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"## 二分类精度"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 19,
"outputs": [],
"source": [
"test_y[test_y <= 1] = 0\n",
"test_y[test_y > 1] = 1\n",
"pred_y = tree.predict_bin(test_x)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 20,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0.0 1.00 1.00 1.00 582\n",
" 1.0 1.00 1.00 1.00 1167\n",
"\n",
" accuracy 1.00 1749\n",
" macro avg 1.00 1.00 1.00 1749\n",
"weighted avg 1.00 1.00 1.00 1749\n",
"\n"
]
}
],
"source": [
"print(classification_report(y_true=pred_y, y_pred=pred_y))"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"# 模型保存"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 22,
"outputs": [],
"source": [
"import datetime\n",
"\n",
"path = datetime.datetime.now().strftime(f\"models/pixel_%Y-%m-%d_%H-%M.model\")\n",
"with open(path, 'wb') as f:\n",
" pickle.dump(tree, f)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}