mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 22:33:54 +00:00
717 lines
19 KiB
Plaintext
717 lines
19 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"# 训练像素模型\n",
|
||
"用这个文件可以训练出需要使用的光谱像素点模型"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%% md\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"outputs": [
|
||
{
|
||
"ename": "NotImplementedError",
|
||
"evalue": "开发时和机器上部署的路径不同,请注意选择rgb_tobacco_model_path、rgb_background_model_path、ai_path后删除本行",
|
||
"output_type": "error",
|
||
"traceback": [
|
||
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
|
||
"\u001B[1;31mNotImplementedError\u001B[0m Traceback (most recent call last)",
|
||
"Input \u001B[1;32mIn [1]\u001B[0m, in \u001B[0;36m<cell line: 3>\u001B[1;34m()\u001B[0m\n\u001B[0;32m 1\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mnumpy\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m \u001B[38;5;21;01mnp\u001B[39;00m\n\u001B[0;32m 2\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mpickle\u001B[39;00m\n\u001B[1;32m----> 3\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mutils\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m read_envi_ascii\n\u001B[0;32m 4\u001B[0m \u001B[38;5;66;03m# from config import Config\u001B[39;00m\n\u001B[0;32m 5\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mmodels\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m ManualTree\n",
|
||
"File \u001B[1;32m~\\OneDrive - macrosolid\\PycharmProjects\\tobacco_color\\utils\\__init__.py:20\u001B[0m, in \u001B[0;36m<module>\u001B[1;34m\u001B[0m\n\u001B[0;32m 17\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mmatplotlib\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m pyplot \u001B[38;5;28;01mas\u001B[39;00m plt\n\u001B[0;32m 18\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mre\u001B[39;00m\n\u001B[1;32m---> 20\u001B[0m \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01mconfig\u001B[39;00m \u001B[38;5;28;01mimport\u001B[39;00m Config\n\u001B[0;32m 23\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mnatural_sort\u001B[39m(l):\n\u001B[0;32m 24\u001B[0m \u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[0;32m 25\u001B[0m \u001B[38;5;124;03m 自然排序\u001B[39;00m\n\u001B[0;32m 26\u001B[0m \u001B[38;5;124;03m :param l: 待排序\u001B[39;00m\n\u001B[0;32m 27\u001B[0m \u001B[38;5;124;03m :return:\u001B[39;00m\n\u001B[0;32m 28\u001B[0m \u001B[38;5;124;03m \"\"\"\u001B[39;00m\n",
|
||
"File \u001B[1;32m~\\OneDrive - macrosolid\\PycharmProjects\\tobacco_color\\config.py:6\u001B[0m, in \u001B[0;36m<module>\u001B[1;34m\u001B[0m\n\u001B[0;32m 1\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mos\u001B[39;00m\n\u001B[0;32m 3\u001B[0m \u001B[38;5;28;01mimport\u001B[39;00m \u001B[38;5;21;01mnumpy\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m \u001B[38;5;21;01mnp\u001B[39;00m\n\u001B[1;32m----> 6\u001B[0m \u001B[38;5;28;01mclass\u001B[39;00m \u001B[38;5;21;01mConfig\u001B[39;00m:\n\u001B[0;32m 7\u001B[0m \u001B[38;5;66;03m# 文件相关参数\u001B[39;00m\n\u001B[0;32m 8\u001B[0m nRows, nCols, nBands \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m256\u001B[39m, \u001B[38;5;241m1024\u001B[39m, \u001B[38;5;241m22\u001B[39m\n\u001B[0;32m 9\u001B[0m nRgbRows, nRgbCols, nRgbBands \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m1024\u001B[39m, \u001B[38;5;241m4096\u001B[39m, \u001B[38;5;241m3\u001B[39m\n",
|
||
"File \u001B[1;32m~\\OneDrive - macrosolid\\PycharmProjects\\tobacco_color\\config.py:31\u001B[0m, in \u001B[0;36mConfig\u001B[1;34m()\u001B[0m\n\u001B[0;32m 28\u001B[0m spec_size_threshold \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m3\u001B[39m\n\u001B[0;32m 30\u001B[0m \u001B[38;5;66;03m# rgb模型参数\u001B[39;00m\n\u001B[1;32m---> 31\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mNotImplementedError\u001B[39;00m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m开发时和机器上部署的路径不同,请注意选择rgb_tobacco_model_path、rgb_background_model_path、ai_path后删除本行\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[0;32m 32\u001B[0m \u001B[38;5;66;03m# rgb_tobacco_model_path = r\"weights/tobacco_dt_2022-08-27_14-43.model\" # 开发时的路径\u001B[39;00m\n\u001B[0;32m 33\u001B[0m \u001B[38;5;66;03m# rgb_tobacco_model_path = r\"/home/dt/tobacco-color/weights/tobacco_dt_2022-08-27_14-43.model\" # 机器上部署的路径\u001B[39;00m\n\u001B[0;32m 34\u001B[0m \u001B[38;5;66;03m# rgb_background_model_path = r\"weights/background_dt_2022-08-22_22-15.model\" # 开发时的路径\u001B[39;00m\n\u001B[0;32m 35\u001B[0m \u001B[38;5;66;03m# rgb_background_model_path = r\"/home/dt/tobacco-color/weights/background_dt_2022-08-22_22-15.model\" # 机器上部署的路径\u001B[39;00m\n\u001B[0;32m 36\u001B[0m threshold_low, threshold_high \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m10\u001B[39m, \u001B[38;5;241m230\u001B[39m\n",
|
||
"\u001B[1;31mNotImplementedError\u001B[0m: 开发时和机器上部署的路径不同,请注意选择rgb_tobacco_model_path、rgb_background_model_path、ai_path后删除本行"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"import pickle\n",
|
||
"from utils import read_envi_ascii\n",
|
||
"from config import Config\n",
|
||
"from models import ManualTree"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%%\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"# 一些变量"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%% md\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"outputs": [],
|
||
"source": [
|
||
"data_path = r'data/envi20220802.txt'\n",
|
||
"name_dict = {'tobacco': 1, 'yantou':2, 'kazhi':3, 'bomo':4, 'jiaodai':5, 'background':0}"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%%\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"# 构建数据集"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%% md\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"outputs": [],
|
||
"source": [
|
||
"data = read_envi_ascii(data_path)"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%%\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"zibian (569, 448)\n",
|
||
"tobacco (1457, 448)\n",
|
||
"yantou (354, 448)\n",
|
||
"kazhi (449, 448)\n",
|
||
"bomo (1154, 448)\n",
|
||
"jiaodai (566, 448)\n",
|
||
"background (1235, 448)\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"_ = [print(class_name, d.shape) for class_name, d in data.items()]"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%%\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"outputs": [],
|
||
"source": [
|
||
"data_x = [d for class_name, d in data.items() if class_name in name_dict.keys()]\n",
|
||
"data_y = [np.ones((d.shape[0], ))*name_dict[class_name] for class_name, d in data.items() if class_name in name_dict.keys()]\n",
|
||
"data_x, data_y = np.concatenate(data_x), np.concatenate(data_y)"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%%\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"## 取出需要的22个特征谱段"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%% md\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"这些是现在的数据: (5215, 448) (5215,)\n",
|
||
"截取其中需要的部分后: (5215, 22) (5215,)\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"print(\"这些是现在的数据: \", data_x.shape, data_y.shape)\n",
|
||
"data_x_cut = data_x[..., Config.bands]\n",
|
||
"print(\"截取其中需要的部分后: \", data_x_cut.shape, data_y.shape)"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%%\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"## 进行样本平衡"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%% md\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"这是重采样后的数据: (8742, 22) (8742,)\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"from imblearn.over_sampling import RandomOverSampler\n",
|
||
"ros = RandomOverSampler(random_state=0)\n",
|
||
"x_resampled, y_resampled = ros.fit_resample(data_x_cut, data_y)\n",
|
||
"print('这是重采样后的数据: ', x_resampled.shape, y_resampled.shape)"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%%\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"# 进行模型训练\n",
|
||
"分出一部分数据进行训练"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%% md\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"outputs": [],
|
||
"source": [
|
||
"from models import DecisionTree\n",
|
||
"from sklearn.model_selection import train_test_split"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%%\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"outputs": [],
|
||
"source": [
|
||
"train_x, test_x, train_y, test_y = train_test_split(x_resampled, y_resampled, test_size=0.2)"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%%\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"outputs": [],
|
||
"source": [
|
||
"tree = DecisionTree(class_weight={1:20})"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%%\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"outputs": [],
|
||
"source": [
|
||
"tree = tree.fit(train_x, train_y)"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%%\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"# 模型评估"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%% md\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"## 多分类精度"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%% md\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"outputs": [],
|
||
"source": [
|
||
"pred_y = tree.predict(test_x)"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%%\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 15,
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" precision recall f1-score support\n",
|
||
"\n",
|
||
" 0.0 1.00 1.00 1.00 312\n",
|
||
" 1.0 0.98 0.97 0.98 289\n",
|
||
" 2.0 0.97 1.00 0.99 312\n",
|
||
" 3.0 0.95 0.99 0.97 278\n",
|
||
" 4.0 0.98 0.92 0.95 288\n",
|
||
" 5.0 0.98 0.98 0.98 270\n",
|
||
"\n",
|
||
" accuracy 0.98 1749\n",
|
||
" macro avg 0.98 0.98 0.98 1749\n",
|
||
"weighted avg 0.98 0.98 0.98 1749\n",
|
||
"\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"from sklearn.metrics import classification_report\n",
|
||
"\n",
|
||
"print(classification_report(y_pred=pred_y, y_true=test_y))"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%%\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"## 二分类精度"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%% md\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 16,
|
||
"outputs": [],
|
||
"source": [
|
||
"test_y[test_y <= 1] = 0\n",
|
||
"test_y[test_y > 1] = 1\n",
|
||
"pred_y = tree.predict_bin(test_x)"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%%\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 17,
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" precision recall f1-score support\n",
|
||
"\n",
|
||
" 0.0 1.00 1.00 1.00 598\n",
|
||
" 1.0 1.00 1.00 1.00 1151\n",
|
||
"\n",
|
||
" accuracy 1.00 1749\n",
|
||
" macro avg 1.00 1.00 1.00 1749\n",
|
||
"weighted avg 1.00 1.00 1.00 1749\n",
|
||
"\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"print(classification_report(y_true=pred_y, y_pred=pred_y))"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%%\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"# 模型保存"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%% md\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 18,
|
||
"outputs": [],
|
||
"source": [
|
||
"import datetime\n",
|
||
"\n",
|
||
"path = datetime.datetime.now().strftime(f\"models/pixel_%Y-%m-%d_%H-%M.model\")\n",
|
||
"with open(path, 'wb') as f:\n",
|
||
" pickle.dump(tree, f)"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%%\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"outputs": [],
|
||
"source": [],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%%\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"outputs": [],
|
||
"source": [
|
||
"from models import DecisionTree\n",
|
||
"from sklearn.model_selection import train_test_split"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%%\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"outputs": [],
|
||
"source": [
|
||
"train_x, test_x, train_y, test_y = train_test_split(x_resampled, y_resampled, test_size=0.2)"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%%\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"outputs": [],
|
||
"source": [
|
||
"tree = DecisionTree(class_weight={1:20})"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%%\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"outputs": [],
|
||
"source": [
|
||
"tree = tree.fit(train_x, train_y)"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%%\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"# 模型评估"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%% md\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"## 多分类精度"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%% md\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 13,
|
||
"outputs": [],
|
||
"source": [
|
||
"pred_y = tree.predict(test_x)"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%%\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" precision recall f1-score support\n",
|
||
"\n",
|
||
" 0.0 1.00 1.00 1.00 304\n",
|
||
" 1.0 0.97 0.98 0.97 275\n",
|
||
" 2.0 0.98 1.00 0.99 297\n",
|
||
" 3.0 0.95 0.99 0.97 293\n",
|
||
" 4.0 0.98 0.91 0.95 316\n",
|
||
" 5.0 0.98 0.98 0.98 264\n",
|
||
"\n",
|
||
" accuracy 0.98 1749\n",
|
||
" macro avg 0.98 0.98 0.98 1749\n",
|
||
"weighted avg 0.98 0.98 0.98 1749\n",
|
||
"\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"from sklearn.metrics import classification_report\n",
|
||
"\n",
|
||
"print(classification_report(y_pred=pred_y, y_true=test_y))"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%%\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"## 二分类精度"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%% md\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 19,
|
||
"outputs": [],
|
||
"source": [
|
||
"test_y[test_y <= 1] = 0\n",
|
||
"test_y[test_y > 1] = 1\n",
|
||
"pred_y = tree.predict_bin(test_x)"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%%\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 20,
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" precision recall f1-score support\n",
|
||
"\n",
|
||
" 0.0 1.00 1.00 1.00 582\n",
|
||
" 1.0 1.00 1.00 1.00 1167\n",
|
||
"\n",
|
||
" accuracy 1.00 1749\n",
|
||
" macro avg 1.00 1.00 1.00 1749\n",
|
||
"weighted avg 1.00 1.00 1.00 1749\n",
|
||
"\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"print(classification_report(y_true=pred_y, y_pred=pred_y))"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%%\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"# 模型保存"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%% md\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 22,
|
||
"outputs": [],
|
||
"source": [
|
||
"import datetime\n",
|
||
"\n",
|
||
"path = datetime.datetime.now().strftime(f\"models/pixel_%Y-%m-%d_%H-%M.model\")\n",
|
||
"with open(path, 'wb') as f:\n",
|
||
" pickle.dump(tree, f)"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%%\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"outputs": [],
|
||
"source": [],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%%\n"
|
||
}
|
||
}
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 2
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython2",
|
||
"version": "2.7.6"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 0
|
||
} |