mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 14:23:55 +00:00
数据分析已经完成
This commit is contained in:
parent
b37114bed7
commit
a4652427c7
354
01_dataset.ipynb
354
01_dataset.ipynb
File diff suppressed because one or more lines are too long
@ -1,13 +1,23 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# 模型的训练"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import cv2\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt"
|
||||
"from model import AnonymousColorDetector"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
@ -20,9 +30,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"img_path = r\"data/tobacco/Image_2022_0716_1409_08_547-001482.bmp\""
|
||||
],
|
||||
"source": [],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
@ -34,162 +42,7 @@
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"img = cv2.imread(img_path)\n",
|
||||
"img = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 257,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"l_bmax = 43\n",
|
||||
"l_bmin = 1\n",
|
||||
"a_bmax = 148\n",
|
||||
"a_bmin = 122\n",
|
||||
"b_bmax = 137\n",
|
||||
"b_bmin = 126"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 258,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"l_ymax = 157\n",
|
||||
"l_ymin = 6\n",
|
||||
"# l_ymin = 44\n",
|
||||
"a_ymax = 170\n",
|
||||
"a_ymin = 124\n",
|
||||
"b_ymax = 163\n",
|
||||
"b_ymin = 129"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 259,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# item_mask = np.zeros((1024, 4096))\n",
|
||||
"item_mask = ~((img[:,:,0]>=l_bmin)&(img[:,:,0]<=l_bmax)&(img[:,:,1]>=a_bmin)&(img[:,:,1]<=a_bmax)&(img[:,:,2]>=b_bmin)&(img[:,:,2]<=b_bmax))+0"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 260,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# tobacco_mask = np.zeros((1024, 4096))\n",
|
||||
"tobacco_mask = ((img[:,:,0]>=l_ymin)&(img[:,:,0]<=l_ymax)&(img[:,:,1]>=a_ymin)&(img[:,:,1]<=a_ymax)&(img[:,:,2]>=b_ymin)&(img[:,:,2]<=b_ymax))+0"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 262,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mask = (item_mask.all() & tobacco_mask.all())\n"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 265,
|
||||
"outputs": [
|
||||
{
|
||||
"ename": "TypeError",
|
||||
"evalue": "Invalid shape () for image data",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
|
||||
"\u001B[1;31mTypeError\u001B[0m Traceback (most recent call last)",
|
||||
"Input \u001B[1;32mIn [265]\u001B[0m, in \u001B[0;36m<cell line: 2>\u001B[1;34m()\u001B[0m\n\u001B[0;32m 1\u001B[0m fig,ax \u001B[38;5;241m=\u001B[39m plt\u001B[38;5;241m.\u001B[39msubplots()\n\u001B[1;32m----> 2\u001B[0m ax \u001B[38;5;241m=\u001B[39m \u001B[43mplt\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mimshow\u001B[49m\u001B[43m(\u001B[49m\u001B[43mmask\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 3\u001B[0m plt\u001B[38;5;241m.\u001B[39mshow()\n",
|
||||
"File \u001B[1;32m~\\miniconda3\\envs\\deepo\\lib\\site-packages\\matplotlib\\_api\\deprecation.py:456\u001B[0m, in \u001B[0;36mmake_keyword_only.<locals>.wrapper\u001B[1;34m(*args, **kwargs)\u001B[0m\n\u001B[0;32m 450\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mlen\u001B[39m(args) \u001B[38;5;241m>\u001B[39m name_idx:\n\u001B[0;32m 451\u001B[0m warn_deprecated(\n\u001B[0;32m 452\u001B[0m since, message\u001B[38;5;241m=\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mPassing the \u001B[39m\u001B[38;5;132;01m%(name)s\u001B[39;00m\u001B[38;5;124m \u001B[39m\u001B[38;5;132;01m%(obj_type)s\u001B[39;00m\u001B[38;5;124m \u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m 453\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mpositionally is deprecated since Matplotlib \u001B[39m\u001B[38;5;132;01m%(since)s\u001B[39;00m\u001B[38;5;124m; the \u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m 454\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mparameter will become keyword-only \u001B[39m\u001B[38;5;132;01m%(removal)s\u001B[39;00m\u001B[38;5;124m.\u001B[39m\u001B[38;5;124m\"\u001B[39m,\n\u001B[0;32m 455\u001B[0m name\u001B[38;5;241m=\u001B[39mname, obj_type\u001B[38;5;241m=\u001B[39m\u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mparameter of \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mfunc\u001B[38;5;241m.\u001B[39m\u001B[38;5;18m__name__\u001B[39m\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m()\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[1;32m--> 456\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m func(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n",
|
||||
"File \u001B[1;32m~\\miniconda3\\envs\\deepo\\lib\\site-packages\\matplotlib\\pyplot.py:2640\u001B[0m, in \u001B[0;36mimshow\u001B[1;34m(X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, interpolation_stage, filternorm, filterrad, resample, url, data, **kwargs)\u001B[0m\n\u001B[0;32m 2634\u001B[0m \u001B[38;5;129m@_copy_docstring_and_deprecators\u001B[39m(Axes\u001B[38;5;241m.\u001B[39mimshow)\n\u001B[0;32m 2635\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21mimshow\u001B[39m(\n\u001B[0;32m 2636\u001B[0m X, cmap\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, norm\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, aspect\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, interpolation\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m,\n\u001B[0;32m 2637\u001B[0m alpha\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, vmin\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, vmax\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, origin\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, extent\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, \u001B[38;5;241m*\u001B[39m,\n\u001B[0;32m 2638\u001B[0m interpolation_stage\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, filternorm\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mTrue\u001B[39;00m, filterrad\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m4.0\u001B[39m,\n\u001B[0;32m 2639\u001B[0m resample\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, url\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, data\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[1;32m-> 2640\u001B[0m __ret \u001B[38;5;241m=\u001B[39m gca()\u001B[38;5;241m.\u001B[39mimshow(\n\u001B[0;32m 2641\u001B[0m X, cmap\u001B[38;5;241m=\u001B[39mcmap, norm\u001B[38;5;241m=\u001B[39mnorm, aspect\u001B[38;5;241m=\u001B[39maspect,\n\u001B[0;32m 2642\u001B[0m interpolation\u001B[38;5;241m=\u001B[39minterpolation, alpha\u001B[38;5;241m=\u001B[39malpha, vmin\u001B[38;5;241m=\u001B[39mvmin,\n\u001B[0;32m 2643\u001B[0m vmax\u001B[38;5;241m=\u001B[39mvmax, origin\u001B[38;5;241m=\u001B[39morigin, extent\u001B[38;5;241m=\u001B[39mextent,\n\u001B[0;32m 2644\u001B[0m interpolation_stage\u001B[38;5;241m=\u001B[39minterpolation_stage,\n\u001B[0;32m 2645\u001B[0m filternorm\u001B[38;5;241m=\u001B[39mfilternorm, filterrad\u001B[38;5;241m=\u001B[39mfilterrad, resample\u001B[38;5;241m=\u001B[39mresample,\n\u001B[0;32m 2646\u001B[0m url\u001B[38;5;241m=\u001B[39murl, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39m({\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mdata\u001B[39m\u001B[38;5;124m\"\u001B[39m: data} \u001B[38;5;28;01mif\u001B[39;00m data \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m \u001B[38;5;28;01melse\u001B[39;00m {}),\n\u001B[0;32m 2647\u001B[0m \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n\u001B[0;32m 2648\u001B[0m sci(__ret)\n\u001B[0;32m 2649\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m __ret\n",
|
||||
"File \u001B[1;32m~\\miniconda3\\envs\\deepo\\lib\\site-packages\\matplotlib\\_api\\deprecation.py:456\u001B[0m, in \u001B[0;36mmake_keyword_only.<locals>.wrapper\u001B[1;34m(*args, **kwargs)\u001B[0m\n\u001B[0;32m 450\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mlen\u001B[39m(args) \u001B[38;5;241m>\u001B[39m name_idx:\n\u001B[0;32m 451\u001B[0m warn_deprecated(\n\u001B[0;32m 452\u001B[0m since, message\u001B[38;5;241m=\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mPassing the \u001B[39m\u001B[38;5;132;01m%(name)s\u001B[39;00m\u001B[38;5;124m \u001B[39m\u001B[38;5;132;01m%(obj_type)s\u001B[39;00m\u001B[38;5;124m \u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m 453\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mpositionally is deprecated since Matplotlib \u001B[39m\u001B[38;5;132;01m%(since)s\u001B[39;00m\u001B[38;5;124m; the \u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m 454\u001B[0m \u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mparameter will become keyword-only \u001B[39m\u001B[38;5;132;01m%(removal)s\u001B[39;00m\u001B[38;5;124m.\u001B[39m\u001B[38;5;124m\"\u001B[39m,\n\u001B[0;32m 455\u001B[0m name\u001B[38;5;241m=\u001B[39mname, obj_type\u001B[38;5;241m=\u001B[39m\u001B[38;5;124mf\u001B[39m\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mparameter of \u001B[39m\u001B[38;5;132;01m{\u001B[39;00mfunc\u001B[38;5;241m.\u001B[39m\u001B[38;5;18m__name__\u001B[39m\u001B[38;5;132;01m}\u001B[39;00m\u001B[38;5;124m()\u001B[39m\u001B[38;5;124m\"\u001B[39m)\n\u001B[1;32m--> 456\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m func(\u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n",
|
||||
"File \u001B[1;32m~\\miniconda3\\envs\\deepo\\lib\\site-packages\\matplotlib\\__init__.py:1412\u001B[0m, in \u001B[0;36m_preprocess_data.<locals>.inner\u001B[1;34m(ax, data, *args, **kwargs)\u001B[0m\n\u001B[0;32m 1409\u001B[0m \u001B[38;5;129m@functools\u001B[39m\u001B[38;5;241m.\u001B[39mwraps(func)\n\u001B[0;32m 1410\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21minner\u001B[39m(ax, \u001B[38;5;241m*\u001B[39margs, data\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mNone\u001B[39;00m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs):\n\u001B[0;32m 1411\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m data \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[1;32m-> 1412\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m func(ax, \u001B[38;5;241m*\u001B[39m\u001B[38;5;28mmap\u001B[39m(sanitize_sequence, args), \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n\u001B[0;32m 1414\u001B[0m bound \u001B[38;5;241m=\u001B[39m new_sig\u001B[38;5;241m.\u001B[39mbind(ax, \u001B[38;5;241m*\u001B[39margs, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n\u001B[0;32m 1415\u001B[0m auto_label \u001B[38;5;241m=\u001B[39m (bound\u001B[38;5;241m.\u001B[39marguments\u001B[38;5;241m.\u001B[39mget(label_namer)\n\u001B[0;32m 1416\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m bound\u001B[38;5;241m.\u001B[39mkwargs\u001B[38;5;241m.\u001B[39mget(label_namer))\n",
|
||||
"File \u001B[1;32m~\\miniconda3\\envs\\deepo\\lib\\site-packages\\matplotlib\\axes\\_axes.py:5488\u001B[0m, in \u001B[0;36mAxes.imshow\u001B[1;34m(self, X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, interpolation_stage, filternorm, filterrad, resample, url, **kwargs)\u001B[0m\n\u001B[0;32m 5481\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mset_aspect(aspect)\n\u001B[0;32m 5482\u001B[0m im \u001B[38;5;241m=\u001B[39m mimage\u001B[38;5;241m.\u001B[39mAxesImage(\u001B[38;5;28mself\u001B[39m, cmap, norm, interpolation,\n\u001B[0;32m 5483\u001B[0m origin, extent, filternorm\u001B[38;5;241m=\u001B[39mfilternorm,\n\u001B[0;32m 5484\u001B[0m filterrad\u001B[38;5;241m=\u001B[39mfilterrad, resample\u001B[38;5;241m=\u001B[39mresample,\n\u001B[0;32m 5485\u001B[0m interpolation_stage\u001B[38;5;241m=\u001B[39minterpolation_stage,\n\u001B[0;32m 5486\u001B[0m \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n\u001B[1;32m-> 5488\u001B[0m \u001B[43mim\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mset_data\u001B[49m\u001B[43m(\u001B[49m\u001B[43mX\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 5489\u001B[0m im\u001B[38;5;241m.\u001B[39mset_alpha(alpha)\n\u001B[0;32m 5490\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m im\u001B[38;5;241m.\u001B[39mget_clip_path() \u001B[38;5;129;01mis\u001B[39;00m \u001B[38;5;28;01mNone\u001B[39;00m:\n\u001B[0;32m 5491\u001B[0m \u001B[38;5;66;03m# image does not already have clipping set, clip to axes patch\u001B[39;00m\n",
|
||||
"File \u001B[1;32m~\\miniconda3\\envs\\deepo\\lib\\site-packages\\matplotlib\\image.py:715\u001B[0m, in \u001B[0;36m_ImageBase.set_data\u001B[1;34m(self, A)\u001B[0m\n\u001B[0;32m 711\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_A \u001B[38;5;241m=\u001B[39m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_A[:, :, \u001B[38;5;241m0\u001B[39m]\n\u001B[0;32m 713\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m (\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_A\u001B[38;5;241m.\u001B[39mndim \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m2\u001B[39m\n\u001B[0;32m 714\u001B[0m \u001B[38;5;129;01mor\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_A\u001B[38;5;241m.\u001B[39mndim \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m3\u001B[39m \u001B[38;5;129;01mand\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_A\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m-\u001B[39m\u001B[38;5;241m1\u001B[39m] \u001B[38;5;129;01min\u001B[39;00m [\u001B[38;5;241m3\u001B[39m, \u001B[38;5;241m4\u001B[39m]):\n\u001B[1;32m--> 715\u001B[0m \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mTypeError\u001B[39;00m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mInvalid shape \u001B[39m\u001B[38;5;132;01m{}\u001B[39;00m\u001B[38;5;124m for image data\u001B[39m\u001B[38;5;124m\"\u001B[39m\n\u001B[0;32m 716\u001B[0m \u001B[38;5;241m.\u001B[39mformat(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_A\u001B[38;5;241m.\u001B[39mshape))\n\u001B[0;32m 718\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_A\u001B[38;5;241m.\u001B[39mndim \u001B[38;5;241m==\u001B[39m \u001B[38;5;241m3\u001B[39m:\n\u001B[0;32m 719\u001B[0m \u001B[38;5;66;03m# If the input data has values outside the valid range (after\u001B[39;00m\n\u001B[0;32m 720\u001B[0m \u001B[38;5;66;03m# normalisation), we issue a warning and then clip X to the bounds\u001B[39;00m\n\u001B[0;32m 721\u001B[0m \u001B[38;5;66;03m# - otherwise casting wraps extreme values, hiding outliers and\u001B[39;00m\n\u001B[0;32m 722\u001B[0m \u001B[38;5;66;03m# making reliable interpretation impossible.\u001B[39;00m\n\u001B[0;32m 723\u001B[0m high \u001B[38;5;241m=\u001B[39m \u001B[38;5;241m255\u001B[39m \u001B[38;5;28;01mif\u001B[39;00m np\u001B[38;5;241m.\u001B[39missubdtype(\u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39m_A\u001B[38;5;241m.\u001B[39mdtype, np\u001B[38;5;241m.\u001B[39minteger) \u001B[38;5;28;01melse\u001B[39;00m \u001B[38;5;241m1\u001B[39m\n",
|
||||
"\u001B[1;31mTypeError\u001B[0m: Invalid shape () for image data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": "<Figure size 432x288 with 1 Axes>",
|
||||
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAQYAAAD8CAYAAACVSwr3AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAAMX0lEQVR4nO3bX4il9X3H8fenuxEak0aJk5DuKt2WNbotWnRiJPSPaWizay6WgBdqqFQCixBDLpVCk4I3zUUhBP8siyySm+xNJN0UEyktiQVr4yz4bxVlulKdrOAaQwoGKqvfXsxpc3q+szvPrGfO2cH3CwbmeZ7fOefLMOc9zzzzTKoKSRr3G/MeQNL5xzBIagyDpMYwSGoMg6TGMEhq1g1DksNJXk/y3BmOJ8m3kywneSbJNdMfU9IsDTljeAjYe5bj+4Ddo48DwAPvfSxJ87RuGKrqMeDNsyzZD3ynVj0BXJTkE9MaUNLsbZ/Cc+wAXh3bXhnte21yYZIDrJ5VcOGFF157xRVXTOHlJZ3JsWPH3qiqhY0+bhphyBr71rzPuqoOAYcAFhcXa2lpaQovL+lMkvznuTxuGn+VWAEuHdveCZycwvNKmpNphOEocNvorxPXA7+sqvZrhKStY91fJZJ8F7gBuCTJCvAN4AMAVXUQeAS4EVgGfgXcvlnDSpqNdcNQVbesc7yAr0xtIklz552PkhrDIKkxDJIawyCpMQySGsMgqTEMkhrDIKkxDJIawyCpMQySGsMgqTEMkhrDIKkxDJIawyCpMQySGsMgqTEMkhrDIKkxDJIawyCpMQySGsMgqTEMkhrDIKkxDJIawyCpMQySGsMgqTEMkhrDIKkxDJIawyCpMQySmkFhSLI3yYtJlpPcvcbxjyT5QZKnkxxPcvv0R5U0K+uGIck24D5gH7AHuCXJnollXwGer6qrgRuAv09ywZRnlTQjQ84YrgOWq+pEVb0NHAH2T6wp4MNJAnwIeBM4PdVJJc3MkDDsAF4d214Z7Rt3L3AlcBJ4FvhaVb07+URJDiRZSrJ06tSpcxxZ0mYbEoassa8mtj8PPAX8NvCHwL1Jfqs9qOpQVS1W1eLCwsIGR5U0K0PCsAJcOra9k9Uzg3G3Aw/XqmXgZeCK6YwoadaGhOFJYHeSXaMLijcDRyfWvAJ8DiDJx4FPAiemOaik2dm+3oKqOp3kTuBRYBtwuKqOJ7ljdPwgcA/wUJJnWf3V466qemMT55a0idYNA0BVPQI8MrHv4NjnJ4G/mO5okubFOx8lNYZBUmMYJDWGQVJjGCQ1hkFSYxgkNYZBUmMYJDWGQVJjGCQ1hkFSYxgkNYZBUmMYJDWGQVJjGCQ1hkFSYxgkNYZBUmMYJDWGQVJjGCQ1hkFSYxgkNYZBUmMYJDWGQVJjGCQ1hkFSYxgkNYZBUmMYJDWGQVIzKAxJ9iZ5MclykrvPsOaGJE8lOZ7kJ9MdU9IsbV9vQZJtwH3AnwMrwJNJjlbV82NrLgLuB/ZW1StJPrZJ80qagSFnDNcBy1V1oqreBo4A+yfW3Ao8XFWvAFTV69MdU9IsDQnDDuDVse2V0b5xlwMXJ/lxkmNJblvriZIcSLKUZOnUqVPnNrGkTTckDFljX01sbweuBb4AfB74mySXtwdVHaqqxapaXFhY2PCwkmZj3WsMrJ4hXDq2vRM4ucaaN6rqLeCtJI8BVwMvTWVKSTM15IzhSWB3kl1JLgBuBo5OrPkH4I+TbE/yQeDTwAvTHVXSrKx7xlBVp5PcCTwKbAMOV9XxJHeMjh+sqheS/Ah4BngXeLCqntvMwSVtnlRNXi6YjcXFxVpaWprLa0vvF0mOVdXiRh/nnY+SGsMgqTEMkhrDIKkxDJIawyCpMQySGsMgqTEMkhrDIKkxDJIawyCpMQySGsMgqTEMkhrDIKkxDJIawyCpMQySGsMgqTEMkhrDIKkxDJIawyCpMQySGsMgqTEMkhrDIKkxDJIawyCpMQySGsMgqTEMkhrDIKkxDJKaQWFIsjfJi0mWk9x9lnWfSvJOkpumN6KkWVs3DEm2AfcB+4A9wC1J9pxh3TeBR6c9pKTZGnLGcB2wXFUnqupt4Aiwf411XwW+B7w+xfkkzcGQMOwAXh3bXhnt+z9JdgBfBA6e7YmSHEiylGTp1KlTG51V0owMCUPW2FcT298C7qqqd872RFV1qKoWq2pxYWFh4IiSZm37gDUrwKVj2zuBkxNrFoEjSQAuAW5Mcrqqvj+NISXN1pAwPAnsTrIL+BlwM3Dr+IKq2vW/nyd5CPhHoyBtXeuGoapOJ7mT1b82bAMOV9XxJHeMjp/1uoKkrWfIGQNV9QjwyMS+NYNQVX/13seSNE/e+SipMQySGsMgqTEMkhrDIKkxDJIawyCpMQySGsMgqTEMkhrDIKkxDJIawyCpMQySGsMgqTEMkhrDIKkxDJIawyCpMQySGsMgqTEMkhrDIKkxDJIawyCpMQySGsMgqTEMkhrDIKkxDJIawyCpMQySGsMgqTEMkppBYUiyN8mLSZaT3L3G8S8leWb08XiSq6c/qqRZWTcMSbYB9wH7gD3ALUn2TCx7GfjTqroKuAc4NO1BJc3OkDOG64DlqjpRVW8DR4D94wuq6vGq+sVo8wlg53THlDRLQ8KwA3h1bHtltO9Mvgz8cK0DSQ4kWUqydOrUqeFTSpqpIWHIGvtqzYXJZ1kNw11rHa+qQ1W1WFWLCwsLw6eUNFPbB6xZAS4d294JnJxclOQq4EFgX1X9fDrjSZqHIWcMTwK7k+xKcgFwM3B0fEGSy4CHgb+sqpemP6akWVr3jKGqTie5E3gU2AYcrqrjSe4YHT8IfB34KHB/EoDTVbW4eWNL2kypWvNywaZbXFyspaWluby29H6R5Ni5/JD2zkdJjWGQ1BgGSY1hkNQYBkmNYZDUGAZJjWGQ1BgGSY1hkNQYBkmNYZDUGAZJjWGQ1BgGSY1hkNQYBkmNYZDUGAZJjWGQ1BgGSY1hkNQYBkmNYZDUGAZJjWGQ1BgGSY1hkNQYBkmNYZDUGAZJjWGQ1BgGSY1hkNQYBknNoDAk2ZvkxSTLSe5e43iSfHt0/Jkk10x/VEmzsm4YkmwD7gP2AXuAW5LsmVi2D9g9+jgAPDDlOSXN0JAzhuuA5ao6UVVvA0eA/RNr9gPfqVVPABcl+cSUZ5U0I9sHrNkBvDq2vQJ8esCaHcBr44uSHGD1jALgv5M8t6Fp5+sS4I15DzHQVpoVtta8W2lWgE+ey4OGhCFr7KtzWENVHQIOASRZqqrFAa9/XthK826lWWFrzbuVZoXVec/lcUN+lVgBLh3b3gmcPIc1kraIIWF4EtidZFeSC4CbgaMTa44Ct43+OnE98Muqem3yiSRtDev+KlFVp5PcCTwKbAMOV9XxJHeMjh8EHgFuBJaBXwG3D3jtQ+c89XxspXm30qywtebdSrPCOc6bqnYpQNL7nHc+SmoMg6Rm08OwlW6nHjDrl0YzPpPk8SRXz2POsXnOOu/Yuk8leSfJTbOcb2KGdWdNckOSp5IcT/KTWc84Mct63wsfSfKDJE+P5h1yXW1TJDmc5PUz3Rd0Tu+xqtq0D1YvVv4H8LvABcDTwJ6JNTcCP2T1XojrgX/fzJne46yfAS4efb5vXrMOnXds3b+weoH4pvN1VuAi4HngstH2x87nry3w18A3R58vAG8CF8xp3j8BrgGeO8PxDb/HNvuMYSvdTr3urFX1eFX9YrT5BKv3a8zLkK8twFeB7wGvz3K4CUNmvRV4uKpeAaiq833eAj6cJMCHWA3D6dmOORqk6rHR65/Jht9jmx2GM90qvdE1s7DROb7MaoXnZd15k+wAvggcnOFcaxnytb0cuDjJj5McS3LbzKbrhsx7L3AlqzfyPQt8rarenc14G7bh99iQW6Lfi6ndTj0Dg+dI8llWw/BHmzrR2Q2Z91vAXVX1zuoPtrkZMut24Frgc8BvAv+W5Imqemmzh1vDkHk/DzwF/Bnwe8A/JfnXqvqvTZ7tXGz4PbbZYdhKt1MPmiPJVcCDwL6q+vmMZlvLkHkXgSOjKFwC3JjkdFV9fyYT/trQ74M3quot4K0kjwFXA/MIw5B5bwf+rlZ/iV9O8jJwBfDT2Yy4IRt/j23yRZHtwAlgF7++iPP7E2u+wP+/MPLTOV3AGTLrZaze3fmZecy40Xkn1j/E/C4+DvnaXgn882jtB4HngD84j+d9APjb0ecfB34GXDLH74ff4cwXHzf8HtvUM4bavNup5zXr14GPAvePfgqfrjn9p93Aec8LQ2atqheS/Ah4BngXeLCq5vJv+QO/tvcADyV5ltU33F1VNZd/x07yXeAG4JIkK8A3gA+Mzbrh95i3REtqvPNRUmMYJDWGQVJjGCQ1hkFSYxgkNYZBUvM/YA1djYGMYyEAAAAASUVORK5CYII=\n"
|
||||
},
|
||||
"metadata": {
|
||||
"needs_background": "light"
|
||||
},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"fig,ax = plt.subplots()\n",
|
||||
"ax = plt.imshow(mask)\n",
|
||||
"plt.show()"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 264,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"c = mask.max(axis=0).max(axis=0)\n",
|
||||
"d = mask.min(axis=0).max(axis=0)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
|
||||
144
elm.py
Normal file
144
elm.py
Normal file
@ -0,0 +1,144 @@
|
||||
# -*- codeing = utf-8 -*-
|
||||
# Time : 2019/7/18 14:11
|
||||
# @Auther : lzy
|
||||
# @File: elm.py
|
||||
# @Software:PyCharm
|
||||
import numpy as np
|
||||
import scipy.io
|
||||
import datetime
|
||||
|
||||
|
||||
class ELM:
|
||||
"""
|
||||
假设输入的数据 x 为一个 n个特征的数据 (1, input_size)
|
||||
|
||||
ELM 的结构包括三部分
|
||||
1.一个隐含层权重 w (input_size, node_num)
|
||||
2.该隐含层的偏置值 b (1, node_num)
|
||||
3.后续加权值 beta (node_num, output_num)
|
||||
|
||||
则进行预测时, h = w_(n, node_num).T . b
|
||||
"""
|
||||
|
||||
def __init__(self, input_size=10, node_num=14, output_num=4,
|
||||
weight=None, bias=None, beta=None, rand_seed=None, **kwargs):
|
||||
"""
|
||||
Initialize a ELM with 3 parameters: input_size, node_num, output_num
|
||||
the structure of an ELM is very simple:
|
||||
|
||||
X_(nxN) W_(NxL)+bias(1xL)->f=H(nxL).beta(Lxt)=T(nxt)
|
||||
x_data ---------> Neurons --------------> Output
|
||||
|
||||
:param input_size: [int] the feature numbers of input data e.g. [x1, x2, x3] ---> 3
|
||||
:param node_num: [int] the number of neurons in the hidden layer
|
||||
:param output_num: [int] the number of output
|
||||
:param weight: [array] shape: input_size x node_num
|
||||
:param bias: [array] shape: 1 x node_num
|
||||
:param beta: [array] shape: node_num, output_num
|
||||
:param rand_seed: [int] the random seed
|
||||
"""
|
||||
if rand_seed is not None:
|
||||
np.random.seed(rand_seed)
|
||||
|
||||
if weight is not None:
|
||||
self.w = weight
|
||||
else:
|
||||
# 随机生产权重(从-1,到1,生成(num_feature行,num_hidden列))
|
||||
self.w = np.random.uniform(-1, 1, (input_size, node_num))
|
||||
|
||||
if bias is not None:
|
||||
self.b = bias
|
||||
else:
|
||||
self.b = np.random.uniform(0, 1, (1, node_num))
|
||||
|
||||
if beta is not None:
|
||||
self.beta = beta
|
||||
else:
|
||||
self.beta = np.random.uniform(0, 1, (node_num, output_num))
|
||||
|
||||
def sigmoid(self, x):
|
||||
"""
|
||||
sigmoid activation function
|
||||
:param x: input x
|
||||
:return: sigmoid output
|
||||
"""
|
||||
return 1.0 / (1 + np.exp(-x))
|
||||
|
||||
def fit(self, x_train, y_train, c=None):
|
||||
"""
|
||||
fit the data
|
||||
:param x_train: [array] train data x
|
||||
:param y_train: [array] train data y
|
||||
:param c: regular c
|
||||
:return: self
|
||||
"""
|
||||
y_train = np.eye(int(y_train.max()) + 1)[y_train, :]
|
||||
mul = np.dot(x_train, self.w) # 输入乘以权重
|
||||
add = mul + self.b # 加偏置
|
||||
H = self.sigmoid(add) # 激活函数
|
||||
HH = H.T.dot(H)
|
||||
HT = H.T.dot(y_train)
|
||||
node_num = self.w.shape[1]
|
||||
if c is None:
|
||||
self.beta = np.linalg.pinv(HH).dot(HT)
|
||||
else:
|
||||
self.beta = np.linalg.pinv(HH + np.identity(node_num) / c).dot(HT)
|
||||
return self
|
||||
|
||||
def predict(self, x_data):
|
||||
"""
|
||||
make prediction
|
||||
:param x_data: data to predict
|
||||
:return: predicted data
|
||||
"""
|
||||
mul = np.dot(x_data, self.w) # 输入乘以权重
|
||||
add = mul + self.b # 加偏置
|
||||
H = self.sigmoid(add) # 激活函数
|
||||
result = H.dot(self.beta)
|
||||
result = np.argmax(result, axis=1)
|
||||
return result
|
||||
|
||||
def save(self, path=None):
|
||||
"""
|
||||
save the model
|
||||
:param path: model path
|
||||
:return: saved model name
|
||||
"""
|
||||
if path is None:
|
||||
path = datetime.datetime.now().strftime("ELM_%Y-%m-%d_%H-%M.mat")
|
||||
model_dic = {"w": self.w, "b": self.b, "beta": self.beta}
|
||||
print(self.w.shape, self.b.shape, self.beta.shape)
|
||||
scipy.io.savemat(path, model_dic)
|
||||
return path
|
||||
|
||||
def load(self, model_path):
|
||||
"""
|
||||
load from saved model ([.mat] contain w, b, beta)
|
||||
:param model_path: load from where
|
||||
:return: self
|
||||
"""
|
||||
data = scipy.io.loadmat(model_path)
|
||||
self.w, self.b, self.beta = data['w'], data['b'], data['beta'].T
|
||||
print(self.w.shape, self.b.shape, self.beta.shape)
|
||||
return self
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
x1 = np.linspace(1, 20, 100)
|
||||
x2 = np.linspace(-5, 5, 100)
|
||||
X = np.vstack([x1, x2]).T
|
||||
T = np.sin(x1 * x2 / (2 * np.pi)) * 30 + np.random.normal(0, 0.2, 100) + np.random.normal(0, 1, 100)
|
||||
|
||||
elm = ELM(input_size=2, node_num=100, output_num=1)
|
||||
# elm.load('ELM_2020-01-10_00-18.mat')
|
||||
elm.fit(X, T)
|
||||
|
||||
y = elm.predict(X)
|
||||
elm.save()
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
plt.plot(x1, T, lw=1.5, label='Training goal')
|
||||
plt.plot(x1, y, lw=3, label='ELM output')
|
||||
plt.legend()
|
||||
plt.show()
|
||||
79
model.py
Normal file
79
model.py
Normal file
@ -0,0 +1,79 @@
|
||||
# -*- codeing = utf-8 -*-
|
||||
# Time : 2022/7/18 14:03
|
||||
# @Auther : zhouchao
|
||||
# @File: model.py
|
||||
# @Software:PyCharm、
|
||||
import numpy as np
|
||||
from sklearn.metrics import classification_report
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
from elm import ELM
|
||||
|
||||
|
||||
class AnonymousColorDetector(object):
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
|
||||
def fit(self, x: np.ndarray, world_boundary: np.ndarray, threshold: float, model_selected: str = 'elm',
|
||||
negative_sample_size: int = 1000, train_size: float = 0.8, **kwargs):
|
||||
"""
|
||||
拟合到指定的样本分布情况下,根据x进行分布的变化。
|
||||
|
||||
:param x: ndarray类型的正样本数据,给出的正样本形状为 n x feature_num
|
||||
:param world_boundary: 整个世界的边界,边界形状为 feature_num个下限, feature_num个上限
|
||||
:param threshold: 与正样本之间的距离阈值大于多少则不认为是指定的样本类别
|
||||
:param model_selected: 选择模型,默认为elm
|
||||
:param negative_sample_size: 负样本的数量
|
||||
:param kwargs: 与模型相对应的参数
|
||||
:return:
|
||||
"""
|
||||
assert model_selected in ['elm']
|
||||
if model_selected == 'elm':
|
||||
node_num = kwargs.get('node_num', 10)
|
||||
self.model = ELM(input_size=x.shape[1], node_num=node_num, output_num=2, **kwargs)
|
||||
negative_samples = self.generate_negative_samples(x, world_boundary, threshold,
|
||||
sample_size=negative_sample_size)
|
||||
data_x, data_y = np.concatenate([x, negative_samples], axis=0), \
|
||||
np.concatenate([np.ones(x.shape[0], dtype=int),
|
||||
np.zeros(negative_samples.shape[0], dtype=int)], axis=0)
|
||||
x_train, x_val, y_train, y_val = train_test_split(data_x, data_y, train_size=train_size, shuffle=True)
|
||||
|
||||
self.model.fit(x_train, y_train)
|
||||
y_predict = self.model.predict(x_val)
|
||||
print(classification_report(y_true=y_val, y_pred=y_predict))
|
||||
|
||||
def predict(self, x):
|
||||
return self.model.predict(x)
|
||||
|
||||
@staticmethod
|
||||
def generate_negative_samples(x: np.ndarray, world_boundary: np.ndarray, threshold: float, sample_size: int):
|
||||
"""
|
||||
根据正样本和世界边界生成负样本
|
||||
|
||||
:param x: ndarray类型的正样本数据,给出的正样本形状为 n x feature_num
|
||||
:param world_boundary: 整个世界的边界,边界形状为 feature_num个下限, feature_num个上限, array like
|
||||
:param threshold: 与正样本x之间的距离限制
|
||||
:return: 负样本形状为:(sample_size, feature_num)
|
||||
"""
|
||||
feature_num = x.shape[1]
|
||||
negative_samples = np.zeros((sample_size, feature_num), dtype=x.dtype)
|
||||
generated_sample_num = 0
|
||||
while generated_sample_num <= sample_size:
|
||||
generated_data = np.random.uniform(world_boundary[:feature_num], world_boundary[feature_num:],
|
||||
size=(sample_size, feature_num))
|
||||
for sample_idx in range(generated_data.shape[0]):
|
||||
sample = generated_data[sample_idx, :]
|
||||
in_threshold = np.any(np.sum(np.power(sample - x, 2), axis=1) < threshold)
|
||||
if not in_threshold:
|
||||
negative_samples[sample_idx, :] = sample
|
||||
generated_sample_num += 1
|
||||
if generated_sample_num >= sample_size:
|
||||
break
|
||||
return negative_samples
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
detector = AnonymousColorDetector()
|
||||
x = np.array([[10, 30, 20], [10, 35, 25], [10, 35, 36]])
|
||||
world_boundary = np.array([0, -127, -127, 100, 127, 127])
|
||||
detector.fit(x, world_boundary, threshold=5, negative_sample_size=2000)
|
||||
32
utils.py
32
utils.py
@ -26,13 +26,14 @@ class MergeDict(dict):
|
||||
return self
|
||||
|
||||
|
||||
def read_labeled_img(dataset_dir: str, color_dict: dict, ext='.bmp') -> dict:
|
||||
def read_labeled_img(dataset_dir: str, color_dict: dict, ext='.bmp', is_ps_color_space=True) -> dict:
|
||||
"""
|
||||
根据dataset_dir下的文件创建数据集
|
||||
|
||||
:param dataset_dir: 文件夹名称,文件夹内必须包含'label'和'label'两个文件夹,并分别存放同名的图像与标签
|
||||
:param color_dict: 进行标签图像的颜色查找
|
||||
:param ext: 图片后缀名,默认为.bmp
|
||||
:param is_ps_color_space: 是否使用ps的标准lab色彩空间,默认True
|
||||
:return: 字典形式的数据集{label: vector(n x 3)},vector为lab色彩空间
|
||||
"""
|
||||
img_names = [img_name for img_name in os.listdir(os.path.join(dataset_dir, 'label'))
|
||||
@ -45,6 +46,7 @@ def read_labeled_img(dataset_dir: str, color_dict: dict, ext='.bmp') -> dict:
|
||||
label_img = cv2.imread(label_path)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
|
||||
# 从opencv的色彩空间到Photoshop的色彩空间
|
||||
if is_ps_color_space:
|
||||
alpha, beta = np.array([100 / 255, 1, 1]), np.array([0, -128, -128])
|
||||
img = img * alpha + beta
|
||||
img = np.asarray(np.round(img, 0), dtype=int)
|
||||
@ -53,17 +55,22 @@ def read_labeled_img(dataset_dir: str, color_dict: dict, ext='.bmp') -> dict:
|
||||
return total_dataset
|
||||
|
||||
|
||||
def lab_scatter(dataset: dict, class_max_num=None):
|
||||
def lab_scatter(dataset: dict, class_max_num=None, is_3d=False, is_ps_color_space=True):
|
||||
"""
|
||||
在lab色彩空间内绘制3维数据分布情况
|
||||
|
||||
:param dataset: 字典形式的数据集{label: vector(n x 3)},vector为lab色彩空间
|
||||
:param class_max_num: 每个类别最多画的样本数量,默认不限制
|
||||
:param is_3d: 进行lab三维绘制或者a,b两通道绘制
|
||||
:param is_ps_color_space: 是否使用ps的标准lab色彩空间,默认True
|
||||
:return: None
|
||||
"""
|
||||
# 观察色彩分布情况
|
||||
fig = plt.figure()
|
||||
if is_3d:
|
||||
ax = fig.add_subplot(projection='3d')
|
||||
else:
|
||||
ax = fig.add_subplot()
|
||||
for label, data in dataset.items():
|
||||
if class_max_num is not None:
|
||||
assert isinstance(class_max_num, int)
|
||||
@ -72,14 +79,25 @@ def lab_scatter(dataset: dict, class_max_num=None):
|
||||
sample_idx = np.random.choice(sample_idx, class_max_num)
|
||||
data = data[sample_idx, :]
|
||||
l, a, b = [data[:, i] for i in range(3)]
|
||||
if is_3d:
|
||||
ax.scatter(a, b, l, label=label, alpha=0.1)
|
||||
ax.set_xlim(-127, 127)
|
||||
ax.set_ylim(-127, 127)
|
||||
ax.set_zlim(0, 100)
|
||||
else:
|
||||
ax.scatter(a, b, label=label, alpha=0.1)
|
||||
x_max, x_min, y_max, y_min, z_max, z_min = [127, -127, 127, -127, 100, 0] if is_ps_color_space else \
|
||||
[255, 0, 255, 0, 255, 0]
|
||||
ax.set_xlim(x_min, x_max)
|
||||
ax.set_ylim(y_min, y_max)
|
||||
ax.set_xlabel('a*')
|
||||
ax.set_ylabel('b*')
|
||||
if is_3d:
|
||||
ax.set_zlim(z_min, z_max)
|
||||
ax.set_zlabel('L')
|
||||
plt.legend()
|
||||
plt.show()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
dataset = read_labeled_img("data/dataset", color_dict={(0, 0, 255): 1, (255, 0, 0): 2})
|
||||
lab_scatter(dataset, class_max_num=2000)
|
||||
color_dict = {(0, 0, 255): "yangeng", (255, 0, 0): "bejing", (0, 255, 0): "hongdianxian",
|
||||
(255, 0, 255): "chengsebangbangtang", (0, 255, 255): "lvdianxian"}
|
||||
dataset = read_labeled_img("data/dataset", color_dict=color_dict, is_ps_color_space=False)
|
||||
lab_scatter(dataset, class_max_num=20000, is_3d=False, is_ps_color_space=False)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user