mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 06:13:53 +00:00
分离预测
This commit is contained in:
parent
6444e923e6
commit
b1be6eeb49
@ -12,7 +12,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 7,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
@ -30,16 +30,16 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 8,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_from_existed = True # 是否从现有数据训练,如果是的话,那就从dataset_file训练,否则就用data_dir里头的数据\n",
|
||||
"train_from_existed = False # 是否从现有数据训练,如果是的话,那就从dataset_file训练,否则就用data_dir里头的数据\n",
|
||||
"data_dir = \"data/dataset\" # 数据集,文件夹下必须包含`img`和`label`两个文件夹,放置相同文件名的图片和label\n",
|
||||
"dataset_file = \"data/dataset/dataset_2022-07-20_10-04.mat\"\n",
|
||||
"\n",
|
||||
"# color_dict = {(0, 0, 255): \"yangeng\", (255, 0, 0): 'beijing'} # 颜色对应的类别\n",
|
||||
"color_dict = {(0, 0, 255): \"yangeng\"}\n",
|
||||
"color_dict = {(255, 0, 0): 'beijing'}\n",
|
||||
"# color_dict = {(0, 0, 255): \"yangeng\"}\n",
|
||||
"# color_dict = {(255, 0, 0): 'beijing'}\n",
|
||||
"color_dict = {(0, 255, 0): \"zibian\"}\n",
|
||||
"label_index = {\"yangeng\": 1, \"beijing\": 0, \"zibian\":2} # 类别对应的序号\n",
|
||||
"show_samples = False # 是否展示样本\n",
|
||||
@ -70,7 +70,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 9,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)\n",
|
||||
@ -99,14 +99,15 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 10,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"rus = RandomUnderSampler(random_state=0)\n",
|
||||
"x_list, y_list = np.concatenate([v for k, v in dataset.items()], axis=0).tolist(), \\\n",
|
||||
"if len(dataset) > 1:\n",
|
||||
" rus = RandomUnderSampler(random_state=0)\n",
|
||||
" x_list, y_list = np.concatenate([v for k, v in dataset.items()], axis=0).tolist(), \\\n",
|
||||
" np.concatenate([np.ones((v.shape[0],)) * label_index[k] for k, v in dataset.items()], axis=0).tolist()\n",
|
||||
"x_resampled, y_resampled = rus.fit_resample(x_list, y_list)\n",
|
||||
"dataset = {\"inside\": np.array(x_resampled)}"
|
||||
" x_resampled, y_resampled = rus.fit_resample(x_list, y_list)\n",
|
||||
" dataset = {\"inside\": np.array(x_resampled)}"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
@ -129,7 +130,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 11,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# 对数据进行预处理\n",
|
||||
@ -146,39 +147,35 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"if train_from_existed:\n",
|
||||
" data = scipy.io.loadmat(dataset_file)\n",
|
||||
" x, y = data['x'], data['y'].ravel()\n",
|
||||
" model.fit(x, y=y, is_generate_negative=False, model_selection='dt')\n",
|
||||
"else:\n",
|
||||
" world_boundary = np.array([0, 0, 0, 255, 255, 255])\n",
|
||||
" model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7,\n",
|
||||
" is_save_dataset=True, model_selection='dt')\n",
|
||||
"model.save()"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 12,
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "NameError",
|
||||
"evalue": "name 'train_from_existed' is not defined",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
|
||||
"\u001B[0;31mNameError\u001B[0m Traceback (most recent call last)",
|
||||
"\u001B[0;32m/var/folders/wh/kr5c3dr12834pfk3j7yqnrq40000gn/T/ipykernel_30867/103680055.py\u001B[0m in \u001B[0;36m<module>\u001B[0;34m\u001B[0m\n\u001B[0;32m----> 1\u001B[0;31m \u001B[0;32mif\u001B[0m \u001B[0mtrain_from_existed\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 2\u001B[0m \u001B[0mdata\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mscipy\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mio\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mloadmat\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mdataset_file\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 3\u001B[0m \u001B[0mx\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0my\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mdata\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;34m'x'\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mdata\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;34m'y'\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mravel\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 4\u001B[0m \u001B[0mmodel\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mfit\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0my\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0my\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mis_generate_negative\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;32mFalse\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mmodel_selection\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;34m'dt'\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 5\u001B[0m \u001B[0mmodel\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0msave\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
|
||||
"\u001B[0;31mNameError\u001B[0m: name 'train_from_existed' is not defined"
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"1923it [00:00, 5114.42it/s] "
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" 0 0.99 0.99 0.99 577\n",
|
||||
" 1 0.99 0.99 0.99 481\n",
|
||||
"\n",
|
||||
" accuracy 0.99 1058\n",
|
||||
" macro avg 0.99 0.99 0.99 1058\n",
|
||||
"weighted avg 0.99 0.99 0.99 1058\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
@ -12,8 +12,8 @@ dataset_file = "data/dataset/dataset_2022-07-20_10-04.mat"
|
||||
|
||||
# color_dict = {(0, 0, 255): "yangeng", (255, 0, 0): 'beijing'} # 颜色对应的类别
|
||||
color_dict = {(0, 0, 255): "yangeng"}
|
||||
color_dict = {(255, 0, 0): 'beijing'}
|
||||
color_dict = {(0, 255, 0): "zibian"}
|
||||
# color_dict = {(255, 0, 0): 'beijing'}
|
||||
# color_dict = {(0, 255, 0): "zibian"}
|
||||
label_index = {"yangeng": 1, "beijing": 0, "zibian": 2} # 类别对应的序号
|
||||
show_samples = False # 是否展示样本
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"execution_count": 3,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import datetime\n",
|
||||
@ -37,7 +37,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 4,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model_path = 'models/dt_2022-07-19_16-03.model'\n",
|
||||
@ -52,7 +52,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"execution_count": 5,
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
@ -65,7 +65,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "<IPython.core.display.HTML object>",
|
||||
"text/html": "<div id='7cca4d64-83fc-4ab9-9e9c-2e378a71e01e'></div>"
|
||||
"text/html": "<div id='2ffc90a2-1e1c-4280-91f7-d085a8204883'></div>"
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
@ -81,7 +81,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "<IPython.core.display.HTML object>",
|
||||
"text/html": "<div id='b1888224-2f86-48e2-9725-4e1a56eeb4a3'></div>"
|
||||
"text/html": "<div id='a9e43e6c-8500-444b-9c77-886843c0cbeb'></div>"
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
@ -102,10 +102,10 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"execution_count": 10,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"dataset_path = 'data/dataset/dataset_2022-07-19_16-03.mat'"
|
||||
"dataset_path = 'data/dataset_old/dataset_2022-07-19_16-03.mat'"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
@ -116,7 +116,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"execution_count": 11,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"new_dataset_path = datetime.datetime.now().strftime(\"data/dataset/dataset_%Y-%m-%d_%H-%M.mat\")"
|
||||
@ -130,7 +130,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"execution_count": 12,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data = scipy.io.loadmat(dataset_path)\n",
|
||||
@ -145,7 +145,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"execution_count": 13,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"nimg = img[result == 0]"
|
||||
@ -159,7 +159,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"execution_count": 14,
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
@ -172,7 +172,7 @@
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "<IPython.core.display.HTML object>",
|
||||
"text/html": "<div id='8f2acac1-57e7-41a6-9cc3-3d514a75fcd7'></div>"
|
||||
"text/html": "<div id='cdb3654b-8248-4ed4-897c-b920b33dbd8c'></div>"
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
@ -224,7 +224,7 @@
|
||||
"execution_count": 28,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"scipy.io.savemat(new_dataset_path, {'x': x, 'y': y})"
|
||||
"# scipy.io.savemat(new_dataset_path, {'x': x, 'y': y})"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
|
||||
193
04_multi_classification.ipynb
Normal file
193
04_multi_classification.ipynb
Normal file
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user