{ "cells": [ { "cell_type": "markdown", "metadata": { "collapsed": true, "pycharm": { "name": "#%% md\n" } }, "source": [ "# 数据集扩充" ] }, { "cell_type": "markdown", "source": [ "虽然当前的模型已经能够达到较好的效果,但是还不够好,对于一些较老的烟梗不能够做到有效的判别,我们为此增加数据集。" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 49, "outputs": [], "source": [ "import os\n", "\n", "import cv2\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "from utils import read_raw_file, split_xy, generate_tobacco_label, generate_impurity_label\n", "from models import SpecDetector\n", "import pickle" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 50, "outputs": [], "source": [ "# some parameters\n", "new_spectra_file = r\"F:\\zhouchao\\615\\calibrated0.raw\"\n", "new_label_file = r\"F:\\zhouchao\\615\\label0.bmp\"\n", "\n", "target_class = 0\n", "target_class_left, target_class_right = 5, 4\n", "light_threshold = 0.5\n", "add_background = False\n", "\n", "split_line = 500\n", "\n", "\n", "blk_sz, sensitivity = 8, 32\n", "selected_bands = [127, 201, 202, 294]\n", "tree_num = 185\n", "\n", "pic_row, pic_col= 600, 1024\n", "\n", "color_dict = {(0, 0, 255): 1, (255, 255, 255): 0, (0, 255, 0): 2, (255, 0, 0): 1, (0, 255, 255): 4,\n", " (255, 255, 0): 5, (255, 0, 255): 6}\n", "\n", "new_dataset_file = f'./dataset/data_{blk_sz}x{blk_sz}_c{len(selected_bands)}_sen{sensitivity}_4.p'\n", "dataset_file = f'./dataset/data_{blk_sz}x{blk_sz}_c{len(selected_bands)}_sen{sensitivity}_3.p'\n", "\n", "model_file = f'./models/rf_{blk_sz}x{blk_sz}_c{len(selected_bands)}_{tree_num}_sen{sensitivity}_3.model'\n", "# selected_bands = None" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 51, "outputs": [], "source": [ "data = read_raw_file(new_spectra_file, selected_bands)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "## 烟梗标签生成\n", "这会将纯烟梗图片中识别为杂质的部分提取出来" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 52, "outputs": [], "source": [ "x_list, y_list = [], []\n", "if (new_label_file is None) and (target_class == 1):\n", " x_list, y_list = generate_tobacco_label(data, model_file, blk_sz, selected_bands)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "## 其他类别杂质阈值分割\n", "通过阈值分割的形式获取其他类别的杂质" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 53, "outputs": [], "source": [ "if (new_label_file is None) and (target_class != 1):\n", " img = generate_impurity_label(data, light_threshold, color_dict,\n", " target_class_right=target_class_right,\n", " target_class_left=target_class_left,\n", " split_line=split_line)\n", " root, _ = os.path.splitext(new_dataset_file)\n", " cv2.imwrite(root+\"_generated.bmp\", img)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "## 读取标签" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 54, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(600, 1024, 3)\n" ] } ], "source": [ "if new_label_file is not None:\n", " label = cv2.imread(new_label_file)\n", " print(label.shape)\n", " x_list, y_list = split_xy(data, label, blk_sz, sensitivity=sensitivity, color_dict=color_dict, add_background=add_background)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "code", "execution_count": 55, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "301 301\n" ] } ], "source": [ "print(len(x_list), len(y_list))" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "## 读取旧数据合并" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 56, "outputs": [], "source": [ "with open(dataset_file, 'rb') as f:\n", " x, y = pickle.load(f)\n", "x.extend(x_list)\n", "y.extend(y_list)\n", "with open(new_dataset_file, 'wb') as f:\n", " pickle.dump((x, y), f)" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } }, { "cell_type": "markdown", "source": [ "## 批量数据的处理" ], "metadata": { "collapsed": false, "pycharm": { "name": "#%% md\n" } } }, { "cell_type": "code", "execution_count": 56, "outputs": [], "source": [], "metadata": { "collapsed": false, "pycharm": { "name": "#%%\n" } } } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 0 }