diff --git a/02_classification.ipynb b/02_classification.ipynb index 9055ffc..d4728e9 100644 --- a/02_classification.ipynb +++ b/02_classification.ipynb @@ -2,21 +2,17 @@ "cells": [ { "cell_type": "markdown", + "source": [], "metadata": { + "collapsed": false, "pycharm": { "name": "#%% md\n" } - }, - "source": [] + } }, { "cell_type": "code", - "execution_count": 5, - "metadata": { - "pycharm": { - "name": "#%%\n" - } - }, + "execution_count": null, "outputs": [], "source": [ "import numpy as np\n", @@ -24,7 +20,13 @@ "from imblearn.under_sampling import RandomUnderSampler\n", "from models import AnonymousColorDetector\n", "from utils import read_labeled_img" - ] + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } }, { "cell_type": "code", @@ -35,8 +37,11 @@ "data_dir = \"data/dataset\" # 数据集,文件夹下必须包含`img`和`label`两个文件夹,放置相同文件名的图片和label\n", "dataset_file = \"data/dataset/dataset_2022-07-20_10-04.mat\"\n", "\n", - "color_dict = {(0, 0, 255): \"yangeng\", (255, 0, 0): 'beijing'} # 颜色对应的类别\n", - "label_index = {\"yangeng\": 1, \"beijing\": 0} # 类别对应的序号\n", + "# color_dict = {(0, 0, 255): \"yangeng\", (255, 0, 0): 'beijing'} # 颜色对应的类别\n", + "color_dict = {(0, 0, 255): \"yangeng\"}\n", + "color_dict = {(255, 0, 0): 'beijing'}\n", + "color_dict = {(0, 255, 0): \"zibian\"}\n", + "label_index = {\"yangeng\": 1, \"beijing\": 0, \"zibian\":2} # 类别对应的序号\n", "show_samples = False # 是否展示样本\n", "\n", "# 定义一些训练量\n", @@ -57,26 +62,16 @@ "## 读取数据" ], "metadata": { - "collapsed": false + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } } }, { "cell_type": "code", - "execution_count": 10, - "outputs": [ - { - "ename": "FileNotFoundError", - "evalue": "[Errno 2] No such file or directory: 'data/dataset/label'", - "output_type": "error", - "traceback": [ - "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", - "\u001B[0;31mFileNotFoundError\u001B[0m Traceback (most recent call last)", - "\u001B[0;32m/var/folders/wh/kr5c3dr12834pfk3j7yqnrq40000gn/T/ipykernel_30867/1942905945.py\u001B[0m in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[0;32m----> 1\u001B[0;31m \u001B[0mdataset\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mread_labeled_img\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mdata_dir\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mcolor_dict\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0mcolor_dict\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mis_ps_color_space\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;32mFalse\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 2\u001B[0m \u001B[0;32mif\u001B[0m \u001B[0mshow_samples\u001B[0m\u001B[0;34m:\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 3\u001B[0m \u001B[0;32mfrom\u001B[0m \u001B[0mutils\u001B[0m \u001B[0;32mimport\u001B[0m \u001B[0mlab_scatter\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 4\u001B[0m \u001B[0mlab_scatter\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mdataset\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mclass_max_num\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;36m30000\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mis_3d\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;32mTrue\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mis_ps_color_space\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;32mFalse\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", - "\u001B[0;32m~/PycharmProjects/tobacco_color/utils.py\u001B[0m in \u001B[0;36mread_labeled_img\u001B[0;34m(dataset_dir, color_dict, ext, is_ps_color_space)\u001B[0m\n\u001B[1;32m 37\u001B[0m \u001B[0;34m:\u001B[0m\u001B[0;32mreturn\u001B[0m\u001B[0;34m:\u001B[0m \u001B[0m字典形式的数据集\u001B[0m\u001B[0;34m{\u001B[0m\u001B[0mlabel\u001B[0m\u001B[0;34m:\u001B[0m \u001B[0mvector\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mn\u001B[0m \u001B[0mx\u001B[0m \u001B[0;36m3\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m}\u001B[0m\u001B[0;34m,\u001B[0m\u001B[0mvector为lab色彩空间\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 38\u001B[0m \"\"\"\n\u001B[0;32m---> 39\u001B[0;31m img_names = [img_name for img_name in os.listdir(os.path.join(dataset_dir, 'label'))\n\u001B[0m\u001B[1;32m 40\u001B[0m if img_name.endswith(ext)]\n\u001B[1;32m 41\u001B[0m \u001B[0mtotal_dataset\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mMergeDict\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", - "\u001B[0;31mFileNotFoundError\u001B[0m: [Errno 2] No such file or directory: 'data/dataset/label'" - ] - } - ], + "execution_count": null, + "outputs": [], "source": [ "dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)\n", "if show_samples:\n", @@ -96,25 +91,16 @@ "## 数据平衡化" ], "metadata": { - "collapsed": false + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } } }, { "cell_type": "code", - "execution_count": 11, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'dataset' is not defined", - "output_type": "error", - "traceback": [ - "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", - "\u001B[0;31mNameError\u001B[0m Traceback (most recent call last)", - "\u001B[0;32m/var/folders/wh/kr5c3dr12834pfk3j7yqnrq40000gn/T/ipykernel_30867/603974095.py\u001B[0m in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[1;32m 1\u001B[0m \u001B[0mrus\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mRandomUnderSampler\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mrandom_state\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;36m0\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m----> 2\u001B[0;31m \u001B[0mx_list\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0my_list\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mnp\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mconcatenate\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0mv\u001B[0m \u001B[0;32mfor\u001B[0m \u001B[0mk\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mv\u001B[0m \u001B[0;32min\u001B[0m \u001B[0mdataset\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mitems\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0maxis\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;36m0\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mtolist\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m,\u001B[0m\u001B[0;31m \u001B[0m\u001B[0;31m\\\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 3\u001B[0m \u001B[0mnp\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mconcatenate\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0mnp\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mones\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mv\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mshape\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;36m0\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m)\u001B[0m \u001B[0;34m*\u001B[0m \u001B[0mlabel_index\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0mk\u001B[0m\u001B[0;34m]\u001B[0m \u001B[0;32mfor\u001B[0m \u001B[0mk\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mv\u001B[0m \u001B[0;32min\u001B[0m \u001B[0mdataset\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mitems\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0maxis\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;36m0\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mtolist\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 4\u001B[0m \u001B[0mx_resampled\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0my_resampled\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mrus\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mfit_resample\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx_list\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0my_list\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 5\u001B[0m \u001B[0mdataset\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;34m{\u001B[0m\u001B[0;34m\"inside\"\u001B[0m\u001B[0;34m:\u001B[0m \u001B[0mnp\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0marray\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx_resampled\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m}\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", - "\u001B[0;31mNameError\u001B[0m: name 'dataset' is not defined" - ] - } - ], + "execution_count": null, + "outputs": [], "source": [ "rus = RandomUnderSampler(random_state=0)\n", "x_list, y_list = np.concatenate([v for k, v in dataset.items()], axis=0).tolist(), \\\n", @@ -135,25 +121,16 @@ "## 模型训练" ], "metadata": { - "collapsed": false + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } } }, { "cell_type": "code", - "execution_count": 12, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'dataset' is not defined", - "output_type": "error", - "traceback": [ - "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", - "\u001B[0;31mNameError\u001B[0m Traceback (most recent call last)", - "\u001B[0;32m/var/folders/wh/kr5c3dr12834pfk3j7yqnrq40000gn/T/ipykernel_30867/1828483636.py\u001B[0m in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[1;32m 1\u001B[0m \u001B[0;31m# 对数据进行预处理\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0;32m----> 2\u001B[0;31m \u001B[0mx\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mnp\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mconcatenate\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0mv\u001B[0m \u001B[0;32mfor\u001B[0m \u001B[0mk\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0mv\u001B[0m \u001B[0;32min\u001B[0m \u001B[0mdataset\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mitems\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0maxis\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;36m0\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 3\u001B[0m \u001B[0mnegative_sample_num\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mint\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mx\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mshape\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;36m0\u001B[0m\u001B[0;34m]\u001B[0m \u001B[0;34m*\u001B[0m \u001B[0;36m1.2\u001B[0m\u001B[0;34m)\u001B[0m \u001B[0;32mif\u001B[0m \u001B[0mnegative_sample_num\u001B[0m \u001B[0;32mis\u001B[0m \u001B[0;32mNone\u001B[0m \u001B[0;32melse\u001B[0m \u001B[0mnegative_sample_num\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 4\u001B[0m \u001B[0mmodel\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mAnonymousColorDetector\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", - "\u001B[0;31mNameError\u001B[0m: name 'dataset' is not defined" - ] - } - ], + "execution_count": null, + "outputs": [], "source": [ "# 对数据进行预处理\n", "x = np.concatenate([v for k, v in dataset.items()], axis=0)\n", @@ -167,6 +144,28 @@ } } }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "if train_from_existed:\n", + " data = scipy.io.loadmat(dataset_file)\n", + " x, y = data['x'], data['y'].ravel()\n", + " model.fit(x, y=y, is_generate_negative=False, model_selection='dt')\n", + "else:\n", + " world_boundary = np.array([0, 0, 0, 255, 255, 255])\n", + " model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7,\n", + " is_save_dataset=True, model_selection='dt')\n", + "model.save()" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, { "cell_type": "code", "execution_count": 14, diff --git a/02_classification.py b/02_classification.py index 7e75930..0fb08ff 100644 --- a/02_classification.py +++ b/02_classification.py @@ -1,53 +1,66 @@ -#!/usr/bin/env python -# coding: utf-8 - -# # 模型的训练 - -# In[16]: - - +# %% import numpy as np import scipy from imblearn.under_sampling import RandomUnderSampler from models import AnonymousColorDetector from utils import read_labeled_img -# ## 读取数据与构建数据集 +# %% +train_from_existed = True # 是否从现有数据训练,如果是的话,那就从dataset_file训练,否则就用data_dir里头的数据 +data_dir = "data/dataset" # 数据集,文件夹下必须包含`img`和`label`两个文件夹,放置相同文件名的图片和label +dataset_file = "data/dataset/dataset_2022-07-20_10-04.mat" -# In[17]: +# color_dict = {(0, 0, 255): "yangeng", (255, 0, 0): 'beijing'} # 颜色对应的类别 +color_dict = {(0, 0, 255): "yangeng"} +color_dict = {(255, 0, 0): 'beijing'} +color_dict = {(0, 255, 0): "zibian"} +label_index = {"yangeng": 1, "beijing": 0, "zibian": 2} # 类别对应的序号 +show_samples = False # 是否展示样本 - -data_dir = "data/dataset" -color_dict = {(0, 0, 255): "yangeng", (255, 0, 0): 'beijing'} -label_index = {"yangeng": 1, "beijing": 0} +# 定义一些训练量 +threshold = 5 # 正样本周围多大范围内的还算是正样本 +node_num = 20 # 如果使用ELM作为分类器物,有多少的节点 +negative_sample_num = None # None或者一个数字,对应生成的负样本数量 +# %% md +## 读取数据 +# %% dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False) +if show_samples: + from utils import lab_scatter + + lab_scatter(dataset, class_max_num=30000, is_3d=True, is_ps_color_space=False) +# %% md +## 数据平衡化 +# %% rus = RandomUnderSampler(random_state=0) x_list, y_list = np.concatenate([v for k, v in dataset.items()], axis=0).tolist(), \ np.concatenate([np.ones((v.shape[0],)) * label_index[k] for k, v in dataset.items()], axis=0).tolist() - x_resampled, y_resampled = rus.fit_resample(x_list, y_list) dataset = {"inside": np.array(x_resampled)} - -# ## 模型训练 - -# In[18]: - - -# 定义一些常量 -threshold = 5 -node_num = 20 -negative_sample_num = None # None或者一个数字 -world_boundary = np.array([0, 0, 0, 255, 255, 255]) +# %% md +## 模型训练 +# %% # 对数据进行预处理 x = np.concatenate([v for k, v in dataset.items()], axis=0) negative_sample_num = int(x.shape[0] * 1.2) if negative_sample_num is None else negative_sample_num - model = AnonymousColorDetector() - -model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7, - is_save_dataset=True, model_selection='dt') -# data = scipy.io.loadmat('dataset_2022-07-19_15-07.mat') -# x, y = data['x'], data['y'].ravel() -# model.fit(x, y=y, is_generate_negative=False, model_selection='dt') - +# %% +if train_from_existed: + data = scipy.io.loadmat(dataset_file) + x, y = data['x'], data['y'].ravel() + model.fit(x, y=y, is_generate_negative=False, model_selection='dt') +else: + world_boundary = np.array([0, 0, 0, 255, 255, 255]) + model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7, + is_save_dataset=True, model_selection='dt') +model.save() +# %% +if train_from_existed: + data = scipy.io.loadmat(dataset_file) + x, y = data['x'], data['y'].ravel() + model.fit(x, y=y, is_generate_negative=False, model_selection='dt') +else: + world_boundary = np.array([0, 0, 0, 255, 255, 255]) + model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7, + is_save_dataset=True, model_selection='dt') model.save()