mirror of
https://github.com/NanjingForestryUniversity/tobacoo-industry.git
synced 2025-11-08 22:33:52 +00:00
213 lines
22 KiB
Plaintext
213 lines
22 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 29,
|
|
"metadata": {
|
|
"collapsed": true
|
|
},
|
|
"outputs": [],
|
|
"source": [
|
|
"import pickle\n",
|
|
"\n",
|
|
"from sklearn.model_selection import train_test_split\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"import numpy as np\n",
|
|
"\n",
|
|
"from utils import visualization_evaluation,visualization_y\n",
|
|
"%matplotlib inline"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 30,
|
|
"outputs": [],
|
|
"source": [
|
|
"# some parameters\n",
|
|
"train_size = 6000\n",
|
|
"\n",
|
|
"blk_sz, sensitivity = 8, 8\n",
|
|
"selected_bands = [127, 201, 202, 294]\n",
|
|
"tree_num = 185\n",
|
|
"pic_row, pic_col= 600, 1024\n",
|
|
"\n",
|
|
"dataset_file = f'./dataset/data_{blk_sz}x{blk_sz}_c{len(selected_bands)}_sen{sensitivity}_test.p'\n",
|
|
"model_file = f'./models/rf_pca_{blk_sz}x{blk_sz}_c{len(selected_bands)}_{tree_num}_1.model'"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"source": [
|
|
"# 数据集与样本平衡"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%% md\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 31,
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"数据量: 9600\n",
|
|
"x train (6720, 8, 8, 4), y train (6720,)\n",
|
|
"x test (2880, 8, 8, 4), y test (2880,)\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# 读取数据\n",
|
|
"with open(dataset_file, 'rb') as f:\n",
|
|
" x_list, y_list = pickle.load(f)\n",
|
|
"# 确保数据当中x和y数量对得上\n",
|
|
"assert len(x_list) == len(y_list)\n",
|
|
"print(\"数据量: \", len(x_list))\n",
|
|
"x, y = np.asarray(x_list), np.asarray(y_list, dtype=int)\n",
|
|
"x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=5,\n",
|
|
" shuffle=True, stratify=y)\n",
|
|
"print(f\"x train {x_train.shape}, y train {y_train.shape}\\n\"\n",
|
|
" f\"x test {x_test.shape}, y test {y_test.shape}\")"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 32,
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"total (array([ 0., 0., 0., 0., 0., 9600., 0.]), array([0, 1, 2, 3, 4, 5, 6, 7]), <BarContainer object of 7 artists>) \n",
|
|
"train (array([ 0., 0., 0., 0., 0., 6720., 0.]), array([0, 1, 2, 3, 4, 5, 6, 7]), <BarContainer object of 7 artists>) \n",
|
|
"test (array([ 0., 0., 0., 0., 0., 2880., 0.]), array([0, 1, 2, 3, 4, 5, 6, 7]), <BarContainer object of 7 artists>)\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": "<Figure size 864x432 with 3 Axes>",
|
|
"image/png": "\n"
|
|
},
|
|
"metadata": {
|
|
"needs_background": "light"
|
|
},
|
|
"output_type": "display_data"
|
|
}
|
|
],
|
|
"source": [
|
|
"fig, axs = plt.subplots(1, 3, figsize=(12, 6))\n",
|
|
"hist_res_total = axs[0].hist(y,[0, 1, 2, 3, 4, 5, 6, 7], align='mid')\n",
|
|
"hist_res_train = axs[1].hist(y_train,[0, 1, 2, 3, 4, 5, 6, 7], align='mid')\n",
|
|
"hist_res_test = axs[2].hist(y_test,[0, 1, 2, 3, 4, 5, 6, 7], align='mid')\n",
|
|
"print(f'total {hist_res_total} \\n'\n",
|
|
" f'train {hist_res_train} \\n'\n",
|
|
" f'test {hist_res_test}')\n",
|
|
"plt.show()"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 33,
|
|
"outputs": [
|
|
{
|
|
"ename": "ValueError",
|
|
"evalue": "The target 'y' needs to have more than 1 class. Got 1 class instead",
|
|
"output_type": "error",
|
|
"traceback": [
|
|
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
|
|
"\u001B[1;31mValueError\u001B[0m Traceback (most recent call last)",
|
|
"Input \u001B[1;32mIn [33]\u001B[0m, in \u001B[0;36m<cell line: 6>\u001B[1;34m()\u001B[0m\n\u001B[0;32m 4\u001B[0m x_train_shape \u001B[38;5;241m=\u001B[39m x_train\u001B[38;5;241m.\u001B[39mshape\n\u001B[0;32m 5\u001B[0m x_train \u001B[38;5;241m=\u001B[39m x_train\u001B[38;5;241m.\u001B[39mreshape((x_train\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m0\u001B[39m], \u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m))\n\u001B[1;32m----> 6\u001B[0m x_resampled, y_resampled \u001B[38;5;241m=\u001B[39m \u001B[43mros\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mfit_resample\u001B[49m\u001B[43m(\u001B[49m\u001B[43mx_train\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43my_train\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 7\u001B[0m \u001B[38;5;66;03m# 画图\u001B[39;00m\n\u001B[0;32m 8\u001B[0m fig, axs \u001B[38;5;241m=\u001B[39m plt\u001B[38;5;241m.\u001B[39msubplots(figsize\u001B[38;5;241m=\u001B[39m(\u001B[38;5;241m10\u001B[39m, \u001B[38;5;241m6\u001B[39m))\n",
|
|
"File \u001B[1;32m~\\miniconda3\\envs\\deepo\\lib\\site-packages\\imblearn\\base.py:79\u001B[0m, in \u001B[0;36mSamplerMixin.fit_resample\u001B[1;34m(self, X, y)\u001B[0m\n\u001B[0;32m 76\u001B[0m arrays_transformer \u001B[38;5;241m=\u001B[39m ArraysTransformer(X, y)\n\u001B[0;32m 77\u001B[0m X, y, binarize_y \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_check_X_y(X, y)\n\u001B[1;32m---> 79\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39msampling_strategy_ \u001B[38;5;241m=\u001B[39m \u001B[43mcheck_sampling_strategy\u001B[49m\u001B[43m(\u001B[49m\n\u001B[0;32m 80\u001B[0m \u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43msampling_strategy\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43my\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_sampling_type\u001B[49m\n\u001B[0;32m 81\u001B[0m \u001B[43m\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 83\u001B[0m output \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_fit_resample(X, y)\n\u001B[0;32m 85\u001B[0m y_ \u001B[38;5;241m=\u001B[39m (\n\u001B[0;32m 86\u001B[0m label_binarize(output[\u001B[38;5;241m1\u001B[39m], classes\u001B[38;5;241m=\u001B[39mnp\u001B[38;5;241m.\u001B[39munique(y)) \u001B[38;5;28;01mif\u001B[39;00m binarize_y \u001B[38;5;28;01melse\u001B[39;00m output[\u001B[38;5;241m1\u001B[39m]\n\u001B[0;32m 87\u001B[0m )\n",
|
|
"File \u001B[1;32m~\\miniconda3\\envs\\deepo\\lib\\site-packages\\imblearn\\utils\\_validation.py:500\u001B[0m, in \u001B[0;36mcheck_sampling_strategy\u001B[1;34m(sampling_strategy, y, sampling_type, **kwargs)\u001B[0m\n\u001B[0;32m 494\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\n\u001B[0;32m 495\u001B[0m \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124msampling_type\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124m should be one of \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mSAMPLING_KIND\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m. \u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m 496\u001B[0m \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mGot \u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;132;01m{\u001B[39;00msampling_type\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m instead.\u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m 497\u001B[0m )\n\u001B[0;32m 499\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m np\u001B[38;5;241m.\u001B[39munique(y)\u001B[38;5;241m.\u001B[39msize \u001B[38;5;241m<\u001B[39m\u001B[38;5;241m=\u001B[39m \u001B[38;5;241m1\u001B[39m:\n\u001B[1;32m--> 500\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mValueError\u001B[39;00m(\n\u001B[0;32m 501\u001B[0m \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mThe target \u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124my\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124m needs to have more than 1 class. \u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m 502\u001B[0m \u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mGot \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mnp\u001B[38;5;241m.\u001B[39munique(y)\u001B[38;5;241m.\u001B[39msize\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m class instead\u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m 503\u001B[0m )\n\u001B[0;32m 505\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m sampling_type \u001B[38;5;129;01min\u001B[39;00m (\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mensemble\u001B[39m\u001B[38;5;124m\"\u001B[39m, \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mbypass\u001B[39m\u001B[38;5;124m\"\u001B[39m):\n\u001B[0;32m 506\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m sampling_strategy\n",
|
|
"\u001B[1;31mValueError\u001B[0m: The target 'y' needs to have more than 1 class. Got 1 class instead"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# 发现样本实在是不平衡\n",
|
|
"from imblearn.over_sampling import RandomOverSampler\n",
|
|
"ros = RandomOverSampler(random_state=0)\n",
|
|
"x_train_shape = x_train.shape\n",
|
|
"x_train = x_train.reshape((x_train.shape[0], -1))\n",
|
|
"x_resampled, y_resampled = ros.fit_resample(x_train, y_train)\n",
|
|
"# 画图\n",
|
|
"fig, axs = plt.subplots(figsize=(10, 6))\n",
|
|
"hist_res_train = axs.hist(y_resampled, [0, 1, 2, 3, 4, 5, 6, 7], align='mid')\n",
|
|
"print(f'train {hist_res_train} \\n')\n",
|
|
"plt.show()\n",
|
|
"# 抽样\n",
|
|
"x_train, _, y_train, _ = train_test_split(x_resampled, y_resampled, train_size=train_size, random_state=0,\n",
|
|
" shuffle=True, stratify=y_resampled)\n",
|
|
"# 画图\n",
|
|
"fig, axs = plt.subplots(figsize=(10, 6))\n",
|
|
"hist_res_train = axs.hist(y_train, [0, 1, 2, 3, 4, 5, 6 ,7], align='mid')\n",
|
|
"print(f'train {hist_res_train} \\n')\n",
|
|
"plt.show()\n",
|
|
"x_train = x_train.reshape(x_train.shape[0], x_train_shape[1], x_train_shape[2], x_train_shape[3])\n",
|
|
"print(len(x_train))"
|
|
],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"outputs": [],
|
|
"source": [],
|
|
"metadata": {
|
|
"collapsed": false,
|
|
"pycharm": {
|
|
"name": "#%%\n"
|
|
}
|
|
}
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 2
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython2",
|
|
"version": "2.7.6"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 0
|
|
} |