supermachine-tobacco/02_classification.ipynb

321 lines
14 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"source": [],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 1,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"C:\\Users\\FEIJINTI\\miniconda3\\envs\\deepo\\lib\\site-packages\\tqdm\\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Training env\n"
]
}
],
"source": [
"import numpy as np\n",
"import scipy\n",
"from imblearn.under_sampling import RandomUnderSampler\n",
"from models import AnonymousColorDetector\n",
"from utils import read_labeled_img"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 2,
"outputs": [],
"source": [
"train_from_existed = False # 是否从现有数据训练如果是的话那就从dataset_file训练否则就用data_dir里头的数据\n",
"data_dir = \"data/dataset\" # 数据集,文件夹下必须包含`img`和`label`两个文件夹放置相同文件名的图片和label\n",
"dataset_file = \"data/dataset/dataset_2022-07-20_10-04.mat\"\n",
"\n",
"# color_dict = {(0, 0, 255): \"yangeng\", (255, 0, 0): 'beijing',(0, 255, 0): \"zibian\"} # 颜色对应的类别\n",
"# color_dict = {(0, 0, 255): \"yangeng\"}\n",
"color_dict = {(255, 0, 0): 'beijing'}\n",
"# color_dict = {(0, 255, 0): \"zibian\"}\n",
"label_index = {\"yangeng\": 1, \"beijing\": 0, \"zibian\":2} # 类别对应的序号\n",
"show_samples = False # 是否展示样本\n",
"\n",
"# 定义一些训练量\n",
"threshold = 2 # 正样本周围多大范围内的还算是正样本\n",
"node_num = 20 # 如果使用ELM作为分类器物有多少的节点\n",
"negative_sample_num = None # None或者一个数字对应生成的负样本数量"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"## 读取数据"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 3,
"outputs": [],
"source": [
"dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)\n",
"if show_samples:\n",
" from utils import lab_scatter\n",
" lab_scatter(dataset, class_max_num=30000, is_3d=True, is_ps_color_space=False)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"## 数据平衡化"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 4,
"outputs": [],
"source": [
"if len(dataset) > 1:\n",
" rus = RandomUnderSampler(random_state=0)\n",
" x_list, y_list = np.concatenate([v for k, v in dataset.items()], axis=0).tolist(), \\\n",
" np.concatenate([np.ones((v.shape[0],)) * label_index[k] for k, v in dataset.items()], axis=0).tolist()\n",
" x_resampled, y_resampled = rus.fit_resample(x_list, y_list)\n",
" dataset = {\"inside\": np.array(x_resampled)}"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"## 模型训练"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 5,
"outputs": [],
"source": [
"# 对数据进行预处理\n",
"x = np.concatenate([v for k, v in dataset.items()], axis=0)\n",
"negative_sample_num = int(x.shape[0] * 1.2) if negative_sample_num is None else negative_sample_num\n",
"model = AnonymousColorDetector()"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 6,
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"147099it [31:23, 78.09it/s] \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0 1.00 1.00 1.00 44129\n",
" 1 1.00 1.00 1.00 36775\n",
"\n",
" accuracy 1.00 80904\n",
" macro avg 1.00 1.00 1.00 80904\n",
"weighted avg 1.00 1.00 1.00 80904\n",
"\n"
]
}
],
"source": [
"if train_from_existed:\n",
" data = scipy.io.loadmat(dataset_file)\n",
" x, y = data['x'], data['y'].ravel()\n",
" model.fit(x, y=y, is_generate_negative=False, model_selection='dt')\n",
"else:\n",
" world_boundary = np.array([0, 0, 0, 255, 255, 255])\n",
" model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7,\n",
" is_save_dataset=True, model_selection='dt')\n",
"model.save()"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "markdown",
"source": [
"## 模型的可视化"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%% md\n"
}
}
},
{
"cell_type": "code",
"execution_count": 21,
"outputs": [],
"source": [
"model_path = \"dt_2022-07-21_17-19.model\"\n",
"dataset_path = \"dataset_2022-07-21_17-19.mat\""
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 22,
"outputs": [
{
"ename": "FileNotFoundError",
"evalue": "[Errno 2] No such file or directory: 'dataset_2022-07-21_17-19.mat'",
"output_type": "error",
"traceback": [
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
"\u001B[1;31mFileNotFoundError\u001B[0m Traceback (most recent call last)",
"File \u001B[1;32m~\\miniconda3\\envs\\deepo\\lib\\site-packages\\scipy\\io\\matlab\\mio.py:39\u001B[0m, in \u001B[0;36m_open_file\u001B[1;34m(file_like, appendmat, mode)\u001B[0m\n\u001B[0;32m 38\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[1;32m---> 39\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mopen\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43mfile_like\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mmode\u001B[49m\u001B[43m)\u001B[49m, \u001B[38;5;28;01mTrue\u001B[39;00m\n\u001B[0;32m 40\u001B[0m \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mIOError\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m e:\n\u001B[0;32m 41\u001B[0m \u001B[38;5;66;03m# Probably \"not found\"\u001B[39;00m\n",
"\u001B[1;31mFileNotFoundError\u001B[0m: [Errno 2] No such file or directory: 'dataset_2022-07-21_17-19.mat'",
"\nDuring handling of the above exception, another exception occurred:\n",
"\u001B[1;31mFileNotFoundError\u001B[0m Traceback (most recent call last)",
"Input \u001B[1;32mIn [22]\u001B[0m, in \u001B[0;36m<cell line: 1>\u001B[1;34m()\u001B[0m\n\u001B[1;32m----> 1\u001B[0m data \u001B[38;5;241m=\u001B[39m \u001B[43mscipy\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mio\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mloadmat\u001B[49m\u001B[43m(\u001B[49m\u001B[43mdataset_path\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 2\u001B[0m ground_truth \u001B[38;5;241m=\u001B[39m data[\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mx\u001B[39m\u001B[38;5;124m\"\u001B[39m][data[\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124my\u001B[39m\u001B[38;5;124m\"\u001B[39m]\u001B[38;5;241m.\u001B[39mravel()\u001B[38;5;241m==\u001B[39m\u001B[38;5;241m1\u001B[39m]\n",
"File \u001B[1;32m~\\miniconda3\\envs\\deepo\\lib\\site-packages\\scipy\\io\\matlab\\mio.py:224\u001B[0m, in \u001B[0;36mloadmat\u001B[1;34m(file_name, mdict, appendmat, **kwargs)\u001B[0m\n\u001B[0;32m 87\u001B[0m \u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[0;32m 88\u001B[0m \u001B[38;5;124;03mLoad MATLAB file.\u001B[39;00m\n\u001B[0;32m 89\u001B[0m \n\u001B[1;32m (...)\u001B[0m\n\u001B[0;32m 221\u001B[0m \u001B[38;5;124;03m 3.14159265+3.14159265j])\u001B[39;00m\n\u001B[0;32m 222\u001B[0m \u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[0;32m 223\u001B[0m variable_names \u001B[38;5;241m=\u001B[39m kwargs\u001B[38;5;241m.\u001B[39mpop(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mvariable_names\u001B[39m\u001B[38;5;124m'\u001B[39m, \u001B[38;5;28;01mNone\u001B[39;00m)\n\u001B[1;32m--> 224\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m _open_file_context(file_name, appendmat) \u001B[38;5;28;01mas\u001B[39;00m f:\n\u001B[0;32m 225\u001B[0m MR, _ \u001B[38;5;241m=\u001B[39m mat_reader_factory(f, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n\u001B[0;32m 226\u001B[0m matfile_dict \u001B[38;5;241m=\u001B[39m MR\u001B[38;5;241m.\u001B[39mget_variables(variable_names)\n",
"File \u001B[1;32m~\\miniconda3\\envs\\deepo\\lib\\contextlib.py:135\u001B[0m, in \u001B[0;36m_GeneratorContextManager.__enter__\u001B[1;34m(self)\u001B[0m\n\u001B[0;32m 133\u001B[0m \u001B[38;5;28;01mdel\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39margs, \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mkwds, \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mfunc\n\u001B[0;32m 134\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[1;32m--> 135\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mnext\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mgen\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 136\u001B[0m \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mStopIteration\u001B[39;00m:\n\u001B[0;32m 137\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mRuntimeError\u001B[39;00m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mgenerator didn\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mt yield\u001B[39m\u001B[38;5;124m\"\u001B[39m) \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;28mNone\u001B[39m\n",
"File \u001B[1;32m~\\miniconda3\\envs\\deepo\\lib\\site-packages\\scipy\\io\\matlab\\mio.py:17\u001B[0m, in \u001B[0;36m_open_file_context\u001B[1;34m(file_like, appendmat, mode)\u001B[0m\n\u001B[0;32m 15\u001B[0m \u001B[38;5;129m@contextmanager\u001B[39m\n\u001B[0;32m 16\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21m_open_file_context\u001B[39m(file_like, appendmat, mode\u001B[38;5;241m=\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mrb\u001B[39m\u001B[38;5;124m'\u001B[39m):\n\u001B[1;32m---> 17\u001B[0m f, opened \u001B[38;5;241m=\u001B[39m \u001B[43m_open_file\u001B[49m\u001B[43m(\u001B[49m\u001B[43mfile_like\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mappendmat\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mmode\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 18\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m 19\u001B[0m \u001B[38;5;28;01myield\u001B[39;00m f\n",
"File \u001B[1;32m~\\miniconda3\\envs\\deepo\\lib\\site-packages\\scipy\\io\\matlab\\mio.py:45\u001B[0m, in \u001B[0;36m_open_file\u001B[1;34m(file_like, appendmat, mode)\u001B[0m\n\u001B[0;32m 43\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m appendmat \u001B[38;5;129;01mand\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m file_like\u001B[38;5;241m.\u001B[39mendswith(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124m.mat\u001B[39m\u001B[38;5;124m'\u001B[39m):\n\u001B[0;32m 44\u001B[0m file_like \u001B[38;5;241m+\u001B[39m\u001B[38;5;241m=\u001B[39m \u001B[38;5;124m'\u001B[39m\u001B[38;5;124m.mat\u001B[39m\u001B[38;5;124m'\u001B[39m\n\u001B[1;32m---> 45\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mopen\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43mfile_like\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mmode\u001B[49m\u001B[43m)\u001B[49m, \u001B[38;5;28;01mTrue\u001B[39;00m\n\u001B[0;32m 46\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[0;32m 47\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mIOError\u001B[39;00m(\n\u001B[0;32m 48\u001B[0m \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mReader needs file name or open file-like object\u001B[39m\u001B[38;5;124m'\u001B[39m\n\u001B[0;32m 49\u001B[0m ) \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01me\u001B[39;00m\n",
"\u001B[1;31mFileNotFoundError\u001B[0m: [Errno 2] No such file or directory: 'dataset_2022-07-21_17-19.mat'"
]
}
],
"source": [
"data = scipy.io.loadmat(dataset_path)\n",
"ground_truth = data[\"x\"][data[\"y\"].ravel()==1]"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [
"%matplotlib notebook\n",
"model = AnonymousColorDetector(model_path)\n",
"model.visualize(world_boundary = np.array([0, 0, 0, 255, 255, 255]),sample_size=5000,ground_truth=ground_truth)"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": null,
"outputs": [],
"source": [],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 1
}