diff --git a/02_classification.ipynb b/02_classification.ipynb index 9da7670..33b8d5b 100644 --- a/02_classification.ipynb +++ b/02_classification.ipynb @@ -14,10 +14,74 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 16, "outputs": [], "source": [ - "from model import AnonymousColorDetector" + "import numpy as np\n", + "\n", + "from models import AnonymousColorDetector\n", + "from utils import read_labeled_img" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "## 读取数据与构建数据集" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 17, + "outputs": [], + "source": [ + "data_dir = \"data/dataset\"\n", + "color_dict = {(0, 0, 255): \"yangeng\"}\n", + "dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "## 模型训练" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 18, + "outputs": [], + "source": [ + "# 定义一些常量\n", + "threshold = 5\n", + "node_num = 20\n", + "negative_sample_num = None # None或者一个数字\n", + "world_boundary = np.array([0, 0, 0, 255, 255, 255])\n", + "# 对数据进行预处理\n", + "x = np.concatenate([v for k, v in dataset.items()], axis=0)\n", + "negative_sample_num = x.shape[0] if negative_sample_num is None else negative_sample_num" ], "metadata": { "collapsed": false, @@ -28,9 +92,11 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 19, "outputs": [], - "source": [], + "source": [ + "model = AnonymousColorDetector()" + ], "metadata": { "collapsed": false, "pycharm": { @@ -40,21 +106,25 @@ }, { "cell_type": "code", - "execution_count": 6, - "outputs": [], - "source": [], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" + "execution_count": 20, + "outputs": [ + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[1;31mKeyboardInterrupt\u001B[0m Traceback (most recent call last)", + "Input \u001B[1;32mIn [20]\u001B[0m, in \u001B[0;36m\u001B[1;34m()\u001B[0m\n\u001B[1;32m----> 1\u001B[0m \u001B[43mmodel\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mfit\u001B[49m\u001B[43m(\u001B[49m\u001B[43mx\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mworld_boundary\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mthreshold\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mnegative_sample_size\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mnegative_sample_num\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mtrain_size\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;241;43m0.7\u001B[39;49m\u001B[43m)\u001B[49m\n", + "File \u001B[1;32m~\\PycharmProjects\\tobacco_color\\models.py:34\u001B[0m, in \u001B[0;36mAnonymousColorDetector.fit\u001B[1;34m(self, x, world_boundary, threshold, model_selected, negative_sample_size, train_size, **kwargs)\u001B[0m\n\u001B[0;32m 32\u001B[0m node_num \u001B[38;5;241m=\u001B[39m kwargs\u001B[38;5;241m.\u001B[39mget(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mnode_num\u001B[39m\u001B[38;5;124m'\u001B[39m, \u001B[38;5;241m10\u001B[39m)\n\u001B[0;32m 33\u001B[0m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mmodel \u001B[38;5;241m=\u001B[39m ELM(input_size\u001B[38;5;241m=\u001B[39mx\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m1\u001B[39m], node_num\u001B[38;5;241m=\u001B[39mnode_num, output_num\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m2\u001B[39m, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n\u001B[1;32m---> 34\u001B[0m negative_samples \u001B[38;5;241m=\u001B[39m \u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mgenerate_negative_samples\u001B[49m\u001B[43m(\u001B[49m\u001B[43mx\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mworld_boundary\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mthreshold\u001B[49m\u001B[43m,\u001B[49m\n\u001B[0;32m 35\u001B[0m \u001B[43m \u001B[49m\u001B[43msample_size\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43mnegative_sample_size\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m 36\u001B[0m data_x, data_y \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39mconcatenate([x, negative_samples], axis\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m0\u001B[39m), \\\n\u001B[0;32m 37\u001B[0m np\u001B[38;5;241m.\u001B[39mconcatenate([np\u001B[38;5;241m.\u001B[39mones(x\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m0\u001B[39m], dtype\u001B[38;5;241m=\u001B[39m\u001B[38;5;28mint\u001B[39m),\n\u001B[0;32m 38\u001B[0m np\u001B[38;5;241m.\u001B[39mzeros(negative_samples\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m0\u001B[39m], dtype\u001B[38;5;241m=\u001B[39m\u001B[38;5;28mint\u001B[39m)], axis\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m0\u001B[39m)\n\u001B[0;32m 39\u001B[0m x_train, x_val, y_train, y_val \u001B[38;5;241m=\u001B[39m train_test_split(data_x, data_y, train_size\u001B[38;5;241m=\u001B[39mtrain_size, shuffle\u001B[38;5;241m=\u001B[39m\u001B[38;5;28;01mTrue\u001B[39;00m)\n", + "File \u001B[1;32m~\\PycharmProjects\\tobacco_color\\models.py:66\u001B[0m, in \u001B[0;36mAnonymousColorDetector.generate_negative_samples\u001B[1;34m(x, world_boundary, threshold, sample_size)\u001B[0m\n\u001B[0;32m 64\u001B[0m \u001B[38;5;28;01mfor\u001B[39;00m sample_idx \u001B[38;5;129;01min\u001B[39;00m \u001B[38;5;28mrange\u001B[39m(generated_data\u001B[38;5;241m.\u001B[39mshape[\u001B[38;5;241m0\u001B[39m]):\n\u001B[0;32m 65\u001B[0m sample \u001B[38;5;241m=\u001B[39m generated_data[sample_idx, :]\n\u001B[1;32m---> 66\u001B[0m in_threshold \u001B[38;5;241m=\u001B[39m np\u001B[38;5;241m.\u001B[39many(np\u001B[38;5;241m.\u001B[39msum(\u001B[43mnp\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mpower\u001B[49m\u001B[43m(\u001B[49m\u001B[43msample\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m-\u001B[39;49m\u001B[43m \u001B[49m\u001B[43mx\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;241;43m2\u001B[39;49m\u001B[43m)\u001B[49m, axis\u001B[38;5;241m=\u001B[39m\u001B[38;5;241m1\u001B[39m) \u001B[38;5;241m<\u001B[39m threshold)\n\u001B[0;32m 67\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m in_threshold:\n\u001B[0;32m 68\u001B[0m negative_samples[sample_idx, :] \u001B[38;5;241m=\u001B[39m sample\n", + "\u001B[1;31mKeyboardInterrupt\u001B[0m: " + ] } - } - }, - { - "cell_type": "code", - "execution_count": null, - "outputs": [], - "source": [], + ], + "source": [ + "model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7)" + ], "metadata": { "collapsed": false, "pycharm": { diff --git a/model.py b/models.py similarity index 99% rename from model.py rename to models.py index 4112e63..5318b2b 100644 --- a/model.py +++ b/models.py @@ -1,7 +1,7 @@ # -*- codeing = utf-8 -*- # Time : 2022/7/18 14:03 # @Auther : zhouchao -# @File: model.py +# @File: models.py # @Software:PyCharm、 import numpy as np from sklearn.metrics import classification_report