mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 14:23:55 +00:00
完成新的dt模型
This commit is contained in:
parent
e4cedc2516
commit
0e297c892c
@ -15,8 +15,17 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": 3,
|
||||||
"outputs": [],
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stderr",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"C:\\Users\\FEIJINTI\\miniconda3\\envs\\deepo\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||||
|
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
"source": [
|
"source": [
|
||||||
"import numpy as np\n",
|
"import numpy as np\n",
|
||||||
"import pickle\n",
|
"import pickle\n",
|
||||||
@ -45,7 +54,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 4,
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"data_path = r'data/envi20220802.txt'\n",
|
"data_path = r'data/envi20220802.txt'\n",
|
||||||
@ -72,7 +81,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 3,
|
"execution_count": 5,
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"data = read_envi_ascii(data_path)"
|
"data = read_envi_ascii(data_path)"
|
||||||
@ -86,7 +95,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 4,
|
"execution_count": 6,
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
@ -114,7 +123,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": 7,
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"data_x = [d for class_name, d in data.items() if class_name in name_dict.keys()]\n",
|
"data_x = [d for class_name, d in data.items() if class_name in name_dict.keys()]\n",
|
||||||
@ -142,7 +151,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 6,
|
"execution_count": 8,
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
@ -179,7 +188,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 7,
|
"execution_count": 9,
|
||||||
"outputs": [
|
"outputs": [
|
||||||
{
|
{
|
||||||
"name": "stdout",
|
"name": "stdout",
|
||||||
@ -215,6 +224,237 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 10,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"from models import DecisionTree\n",
|
||||||
|
"from sklearn.model_selection import train_test_split"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 11,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"train_x, test_x, train_y, test_y = train_test_split(x_resampled, y_resampled, test_size=0.2)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 12,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"tree = DecisionTree(class_weight={1:20})"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 13,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"tree = tree.fit(train_x, train_y)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"# 模型评估"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"## 多分类精度"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 14,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"pred_y = tree.predict(test_x)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 15,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" precision recall f1-score support\n",
|
||||||
|
"\n",
|
||||||
|
" 0.0 1.00 1.00 1.00 312\n",
|
||||||
|
" 1.0 0.98 0.97 0.98 289\n",
|
||||||
|
" 2.0 0.97 1.00 0.99 312\n",
|
||||||
|
" 3.0 0.95 0.99 0.97 278\n",
|
||||||
|
" 4.0 0.98 0.92 0.95 288\n",
|
||||||
|
" 5.0 0.98 0.98 0.98 270\n",
|
||||||
|
"\n",
|
||||||
|
" accuracy 0.98 1749\n",
|
||||||
|
" macro avg 0.98 0.98 0.98 1749\n",
|
||||||
|
"weighted avg 0.98 0.98 0.98 1749\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from sklearn.metrics import classification_report\n",
|
||||||
|
"\n",
|
||||||
|
"print(classification_report(y_pred=pred_y, y_true=test_y))"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"## 二分类精度"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 16,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"test_y[test_y <= 1] = 0\n",
|
||||||
|
"test_y[test_y > 1] = 1\n",
|
||||||
|
"pred_y = tree.predict_bin(test_x)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 17,
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
" precision recall f1-score support\n",
|
||||||
|
"\n",
|
||||||
|
" 0.0 1.00 1.00 1.00 598\n",
|
||||||
|
" 1.0 1.00 1.00 1.00 1151\n",
|
||||||
|
"\n",
|
||||||
|
" accuracy 1.00 1749\n",
|
||||||
|
" macro avg 1.00 1.00 1.00 1749\n",
|
||||||
|
"weighted avg 1.00 1.00 1.00 1749\n",
|
||||||
|
"\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"print(classification_report(y_true=pred_y, y_pred=pred_y))"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"source": [
|
||||||
|
"# 模型保存"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 18,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import datetime\n",
|
||||||
|
"\n",
|
||||||
|
"path = datetime.datetime.now().strftime(f\"models/pixel_%Y-%m-%d_%H-%M.model\")\n",
|
||||||
|
"with open(path, 'wb') as f:\n",
|
||||||
|
" pickle.dump(tree, f)"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 8,
|
"execution_count": 8,
|
||||||
|
|||||||
@ -19,7 +19,7 @@ class Config:
|
|||||||
|
|
||||||
# 光谱模型参数
|
# 光谱模型参数
|
||||||
blk_size = 4
|
blk_size = 4
|
||||||
pixel_model_path = r"./models/dt.p"
|
pixel_model_path = r"./models/pixel_2022-08-02_15-22.model"
|
||||||
blk_model_path = r"./models/rf_4x4_c22_20_sen8_9.model"
|
blk_model_path = r"./models/rf_4x4_c22_20_sen8_9.model"
|
||||||
spec_size_threshold = 3
|
spec_size_threshold = 3
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user