diff --git a/04_multi_classification.ipynb b/04_multi_classification.ipynb index 1ce3ea4..6762671 100644 --- a/04_multi_classification.ipynb +++ b/04_multi_classification.ipynb @@ -23,10 +23,6 @@ "import scipy.io\n", "import cv2\n", "import numpy as np\n", - "import pickle\n", - "from sklearn.tree import DecisionTreeClassifier\n", - "# %matplotlib notebook\n", - "from main_test import pony_run\n", "from models import AnonymousColorDetector\n", "from utils import lab_scatter" ], diff --git a/05_model_training.ipynb b/05_model_training.ipynb new file mode 100644 index 0000000..01aa777 --- /dev/null +++ b/05_model_training.ipynb @@ -0,0 +1,239 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# 训练像素模型\n", + "用这个文件可以训练出需要使用的光谱像素点模型" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 22, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pickle\n", + "from utils import read_envi_ascii\n", + "from config import Config\n", + "from models import ManualTree" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "# 一些变量" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 23, + "outputs": [], + "source": [ + "data_path = r'data/envi20220802.txt'\n", + "name_dict = {'tobacco': 1, 'yantou':2, 'kazhi':3, 'bomo':4, 'jiaodai':5, 'background':0}" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "# 构建数据集" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 24, + "outputs": [], + "source": [ + "data = read_envi_ascii(data_path)" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 25, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "zibian (569, 448)\n", + "tobacco (1457, 448)\n", + "yantou (354, 448)\n", + "kazhi (449, 448)\n", + "bomo (1154, 448)\n", + "jiaodai (566, 448)\n", + "background (1235, 448)\n" + ] + } + ], + "source": [ + "_ = [print(class_name, d.shape) for class_name, d in data.items()]" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 26, + "outputs": [], + "source": [ + "data_x = [d for class_name, d in data.items() if class_name in name_dict.keys()]\n", + "data_y = [np.ones((d.shape[0], ))*name_dict[class_name] for class_name, d in data.items() if class_name in name_dict.keys()]\n", + "data_x, data_y = np.concatenate(data_x), np.concatenate(data_y)" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "## 取出需要的22个特征谱段" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 27, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "这些是现在的数据: (5215, 448) (5215,)\n", + "截取其中需要的部分后: (5215, 22) (5215,)\n" + ] + } + ], + "source": [ + "print(\"这些是现在的数据: \", data_x.shape, data_y.shape)\n", + "data_x_cut = data_x[..., Config.bands]\n", + "print(\"截取其中需要的部分后: \", data_x_cut.shape, data_y.shape)" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "## 进行样本平衡" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 30, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "这是重采样后的数据: (8742, 22) (8742,)\n" + ] + } + ], + "source": [ + "from imblearn.over_sampling import RandomOverSampler\n", + "ros = RandomOverSampler(random_state=0)\n", + "x_resampled, y_resampled = ros.fit_resample(data_x_cut, data_y)\n", + "print('这是重采样后的数据: ', x_resampled.shape, y_resampled.shape)" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "#" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/utils.py b/utils.py index 9359bc4..298d869 100755 --- a/utils.py +++ b/utils.py @@ -140,6 +140,61 @@ def size_threshold(img, blk_size, threshold): return mask +def read_envi_ascii(file_name, save_xy=False, hdr_file_name=None): + """ + Read envi ascii file. Use ENVI ROI Tool -> File -> output ROIs to ASCII... + + :param file_name: file name of ENVI ascii file + :param hdr_file_name: hdr file name for a "BANDS" vector in the output + :param save_xy: save the x, y position on the first two cols of the result vector + :return: dict {class_name: vector, ...} + """ + number_line_start_with = "; Number of ROIs: " + roi_name_start_with, roi_npts_start_with = "; ROI name: ", "; ROI npts: " + data_start_with = "; ID" + class_num, class_names, class_nums, vectors = 0, [], [], [] + with open(file_name, 'r') as f: + for line_text in f: + if line_text.startswith(number_line_start_with): + class_num = int(line_text[len(number_line_start_with):]) + elif line_text.startswith(roi_name_start_with): + class_names.append(line_text[len(roi_name_start_with):-1]) + elif line_text.startswith(roi_npts_start_with): + class_nums.append(int(line_text[len(roi_name_start_with):-1])) + elif line_text.startswith(data_start_with): + col_list = list(filter(None, line_text[1:].split(" "))) + assert (len(class_names) == class_num) and (len(class_names) == len(class_nums)) + break + elif line_text.startswith(";"): + continue + for vector_rows in class_nums: + vector_str = '' + for i in range(vector_rows): + vector_str += f.readline() + vector = np.fromstring(vector_str, dtype=np.float, sep=" ").reshape(-1, len(col_list)) + assert vector.shape[0] == vector_rows + vector = vector[:, 3:] if not save_xy else vector[:, 1:] + vectors.append(vector) + f.readline() # suppose to read a blank line + if hdr_file_name is not None: + bands = [] + with open(hdr_file_name, 'r') as f: + start_bands = False + for line_text in f: + if start_bands: + if line_text.endswith(",\n"): + bands.append(float(line_text[:-2])) + else: + bands.append(float(line_text)) + break + elif line_text.startswith("wavelength ="): + start_bands = True + bands = np.array(bands, dtype=np.float) + vectors.append(bands) + class_names.append("BANDS") + return dict(zip(class_names, vectors)) + + if __name__ == '__main__': color_dict = {(0, 0, 255): "yangeng", (255, 0, 0): "bejing", (0, 255, 0): "hongdianxian", (255, 0, 255): "chengsebangbangtang", (0, 255, 255): "lvdianxian"}