Merge remote-tracking branch 'origin/master'

This commit is contained in:
li.zhenye 2022-08-03 10:30:04 +08:00
commit 660711f6e6
4 changed files with 252 additions and 12 deletions

View File

@ -15,8 +15,17 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 3,
"outputs": [], "outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\FEIJINTI\\miniconda3\\envs\\deepo\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [ "source": [
"import numpy as np\n", "import numpy as np\n",
"import pickle\n", "import pickle\n",
@ -45,7 +54,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 4,
"outputs": [], "outputs": [],
"source": [ "source": [
"data_path = r'data/envi20220802.txt'\n", "data_path = r'data/envi20220802.txt'\n",
@ -72,7 +81,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 5,
"outputs": [], "outputs": [],
"source": [ "source": [
"data = read_envi_ascii(data_path)" "data = read_envi_ascii(data_path)"
@ -86,7 +95,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 6,
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
@ -114,7 +123,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 7,
"outputs": [], "outputs": [],
"source": [ "source": [
"data_x = [d for class_name, d in data.items() if class_name in name_dict.keys()]\n", "data_x = [d for class_name, d in data.items() if class_name in name_dict.keys()]\n",
@ -142,7 +151,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 8,
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
@ -179,7 +188,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 9,
"outputs": [ "outputs": [
{ {
"name": "stdout", "name": "stdout",
@ -215,6 +224,237 @@
} }
} }
}, },
{
"cell_type": "code",
"execution_count": 10,
"outputs": [],
"source": [
"from models import DecisionTree\n",
"from sklearn.model_selection import train_test_split"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 11,
"outputs": [],
"source": [
"train_x, test_x, train_y, test_y = train_test_split(x_resampled, y_resampled, test_size=0.2)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 12,
"outputs": [],
"source": [
"tree = DecisionTree(class_weight={1:20})"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 13,
"outputs": [],
"source": [
"tree = tree.fit(train_x, train_y)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"# 模型评估"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"## 多分类精度"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 14,
"outputs": [],
"source": [
"pred_y = tree.predict(test_x)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 15,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0.0 1.00 1.00 1.00 312\n",
" 1.0 0.98 0.97 0.98 289\n",
" 2.0 0.97 1.00 0.99 312\n",
" 3.0 0.95 0.99 0.97 278\n",
" 4.0 0.98 0.92 0.95 288\n",
" 5.0 0.98 0.98 0.98 270\n",
"\n",
" accuracy 0.98 1749\n",
" macro avg 0.98 0.98 0.98 1749\n",
"weighted avg 0.98 0.98 0.98 1749\n",
"\n"
]
}
],
"source": [
"from sklearn.metrics import classification_report\n",
"\n",
"print(classification_report(y_pred=pred_y, y_true=test_y))"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"## 二分类精度"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 16,
"outputs": [],
"source": [
"test_y[test_y <= 1] = 0\n",
"test_y[test_y > 1] = 1\n",
"pred_y = tree.predict_bin(test_x)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 17,
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0.0 1.00 1.00 1.00 598\n",
" 1.0 1.00 1.00 1.00 1151\n",
"\n",
" accuracy 1.00 1749\n",
" macro avg 1.00 1.00 1.00 1749\n",
"weighted avg 1.00 1.00 1.00 1749\n",
"\n"
]
}
],
"source": [
"print(classification_report(y_true=pred_y, y_pred=pred_y))"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"# 模型保存"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 18,
"outputs": [],
"source": [
"import datetime\n",
"\n",
"path = datetime.datetime.now().strftime(f\"models/pixel_%Y-%m-%d_%H-%M.model\")\n",
"with open(path, 'wb') as f:\n",
" pickle.dump(tree, f)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 8,

View File

@ -19,7 +19,7 @@ class Config:
# 光谱模型参数 # 光谱模型参数
blk_size = 4 blk_size = 4
pixel_model_path = r"./models/dt.p" pixel_model_path = r"./models/pixel_2022-08-02_15-22.model"
blk_model_path = r"./models/rf_4x4_c22_20_sen8_9.model" blk_model_path = r"./models/rf_4x4_c22_20_sen8_9.model"
spec_size_threshold = 3 spec_size_threshold = 3

View File

@ -168,5 +168,5 @@ if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Run image test or ') parser = argparse.ArgumentParser(description='Run image test or ')
tester = TestMain() tester = TestMain()
tester.pony_run(test_path=r'/home/lzy/2022.7.30/tobacco_v1_0/saved_img/', tester.pony_run(test_path=r'/home/lzy/2022.7.30/tobacco_v1_0/saved_img/',
test_rgb=False, test_spectra=False, get_delta=False) test_rgb=True, test_spectra=True, get_delta=False)

View File

@ -260,7 +260,7 @@ class PixelModelML:
self.dt = pickle.load(f) self.dt = pickle.load(f)
def predict(self, feature): def predict(self, feature):
pixel_result_array = self.dt.predict(feature) pixel_result_array = self.dt.predict_bin(feature)
return pixel_result_array return pixel_result_array
@ -409,7 +409,7 @@ class SpecDetector(Detector):
if x_yellow.shape[0] == 0: if x_yellow.shape[0] == 0:
return non_yellow_things return non_yellow_things
else: else:
tobacco = self.pixel_model_ml.predict_bin(x_yellow) < 0.5 tobacco = self.pixel_model_ml.predict(x_yellow) < 0.5
non_yellow_things[yellow_things] = ~tobacco non_yellow_things[yellow_things] = ~tobacco
# 杂质mask中将背景赋值为0,将杂质赋值为1 # 杂质mask中将背景赋值为0,将杂质赋值为1