From 0e297c892cefbaa90ee0190cada3385e4fe35e2c Mon Sep 17 00:00:00 2001 From: FEIJINTI <83849113+FEIJINTI@users.noreply.github.com> Date: Tue, 2 Aug 2022 15:23:13 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E5=AE=8C=E6=88=90=E6=96=B0=E7=9A=84dt?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- 05_model_training.ipynb | 256 ++++++++++++++++++++++++++++++++++++++-- config.py | 2 +- 2 files changed, 249 insertions(+), 9 deletions(-) diff --git a/05_model_training.ipynb b/05_model_training.ipynb index 57a56e9..d610ed1 100644 --- a/05_model_training.ipynb +++ b/05_model_training.ipynb @@ -15,8 +15,17 @@ }, { "cell_type": "code", - "execution_count": 1, - "outputs": [], + "execution_count": 3, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\FEIJINTI\\miniconda3\\envs\\deepo\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + } + ], "source": [ "import numpy as np\n", "import pickle\n", @@ -45,7 +54,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 4, "outputs": [], "source": [ "data_path = r'data/envi20220802.txt'\n", @@ -72,7 +81,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "outputs": [], "source": [ "data = read_envi_ascii(data_path)" @@ -86,7 +95,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "outputs": [ { "name": "stdout", @@ -114,7 +123,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "outputs": [], "source": [ "data_x = [d for class_name, d in data.items() if class_name in name_dict.keys()]\n", @@ -142,7 +151,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "outputs": [ { "name": "stdout", @@ -179,7 +188,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "outputs": [ { "name": "stdout", @@ -215,6 +224,237 @@ } } }, + { + "cell_type": "code", + "execution_count": 10, + "outputs": [], + "source": [ + "from models import DecisionTree\n", + "from sklearn.model_selection import train_test_split" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 11, + "outputs": [], + "source": [ + "train_x, test_x, train_y, test_y = train_test_split(x_resampled, y_resampled, test_size=0.2)" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 12, + "outputs": [], + "source": [ + "tree = DecisionTree(class_weight={1:20})" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 13, + "outputs": [], + "source": [ + "tree = tree.fit(train_x, train_y)" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "# 模型评估" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "## 多分类精度" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 14, + "outputs": [], + "source": [ + "pred_y = tree.predict(test_x)" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 15, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " precision recall f1-score support\n", + "\n", + " 0.0 1.00 1.00 1.00 312\n", + " 1.0 0.98 0.97 0.98 289\n", + " 2.0 0.97 1.00 0.99 312\n", + " 3.0 0.95 0.99 0.97 278\n", + " 4.0 0.98 0.92 0.95 288\n", + " 5.0 0.98 0.98 0.98 270\n", + "\n", + " accuracy 0.98 1749\n", + " macro avg 0.98 0.98 0.98 1749\n", + "weighted avg 0.98 0.98 0.98 1749\n", + "\n" + ] + } + ], + "source": [ + "from sklearn.metrics import classification_report\n", + "\n", + "print(classification_report(y_pred=pred_y, y_true=test_y))" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "## 二分类精度" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 16, + "outputs": [], + "source": [ + "test_y[test_y <= 1] = 0\n", + "test_y[test_y > 1] = 1\n", + "pred_y = tree.predict_bin(test_x)" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 17, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " precision recall f1-score support\n", + "\n", + " 0.0 1.00 1.00 1.00 598\n", + " 1.0 1.00 1.00 1.00 1151\n", + "\n", + " accuracy 1.00 1749\n", + " macro avg 1.00 1.00 1.00 1749\n", + "weighted avg 1.00 1.00 1.00 1749\n", + "\n" + ] + } + ], + "source": [ + "print(classification_report(y_true=pred_y, y_pred=pred_y))" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "# 模型保存" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 18, + "outputs": [], + "source": [ + "import datetime\n", + "\n", + "path = datetime.datetime.now().strftime(f\"models/pixel_%Y-%m-%d_%H-%M.model\")\n", + "with open(path, 'wb') as f:\n", + " pickle.dump(tree, f)" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, { "cell_type": "code", "execution_count": 8, diff --git a/config.py b/config.py index 8965715..1bd2e56 100644 --- a/config.py +++ b/config.py @@ -19,7 +19,7 @@ class Config: # 光谱模型参数 blk_size = 4 - pixel_model_path = r"./models/dt.p" + pixel_model_path = r"./models/pixel_2022-08-02_15-22.model" blk_model_path = r"./models/rf_4x4_c22_20_sen8_9.model" spec_size_threshold = 3 From fe5eececc2229e892120832bddbe01956dc7a161 Mon Sep 17 00:00:00 2001 From: "li.zhenye" <李> Date: Tue, 2 Aug 2022 15:41:47 +0800 Subject: [PATCH 2/2] perfect version2 --- main_test.py | 2 +- models.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/main_test.py b/main_test.py index fc3a084..c8e895a 100644 --- a/main_test.py +++ b/main_test.py @@ -168,5 +168,5 @@ if __name__ == '__main__': parser = argparse.ArgumentParser(description='Run image test or ') tester = TestMain() tester.pony_run(test_path=r'/home/lzy/2022.7.30/tobacco_v1_0/saved_img/', - test_rgb=False, test_spectra=False, get_delta=False) + test_rgb=True, test_spectra=True, get_delta=False) diff --git a/models.py b/models.py index f0acf56..9af462f 100755 --- a/models.py +++ b/models.py @@ -260,7 +260,7 @@ class PixelModelML: self.dt = pickle.load(f) def predict(self, feature): - pixel_result_array = self.dt.predict(feature) + pixel_result_array = self.dt.predict_bin(feature) return pixel_result_array @@ -409,7 +409,7 @@ class SpecDetector(Detector): if x_yellow.shape[0] == 0: return non_yellow_things else: - tobacco = self.pixel_model_ml.predict_bin(x_yellow) < 0.5 + tobacco = self.pixel_model_ml.predict(x_yellow) < 0.5 non_yellow_things[yellow_things] = ~tobacco # 杂质mask中将背景赋值为0,将杂质赋值为1