mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 06:13:53 +00:00
222 lines
5.6 KiB
Plaintext
222 lines
5.6 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%% md\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"outputs": [],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"import scipy\n",
|
||
"from imblearn.under_sampling import RandomUnderSampler\n",
|
||
"from models import AnonymousColorDetector\n",
|
||
"from utils import read_labeled_img"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%%\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"outputs": [],
|
||
"source": [
|
||
"train_from_existed = False # 是否从现有数据训练,如果是的话,那就从dataset_file训练,否则就用data_dir里头的数据\n",
|
||
"data_dir = \"data/dataset\" # 数据集,文件夹下必须包含`img`和`label`两个文件夹,放置相同文件名的图片和label\n",
|
||
"dataset_file = \"data/dataset/dataset_2022-07-20_10-04.mat\"\n",
|
||
"\n",
|
||
"# color_dict = {(0, 0, 255): \"yangeng\", (255, 0, 0): 'beijing'} # 颜色对应的类别\n",
|
||
"# color_dict = {(0, 0, 255): \"yangeng\"}\n",
|
||
"# color_dict = {(255, 0, 0): 'beijing'}\n",
|
||
"color_dict = {(0, 255, 0): \"zibian\"}\n",
|
||
"label_index = {\"yangeng\": 1, \"beijing\": 0, \"zibian\":2} # 类别对应的序号\n",
|
||
"show_samples = False # 是否展示样本\n",
|
||
"\n",
|
||
"# 定义一些训练量\n",
|
||
"threshold = 5 # 正样本周围多大范围内的还算是正样本\n",
|
||
"node_num = 20 # 如果使用ELM作为分类器物,有多少的节点\n",
|
||
"negative_sample_num = None # None或者一个数字,对应生成的负样本数量"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%%\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"## 读取数据"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%% md\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"outputs": [],
|
||
"source": [
|
||
"dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)\n",
|
||
"if show_samples:\n",
|
||
" from utils import lab_scatter\n",
|
||
" lab_scatter(dataset, class_max_num=30000, is_3d=True, is_ps_color_space=False)"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%%\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"## 数据平衡化"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%% md\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"outputs": [],
|
||
"source": [
|
||
"if len(dataset) > 1:\n",
|
||
" rus = RandomUnderSampler(random_state=0)\n",
|
||
" x_list, y_list = np.concatenate([v for k, v in dataset.items()], axis=0).tolist(), \\\n",
|
||
" np.concatenate([np.ones((v.shape[0],)) * label_index[k] for k, v in dataset.items()], axis=0).tolist()\n",
|
||
" x_resampled, y_resampled = rus.fit_resample(x_list, y_list)\n",
|
||
" dataset = {\"inside\": np.array(x_resampled)}"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%%\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"source": [
|
||
"## 模型训练"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%% md\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 11,
|
||
"outputs": [],
|
||
"source": [
|
||
"# 对数据进行预处理\n",
|
||
"x = np.concatenate([v for k, v in dataset.items()], axis=0)\n",
|
||
"negative_sample_num = int(x.shape[0] * 1.2) if negative_sample_num is None else negative_sample_num\n",
|
||
"model = AnonymousColorDetector()"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%%\n"
|
||
}
|
||
}
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"1923it [00:00, 5114.42it/s] "
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
" precision recall f1-score support\n",
|
||
"\n",
|
||
" 0 0.99 0.99 0.99 577\n",
|
||
" 1 0.99 0.99 0.99 481\n",
|
||
"\n",
|
||
" accuracy 0.99 1058\n",
|
||
" macro avg 0.99 0.99 0.99 1058\n",
|
||
"weighted avg 0.99 0.99 0.99 1058\n",
|
||
"\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"if train_from_existed:\n",
|
||
" data = scipy.io.loadmat(dataset_file)\n",
|
||
" x, y = data['x'], data['y'].ravel()\n",
|
||
" model.fit(x, y=y, is_generate_negative=False, model_selection='dt')\n",
|
||
"else:\n",
|
||
" world_boundary = np.array([0, 0, 0, 255, 255, 255])\n",
|
||
" model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7,\n",
|
||
" is_save_dataset=True, model_selection='dt')\n",
|
||
"model.save()"
|
||
],
|
||
"metadata": {
|
||
"collapsed": false,
|
||
"pycharm": {
|
||
"name": "#%%\n"
|
||
}
|
||
}
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3 (ipykernel)",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.10.0"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 1
|
||
} |