# 训练像素模型
用这个文件可以训练出需要使用的光谱像素点模型

In [3]:
import numpy as np
import pickle
from utils import read_envi_ascii
from config import Config
from models import ManualTree

  from .autonotebook import tqdm as notebook_tqdm


# 一些变量

In [4]:
data_path = r'data/envi20220802.txt'
name_dict = {'tobacco': 1, 'yantou':2, 'kazhi':3, 'bomo':4, 'jiaodai':5, 'background':0}

# 构建数据集

In [5]:
data = read_envi_ascii(data_path)

In [6]:
_ = [print(class_name, d.shape) for class_name, d in data.items()]

zibian (569, 448)
tobacco (1457, 448)
yantou (354, 448)
kazhi (449, 448)
bomo (1154, 448)
jiaodai (566, 448)
background (1235, 448)


In [7]:
data_x = [d for class_name, d in data.items() if class_name in name_dict.keys()]
data_y = [np.ones((d.shape[0], ))*name_dict[class_name] for class_name, d in data.items() if class_name in name_dict.keys()]
data_x, data_y = np.concatenate(data_x), np.concatenate(data_y)

## 取出需要的22个特征谱段

In [8]:
print("这些是现在的数据: ", data_x.shape, data_y.shape)
data_x_cut = data_x[..., Config.bands]
print("截取其中需要的部分后: ", data_x_cut.shape, data_y.shape)

这些是现在的数据:  (5215, 448) (5215,)
截取其中需要的部分后:  (5215, 22) (5215,)


## 进行样本平衡

In [9]:
from imblearn.over_sampling import RandomOverSampler
ros = RandomOverSampler(random_state=0)
x_resampled, y_resampled = ros.fit_resample(data_x_cut, data_y)
print('这是重采样后的数据: ', x_resampled.shape, y_resampled.shape)

这是重采样后的数据:  (8742, 22) (8742,)


# 进行模型训练
分出一部分数据进行训练

In [10]:
from models import DecisionTree
from sklearn.model_selection import train_test_split

In [11]:
train_x, test_x, train_y, test_y = train_test_split(x_resampled, y_resampled, test_size=0.2)

In [12]:
tree = DecisionTree(class_weight={1:20})

In [13]:
tree = tree.fit(train_x, train_y)

# 模型评估

## 多分类精度

In [14]:
pred_y = tree.predict(test_x)

In [15]:
from sklearn.metrics import classification_report

print(classification_report(y_pred=pred_y, y_true=test_y))

              precision    recall  f1-score   support

         0.0       1.00      1.00      1.00       312
         1.0       0.98      0.97      0.98       289
         2.0       0.97      1.00      0.99       312
         3.0       0.95      0.99      0.97       278
         4.0       0.98      0.92      0.95       288
         5.0       0.98      0.98      0.98       270

    accuracy                           0.98      1749
   macro avg       0.98      0.98      0.98      1749
weighted avg       0.98      0.98      0.98      1749



## 二分类精度

In [16]:
test_y[test_y <= 1] = 0
test_y[test_y > 1] = 1
pred_y = tree.predict_bin(test_x)

In [17]:
print(classification_report(y_true=pred_y, y_pred=pred_y))

              precision    recall  f1-score   support

         0.0       1.00      1.00      1.00       598
         1.0       1.00      1.00      1.00      1151

    accuracy                           1.00      1749
   macro avg       1.00      1.00      1.00      1749
weighted avg       1.00      1.00      1.00      1749



# 模型保存

In [18]:
import datetime

path = datetime.datetime.now().strftime(f"models/pixel_%Y-%m-%d_%H-%M.model")
with open(path, 'wb') as f:
    pickle.dump(tree, f)

In [8]:
from models import DecisionTree
from sklearn.model_selection import train_test_split

In [9]:
train_x, test_x, train_y, test_y = train_test_split(x_resampled, y_resampled, test_size=0.2)

In [11]:
tree = DecisionTree(class_weight={1:20})

In [12]:
tree = tree.fit(train_x, train_y)

# 模型评估

## 多分类精度

In [13]:
pred_y = tree.predict(test_x)

In [14]:
from sklearn.metrics import classification_report

print(classification_report(y_pred=pred_y, y_true=test_y))

              precision    recall  f1-score   support

         0.0       1.00      1.00      1.00       304
         1.0       0.97      0.98      0.97       275
         2.0       0.98      1.00      0.99       297
         3.0       0.95      0.99      0.97       293
         4.0       0.98      0.91      0.95       316
         5.0       0.98      0.98      0.98       264

    accuracy                           0.98      1749
   macro avg       0.98      0.98      0.98      1749
weighted avg       0.98      0.98      0.98      1749



## 二分类精度

In [19]:
test_y[test_y <= 1] = 0
test_y[test_y > 1] = 1
pred_y = tree.predict_bin(test_x)

In [20]:
print(classification_report(y_true=pred_y, y_pred=pred_y))

              precision    recall  f1-score   support

         0.0       1.00      1.00      1.00       582
         1.0       1.00      1.00      1.00      1167

    accuracy                           1.00      1749
   macro avg       1.00      1.00      1.00      1749
weighted avg       1.00      1.00      1.00      1749



# 模型保存

In [22]:
import datetime

path = datetime.datetime.now().strftime(f"models/pixel_%Y-%m-%d_%H-%M.model")
with open(path, 'wb') as f:
    pickle.dump(tree, f)