mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 22:33:54 +00:00
加入紫边,分离模型
This commit is contained in:
parent
62440b5bb9
commit
6444e923e6
@ -2,21 +2,17 @@
|
|||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "markdown",
|
"cell_type": "markdown",
|
||||||
|
"source": [],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
"pycharm": {
|
"pycharm": {
|
||||||
"name": "#%% md\n"
|
"name": "#%% md\n"
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
"source": []
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 5,
|
"execution_count": null,
|
||||||
"metadata": {
|
|
||||||
"pycharm": {
|
|
||||||
"name": "#%%\n"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
"import numpy as np\n",
|
"import numpy as np\n",
|
||||||
@ -24,7 +20,13 @@
|
|||||||
"from imblearn.under_sampling import RandomUnderSampler\n",
|
"from imblearn.under_sampling import RandomUnderSampler\n",
|
||||||
"from models import AnonymousColorDetector\n",
|
"from models import AnonymousColorDetector\n",
|
||||||
"from utils import read_labeled_img"
|
"from utils import read_labeled_img"
|
||||||
]
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
@ -35,8 +37,11 @@
|
|||||||
"data_dir = \"data/dataset\" # 数据集,文件夹下必须包含`img`和`label`两个文件夹,放置相同文件名的图片和label\n",
|
"data_dir = \"data/dataset\" # 数据集,文件夹下必须包含`img`和`label`两个文件夹,放置相同文件名的图片和label\n",
|
||||||
"dataset_file = \"data/dataset/dataset_2022-07-20_10-04.mat\"\n",
|
"dataset_file = \"data/dataset/dataset_2022-07-20_10-04.mat\"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"color_dict = {(0, 0, 255): \"yangeng\", (255, 0, 0): 'beijing'} # 颜色对应的类别\n",
|
"# color_dict = {(0, 0, 255): \"yangeng\", (255, 0, 0): 'beijing'} # 颜色对应的类别\n",
|
||||||
"label_index = {\"yangeng\": 1, \"beijing\": 0} # 类别对应的序号\n",
|
"color_dict = {(0, 0, 255): \"yangeng\"}\n",
|
||||||
|
"color_dict = {(255, 0, 0): 'beijing'}\n",
|
||||||
|
"color_dict = {(0, 255, 0): \"zibian\"}\n",
|
||||||
|
"label_index = {\"yangeng\": 1, \"beijing\": 0, \"zibian\":2} # 类别对应的序号\n",
|
||||||
"show_samples = False # 是否展示样本\n",
|
"show_samples = False # 是否展示样本\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# 定义一些训练量\n",
|
"# 定义一些训练量\n",
|
||||||
@ -57,26 +62,16 @@
|
|||||||
"## 读取数据"
|
"## 读取数据"
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": false
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 10,
|
"execution_count": null,
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"ename": "FileNotFoundError",
|
|
||||||
"evalue": "[Errno 2] No such file or directory: 'data/dataset/label'",
|
|
||||||
"output_type": "error",
|
|
||||||
"traceback": [
|
|
||||||
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
|
|
||||||
"\u001B[0;31mFileNotFoundError\u001B[0m Traceback (most recent call last)",
|
|
||||||
"\u001B[0;32m/var/folders/wh/kr5c3dr12834pfk3j7yqnrq40000gn/T/ipykernel_30867/1942905945.py\u001B[0m in \u001B[0;36m<module>\u001B[0;34m\u001B[0m\n\u001B[0;32m----> 1\u001B[0;31m \u001B[0mdataset\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mread_labeled_img\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mdata_dir\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mcolor_dict\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0mcolor_dict\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mis_ps_color_space\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;32mFalse\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 2\u001B[0m \u001B[0;32mif\u001B[0m \u001B[0mshow_samples\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 3\u001B[0m \u001B[0;32mfrom\u001B[0m \u001B[0mutils\u001B[0m \u001B[0;32mimport\u001B[0m \u001B[0mlab_scatter\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 4\u001B[0m \u001B[0mlab_scatter\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mdataset\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mclass_max_num\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;36m30000\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mis_3d\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;32mTrue\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mis_ps_color_space\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;32mFalse\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
|
|
||||||
"\u001B[0;32m~/PycharmProjects/tobacco_color/utils.py\u001B[0m in \u001B[0;36mread_labeled_img\u001B[0;34m(dataset_dir, color_dict, ext, is_ps_color_space)\u001B[0m\n\u001B[1;32m 37\u001B[0m \u001B[0;34m:\u001B[0m\u001B[0;32mreturn\u001B[0m\u001B[0;34m:\u001B[0m \u001B[0m字典形式的数据集\u001B[0m\u001B[0;34m{\u001B[0m\u001B[0mlabel\u001B[0m\u001B[0;34m:\u001B[0m \u001B[0mvector\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mn\u001B[0m \u001B[0mx\u001B[0m \u001B[0;36m3\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m}\u001B[0m\u001B[0;34m,\u001B[0m\u001B[0mvector为lab色彩空间\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 38\u001B[0m \"\"\"\n\u001B[0;32m---> 39\u001B[0;31m img_names = [img_name for img_name in os.listdir(os.path.join(dataset_dir, 'label'))\n\u001B[0m\u001B[1;32m 40\u001B[0m if img_name.endswith(ext)]\n\u001B[1;32m 41\u001B[0m \u001B[0mtotal_dataset\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mMergeDict\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
|
|
||||||
"\u001B[0;31mFileNotFoundError\u001B[0m: [Errno 2] No such file or directory: 'data/dataset/label'"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)\n",
|
"dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)\n",
|
||||||
"if show_samples:\n",
|
"if show_samples:\n",
|
||||||
@ -96,25 +91,16 @@
|
|||||||
"## 数据平衡化"
|
"## 数据平衡化"
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": false
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 11,
|
"execution_count": null,
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"ename": "NameError",
|
|
||||||
"evalue": "name 'dataset' is not defined",
|
|
||||||
"output_type": "error",
|
|
||||||
"traceback": [
|
|
||||||
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
|
|
||||||
"\u001B[0;31mNameError\u001B[0m Traceback (most recent call last)",
|
|
||||||
"\u001B[0;32m/var/folders/wh/kr5c3dr12834pfk3j7yqnrq40000gn/T/ipykernel_30867/603974095.py\u001B[0m in \u001B[0;36m<module>\u001B[0;34m\u001B[0m\n\u001B[1;32m 1\u001B[0m \u001B[0mrus\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mRandomUnderSampler\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mrandom_state\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;36m0\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m----> 2\u001B[0;31m \u001B[0mx_list\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0my_list\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mnp\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mconcatenate\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0mv\u001B[0m \u001B[0;32mfor\u001B[0m \u001B[0mk\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mv\u001B[0m \u001B[0;32min\u001B[0m \u001B[0mdataset\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mitems\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0maxis\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;36m0\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mtolist\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m,\u001B[0m\u001B[0;31m \u001B[0m\u001B[0;31m\\\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 3\u001B[0m \u001B[0mnp\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mconcatenate\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0mnp\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mones\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mv\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mshape\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;36m0\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m)\u001B[0m \u001B[0;34m*\u001B[0m \u001B[0mlabel_index\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0mk\u001B[0m\u001B[0;34m]\u001B[0m \u001B[0;32mfor\u001B[0m \u001B[0mk\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mv\u001B[0m \u001B[0;32min\u001B[0m \u001B[0mdataset\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mitems\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0maxis\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;36m0\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mtolist\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 4\u001B[0m \u001B[0mx_resampled\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0my_resampled\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mrus\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mfit_resample\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx_list\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0my_list\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 5\u001B[0m \u001B[0mdataset\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;34m{\u001B[0m\u001B[0;34m\"inside\"\u001B[0m\u001B[0;34m:\u001B[0m \u001B[0mnp\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0marray\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx_resampled\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m}\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
|
|
||||||
"\u001B[0;31mNameError\u001B[0m: name 'dataset' is not defined"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"rus = RandomUnderSampler(random_state=0)\n",
|
"rus = RandomUnderSampler(random_state=0)\n",
|
||||||
"x_list, y_list = np.concatenate([v for k, v in dataset.items()], axis=0).tolist(), \\\n",
|
"x_list, y_list = np.concatenate([v for k, v in dataset.items()], axis=0).tolist(), \\\n",
|
||||||
@ -135,25 +121,16 @@
|
|||||||
"## 模型训练"
|
"## 模型训练"
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"collapsed": false
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%% md\n"
|
||||||
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 12,
|
"execution_count": null,
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"ename": "NameError",
|
|
||||||
"evalue": "name 'dataset' is not defined",
|
|
||||||
"output_type": "error",
|
|
||||||
"traceback": [
|
|
||||||
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
|
|
||||||
"\u001B[0;31mNameError\u001B[0m Traceback (most recent call last)",
|
|
||||||
"\u001B[0;32m/var/folders/wh/kr5c3dr12834pfk3j7yqnrq40000gn/T/ipykernel_30867/1828483636.py\u001B[0m in \u001B[0;36m<module>\u001B[0;34m\u001B[0m\n\u001B[1;32m 1\u001B[0m \u001B[0;31m# 对数据进行预处理\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m----> 2\u001B[0;31m \u001B[0mx\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mnp\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mconcatenate\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0mv\u001B[0m \u001B[0;32mfor\u001B[0m \u001B[0mk\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mv\u001B[0m \u001B[0;32min\u001B[0m \u001B[0mdataset\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mitems\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0maxis\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;36m0\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 3\u001B[0m \u001B[0mnegative_sample_num\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mint\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mshape\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;36m0\u001B[0m\u001B[0;34m]\u001B[0m \u001B[0;34m*\u001B[0m \u001B[0;36m1.2\u001B[0m\u001B[0;34m)\u001B[0m \u001B[0;32mif\u001B[0m \u001B[0mnegative_sample_num\u001B[0m \u001B[0;32mis\u001B[0m \u001B[0;32mNone\u001B[0m \u001B[0;32melse\u001B[0m \u001B[0mnegative_sample_num\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 4\u001B[0m \u001B[0mmodel\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mAnonymousColorDetector\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n",
|
|
||||||
"\u001B[0;31mNameError\u001B[0m: name 'dataset' is not defined"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"# 对数据进行预处理\n",
|
"# 对数据进行预处理\n",
|
||||||
"x = np.concatenate([v for k, v in dataset.items()], axis=0)\n",
|
"x = np.concatenate([v for k, v in dataset.items()], axis=0)\n",
|
||||||
@ -167,6 +144,28 @@
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"if train_from_existed:\n",
|
||||||
|
" data = scipy.io.loadmat(dataset_file)\n",
|
||||||
|
" x, y = data['x'], data['y'].ravel()\n",
|
||||||
|
" model.fit(x, y=y, is_generate_negative=False, model_selection='dt')\n",
|
||||||
|
"else:\n",
|
||||||
|
" world_boundary = np.array([0, 0, 0, 255, 255, 255])\n",
|
||||||
|
" model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7,\n",
|
||||||
|
" is_save_dataset=True, model_selection='dt')\n",
|
||||||
|
"model.save()"
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"collapsed": false,
|
||||||
|
"pycharm": {
|
||||||
|
"name": "#%%\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 14,
|
"execution_count": 14,
|
||||||
|
|||||||
@ -1,53 +1,66 @@
|
|||||||
#!/usr/bin/env python
|
# %%
|
||||||
# coding: utf-8
|
|
||||||
|
|
||||||
# # 模型的训练
|
|
||||||
|
|
||||||
# In[16]:
|
|
||||||
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import scipy
|
import scipy
|
||||||
from imblearn.under_sampling import RandomUnderSampler
|
from imblearn.under_sampling import RandomUnderSampler
|
||||||
from models import AnonymousColorDetector
|
from models import AnonymousColorDetector
|
||||||
from utils import read_labeled_img
|
from utils import read_labeled_img
|
||||||
|
|
||||||
# ## 读取数据与构建数据集
|
# %%
|
||||||
|
train_from_existed = True # 是否从现有数据训练,如果是的话,那就从dataset_file训练,否则就用data_dir里头的数据
|
||||||
|
data_dir = "data/dataset" # 数据集,文件夹下必须包含`img`和`label`两个文件夹,放置相同文件名的图片和label
|
||||||
|
dataset_file = "data/dataset/dataset_2022-07-20_10-04.mat"
|
||||||
|
|
||||||
# In[17]:
|
# color_dict = {(0, 0, 255): "yangeng", (255, 0, 0): 'beijing'} # 颜色对应的类别
|
||||||
|
color_dict = {(0, 0, 255): "yangeng"}
|
||||||
|
color_dict = {(255, 0, 0): 'beijing'}
|
||||||
|
color_dict = {(0, 255, 0): "zibian"}
|
||||||
|
label_index = {"yangeng": 1, "beijing": 0, "zibian": 2} # 类别对应的序号
|
||||||
|
show_samples = False # 是否展示样本
|
||||||
|
|
||||||
|
# 定义一些训练量
|
||||||
data_dir = "data/dataset"
|
threshold = 5 # 正样本周围多大范围内的还算是正样本
|
||||||
color_dict = {(0, 0, 255): "yangeng", (255, 0, 0): 'beijing'}
|
node_num = 20 # 如果使用ELM作为分类器物,有多少的节点
|
||||||
label_index = {"yangeng": 1, "beijing": 0}
|
negative_sample_num = None # None或者一个数字,对应生成的负样本数量
|
||||||
|
# %% md
|
||||||
|
## 读取数据
|
||||||
|
# %%
|
||||||
dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)
|
dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)
|
||||||
|
if show_samples:
|
||||||
|
from utils import lab_scatter
|
||||||
|
|
||||||
|
lab_scatter(dataset, class_max_num=30000, is_3d=True, is_ps_color_space=False)
|
||||||
|
# %% md
|
||||||
|
## 数据平衡化
|
||||||
|
# %%
|
||||||
rus = RandomUnderSampler(random_state=0)
|
rus = RandomUnderSampler(random_state=0)
|
||||||
x_list, y_list = np.concatenate([v for k, v in dataset.items()], axis=0).tolist(), \
|
x_list, y_list = np.concatenate([v for k, v in dataset.items()], axis=0).tolist(), \
|
||||||
np.concatenate([np.ones((v.shape[0],)) * label_index[k] for k, v in dataset.items()], axis=0).tolist()
|
np.concatenate([np.ones((v.shape[0],)) * label_index[k] for k, v in dataset.items()], axis=0).tolist()
|
||||||
|
|
||||||
x_resampled, y_resampled = rus.fit_resample(x_list, y_list)
|
x_resampled, y_resampled = rus.fit_resample(x_list, y_list)
|
||||||
dataset = {"inside": np.array(x_resampled)}
|
dataset = {"inside": np.array(x_resampled)}
|
||||||
|
# %% md
|
||||||
# ## 模型训练
|
## 模型训练
|
||||||
|
# %%
|
||||||
# In[18]:
|
|
||||||
|
|
||||||
|
|
||||||
# 定义一些常量
|
|
||||||
threshold = 5
|
|
||||||
node_num = 20
|
|
||||||
negative_sample_num = None # None或者一个数字
|
|
||||||
world_boundary = np.array([0, 0, 0, 255, 255, 255])
|
|
||||||
# 对数据进行预处理
|
# 对数据进行预处理
|
||||||
x = np.concatenate([v for k, v in dataset.items()], axis=0)
|
x = np.concatenate([v for k, v in dataset.items()], axis=0)
|
||||||
negative_sample_num = int(x.shape[0] * 1.2) if negative_sample_num is None else negative_sample_num
|
negative_sample_num = int(x.shape[0] * 1.2) if negative_sample_num is None else negative_sample_num
|
||||||
|
|
||||||
model = AnonymousColorDetector()
|
model = AnonymousColorDetector()
|
||||||
|
# %%
|
||||||
model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7,
|
if train_from_existed:
|
||||||
is_save_dataset=True, model_selection='dt')
|
data = scipy.io.loadmat(dataset_file)
|
||||||
# data = scipy.io.loadmat('dataset_2022-07-19_15-07.mat')
|
x, y = data['x'], data['y'].ravel()
|
||||||
# x, y = data['x'], data['y'].ravel()
|
model.fit(x, y=y, is_generate_negative=False, model_selection='dt')
|
||||||
# model.fit(x, y=y, is_generate_negative=False, model_selection='dt')
|
else:
|
||||||
|
world_boundary = np.array([0, 0, 0, 255, 255, 255])
|
||||||
|
model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7,
|
||||||
|
is_save_dataset=True, model_selection='dt')
|
||||||
|
model.save()
|
||||||
|
# %%
|
||||||
|
if train_from_existed:
|
||||||
|
data = scipy.io.loadmat(dataset_file)
|
||||||
|
x, y = data['x'], data['y'].ravel()
|
||||||
|
model.fit(x, y=y, is_generate_negative=False, model_selection='dt')
|
||||||
|
else:
|
||||||
|
world_boundary = np.array([0, 0, 0, 255, 255, 255])
|
||||||
|
model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7,
|
||||||
|
is_save_dataset=True, model_selection='dt')
|
||||||
model.save()
|
model.save()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user