diff --git a/05_model_training.ipynb b/05_model_training.ipynb index 01aa777..57a56e9 100644 --- a/05_model_training.ipynb +++ b/05_model_training.ipynb @@ -15,7 +15,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 1, "outputs": [], "source": [ "import numpy as np\n", @@ -45,7 +45,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 2, "outputs": [], "source": [ "data_path = r'data/envi20220802.txt'\n", @@ -72,7 +72,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 3, "outputs": [], "source": [ "data = read_envi_ascii(data_path)" @@ -86,7 +86,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 4, "outputs": [ { "name": "stdout", @@ -114,7 +114,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 5, "outputs": [], "source": [ "data_x = [d for class_name, d in data.items() if class_name in name_dict.keys()]\n", @@ -142,7 +142,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 6, "outputs": [ { "name": "stdout", @@ -179,7 +179,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 7, "outputs": [ { "name": "stdout", @@ -205,7 +205,8 @@ { "cell_type": "markdown", "source": [ - "#" + "# 进行模型训练\n", + "分出一部分数据进行训练" ], "metadata": { "collapsed": false, @@ -213,6 +214,237 @@ "name": "#%% md\n" } } + }, + { + "cell_type": "code", + "execution_count": 8, + "outputs": [], + "source": [ + "from models import DecisionTree\n", + "from sklearn.model_selection import train_test_split" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 9, + "outputs": [], + "source": [ + "train_x, test_x, train_y, test_y = train_test_split(x_resampled, y_resampled, test_size=0.2)" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 11, + "outputs": [], + "source": [ + "tree = DecisionTree(class_weight={1:20})" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 12, + "outputs": [], + "source": [ + "tree = tree.fit(train_x, train_y)" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "# 模型评估" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "## 多分类精度" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 13, + "outputs": [], + "source": [ + "pred_y = tree.predict(test_x)" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 14, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " precision recall f1-score support\n", + "\n", + " 0.0 1.00 1.00 1.00 304\n", + " 1.0 0.97 0.98 0.97 275\n", + " 2.0 0.98 1.00 0.99 297\n", + " 3.0 0.95 0.99 0.97 293\n", + " 4.0 0.98 0.91 0.95 316\n", + " 5.0 0.98 0.98 0.98 264\n", + "\n", + " accuracy 0.98 1749\n", + " macro avg 0.98 0.98 0.98 1749\n", + "weighted avg 0.98 0.98 0.98 1749\n", + "\n" + ] + } + ], + "source": [ + "from sklearn.metrics import classification_report\n", + "\n", + "print(classification_report(y_pred=pred_y, y_true=test_y))" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "## 二分类精度" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 19, + "outputs": [], + "source": [ + "test_y[test_y <= 1] = 0\n", + "test_y[test_y > 1] = 1\n", + "pred_y = tree.predict_bin(test_x)" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 20, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " precision recall f1-score support\n", + "\n", + " 0.0 1.00 1.00 1.00 582\n", + " 1.0 1.00 1.00 1.00 1167\n", + "\n", + " accuracy 1.00 1749\n", + " macro avg 1.00 1.00 1.00 1749\n", + "weighted avg 1.00 1.00 1.00 1749\n", + "\n" + ] + } + ], + "source": [ + "print(classification_report(y_true=pred_y, y_pred=pred_y))" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "markdown", + "source": [ + "# 模型保存" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 22, + "outputs": [], + "source": [ + "import datetime\n", + "\n", + "path = datetime.datetime.now().strftime(f\"models/pixel_%Y-%m-%d_%H-%M.model\")\n", + "with open(path, 'wb') as f:\n", + " pickle.dump(tree, f)" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } } ], "metadata": { diff --git a/README.md b/README.md index 1a742c8..7944366 100644 --- a/README.md +++ b/README.md @@ -175,3 +175,30 @@ 偏差的影响,也可从这幅图当中看到,这幅图的上下偏差达到了惊人的200像素,明显考虑是触发有问题了,不然偏差值至少是恒定的。 结论是考虑RGB相机的触发存在一定问题。 + +## 喷阀检查 + +为了能够有效的对喷阀进行检查,我写了一个用于测试的小socket,这个小socket的使用方式是这样的: + +开启服务端: + +```shel +python valve_test.py +``` + +然后按照要求进行输入就可以了,我还在里头藏了个彩蛋,你猜猜是啥。 + +如果想要开客户端,可以加个参数,就像这样: + +```shel +python valve_test.py -c +``` + +这个客户端啥也不会干,只会做去显示相应的收到的指令。 + +同时运行这两个可以在本地看到测试结果,不用看zynq那边的结果: + +![截屏2022-08-02 14.16.24](https://raw.githubusercontent.com/Karllzy/imagebed/main/img/%E6%88%AA%E5%B1%8F2022-08-02%2014.16.24.png) + + + diff --git a/models.py b/models.py index abb1a74..f0acf56 100755 --- a/models.py +++ b/models.py @@ -255,7 +255,7 @@ class ManualTree: # 机器学习像素模型类 class PixelModelML: - def __init__(self, pixel_model_path): + def __init__(self, pixel_model_path=None): with open(pixel_model_path, "rb") as f: self.dt = pickle.load(f) @@ -409,7 +409,7 @@ class SpecDetector(Detector): if x_yellow.shape[0] == 0: return non_yellow_things else: - tobacco = self.pixel_model_ml.predict(x_yellow[..., Config.green_bands]) > 0.5 + tobacco = self.pixel_model_ml.predict_bin(x_yellow) < 0.5 non_yellow_things[yellow_things] = ~tobacco # 杂质mask中将背景赋值为0,将杂质赋值为1 @@ -436,6 +436,14 @@ class SpecDetector(Detector): return blk_result_array +class DecisionTree(DecisionTreeClassifier): + def predict_bin(self, feature): + res = self.predict(feature) + res[res <= 1] = 0 + res[res > 1] = 1 + return res + + if __name__ == '__main__': data_dir = "data/dataset" color_dict = {(0, 0, 255): "yangeng"} diff --git a/valve_test.py b/valve_test.py index 554aac6..e74c89e 100644 --- a/valve_test.py +++ b/valve_test.py @@ -154,7 +154,7 @@ d. 阀板的脉冲分频系数,>=2即可 h. 发个da和 class VirtualValve: def __init__(self): self.client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # 声明socket类型,同时生成链接对象 - self.client.connect(('localhost', 13452)) # 建立一个链接,连接到本地的6969端口 + self.client.connect(('localhost', 13452)) # 建立一个链接,连接到本地的13452端口 def run(self): while True: @@ -166,7 +166,6 @@ class VirtualValve: if __name__ == '__main__': import argparse - parser = argparse.ArgumentParser(description='阀门测程序') parser.add_argument('-c', default=False, action='store_true', help='是否是开个客户端', required=False) args = parser.parse_args()