mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 06:13:53 +00:00
写了一半的dt
This commit is contained in:
parent
39d3aa4b59
commit
8caffa55bc
@ -23,10 +23,6 @@
|
||||
"import scipy.io\n",
|
||||
"import cv2\n",
|
||||
"import numpy as np\n",
|
||||
"import pickle\n",
|
||||
"from sklearn.tree import DecisionTreeClassifier\n",
|
||||
"# %matplotlib notebook\n",
|
||||
"from main_test import pony_run\n",
|
||||
"from models import AnonymousColorDetector\n",
|
||||
"from utils import lab_scatter"
|
||||
],
|
||||
|
||||
239
05_model_training.ipynb
Normal file
239
05_model_training.ipynb
Normal file
@ -0,0 +1,239 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# 训练像素模型\n",
|
||||
"用这个文件可以训练出需要使用的光谱像素点模型"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"import pickle\n",
|
||||
"from utils import read_envi_ascii\n",
|
||||
"from config import Config\n",
|
||||
"from models import ManualTree"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# 一些变量"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data_path = r'data/envi20220802.txt'\n",
|
||||
"name_dict = {'tobacco': 1, 'yantou':2, 'kazhi':3, 'bomo':4, 'jiaodai':5, 'background':0}"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# 构建数据集"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data = read_envi_ascii(data_path)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"zibian (569, 448)\n",
|
||||
"tobacco (1457, 448)\n",
|
||||
"yantou (354, 448)\n",
|
||||
"kazhi (449, 448)\n",
|
||||
"bomo (1154, 448)\n",
|
||||
"jiaodai (566, 448)\n",
|
||||
"background (1235, 448)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"_ = [print(class_name, d.shape) for class_name, d in data.items()]"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data_x = [d for class_name, d in data.items() if class_name in name_dict.keys()]\n",
|
||||
"data_y = [np.ones((d.shape[0], ))*name_dict[class_name] for class_name, d in data.items() if class_name in name_dict.keys()]\n",
|
||||
"data_x, data_y = np.concatenate(data_x), np.concatenate(data_y)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## 取出需要的22个特征谱段"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"这些是现在的数据: (5215, 448) (5215,)\n",
|
||||
"截取其中需要的部分后: (5215, 22) (5215,)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(\"这些是现在的数据: \", data_x.shape, data_y.shape)\n",
|
||||
"data_x_cut = data_x[..., Config.bands]\n",
|
||||
"print(\"截取其中需要的部分后: \", data_x_cut.shape, data_y.shape)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## 进行样本平衡"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"这是重采样后的数据: (8742, 22) (8742,)\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from imblearn.over_sampling import RandomOverSampler\n",
|
||||
"ros = RandomOverSampler(random_state=0)\n",
|
||||
"x_resampled, y_resampled = ros.fit_resample(data_x_cut, data_y)\n",
|
||||
"print('这是重采样后的数据: ', x_resampled.shape, y_resampled.shape)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"#"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 2
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython2",
|
||||
"version": "2.7.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
55
utils.py
55
utils.py
@ -140,6 +140,61 @@ def size_threshold(img, blk_size, threshold):
|
||||
return mask
|
||||
|
||||
|
||||
def read_envi_ascii(file_name, save_xy=False, hdr_file_name=None):
|
||||
"""
|
||||
Read envi ascii file. Use ENVI ROI Tool -> File -> output ROIs to ASCII...
|
||||
|
||||
:param file_name: file name of ENVI ascii file
|
||||
:param hdr_file_name: hdr file name for a "BANDS" vector in the output
|
||||
:param save_xy: save the x, y position on the first two cols of the result vector
|
||||
:return: dict {class_name: vector, ...}
|
||||
"""
|
||||
number_line_start_with = "; Number of ROIs: "
|
||||
roi_name_start_with, roi_npts_start_with = "; ROI name: ", "; ROI npts: "
|
||||
data_start_with = "; ID"
|
||||
class_num, class_names, class_nums, vectors = 0, [], [], []
|
||||
with open(file_name, 'r') as f:
|
||||
for line_text in f:
|
||||
if line_text.startswith(number_line_start_with):
|
||||
class_num = int(line_text[len(number_line_start_with):])
|
||||
elif line_text.startswith(roi_name_start_with):
|
||||
class_names.append(line_text[len(roi_name_start_with):-1])
|
||||
elif line_text.startswith(roi_npts_start_with):
|
||||
class_nums.append(int(line_text[len(roi_name_start_with):-1]))
|
||||
elif line_text.startswith(data_start_with):
|
||||
col_list = list(filter(None, line_text[1:].split(" ")))
|
||||
assert (len(class_names) == class_num) and (len(class_names) == len(class_nums))
|
||||
break
|
||||
elif line_text.startswith(";"):
|
||||
continue
|
||||
for vector_rows in class_nums:
|
||||
vector_str = ''
|
||||
for i in range(vector_rows):
|
||||
vector_str += f.readline()
|
||||
vector = np.fromstring(vector_str, dtype=np.float, sep=" ").reshape(-1, len(col_list))
|
||||
assert vector.shape[0] == vector_rows
|
||||
vector = vector[:, 3:] if not save_xy else vector[:, 1:]
|
||||
vectors.append(vector)
|
||||
f.readline() # suppose to read a blank line
|
||||
if hdr_file_name is not None:
|
||||
bands = []
|
||||
with open(hdr_file_name, 'r') as f:
|
||||
start_bands = False
|
||||
for line_text in f:
|
||||
if start_bands:
|
||||
if line_text.endswith(",\n"):
|
||||
bands.append(float(line_text[:-2]))
|
||||
else:
|
||||
bands.append(float(line_text))
|
||||
break
|
||||
elif line_text.startswith("wavelength ="):
|
||||
start_bands = True
|
||||
bands = np.array(bands, dtype=np.float)
|
||||
vectors.append(bands)
|
||||
class_names.append("BANDS")
|
||||
return dict(zip(class_names, vectors))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
color_dict = {(0, 0, 255): "yangeng", (255, 0, 0): "bejing", (0, 255, 0): "hongdianxian",
|
||||
(255, 0, 255): "chengsebangbangtang", (0, 255, 255): "lvdianxian"}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user