From e4cedc2516e94282f2c188972e0d2a52d04cc6e9 Mon Sep 17 00:00:00 2001
From: "li.zhenye"
Date: Tue, 2 Aug 2022 15:17:11 +0800
Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86=E5=83=8F=E7=B4=A0?=
=?UTF-8?q?=E7=82=B9=E6=A8=A1=E5=9E=8B=E7=9A=84=E8=AE=AD=E7=BB=83=E4=BB=A3?=
=?UTF-8?q?=E7=A0=81?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
---
05_model_training.ipynb | 248 ++++++++++++++++++++++++++++++++++++++--
README.md | 27 +++++
models.py | 12 +-
valve_test.py | 3 +-
4 files changed, 278 insertions(+), 12 deletions(-)
diff --git a/05_model_training.ipynb b/05_model_training.ipynb
index 01aa777..57a56e9 100644
--- a/05_model_training.ipynb
+++ b/05_model_training.ipynb
@@ -15,7 +15,7 @@
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": 1,
"outputs": [],
"source": [
"import numpy as np\n",
@@ -45,7 +45,7 @@
},
{
"cell_type": "code",
- "execution_count": 23,
+ "execution_count": 2,
"outputs": [],
"source": [
"data_path = r'data/envi20220802.txt'\n",
@@ -72,7 +72,7 @@
},
{
"cell_type": "code",
- "execution_count": 24,
+ "execution_count": 3,
"outputs": [],
"source": [
"data = read_envi_ascii(data_path)"
@@ -86,7 +86,7 @@
},
{
"cell_type": "code",
- "execution_count": 25,
+ "execution_count": 4,
"outputs": [
{
"name": "stdout",
@@ -114,7 +114,7 @@
},
{
"cell_type": "code",
- "execution_count": 26,
+ "execution_count": 5,
"outputs": [],
"source": [
"data_x = [d for class_name, d in data.items() if class_name in name_dict.keys()]\n",
@@ -142,7 +142,7 @@
},
{
"cell_type": "code",
- "execution_count": 27,
+ "execution_count": 6,
"outputs": [
{
"name": "stdout",
@@ -179,7 +179,7 @@
},
{
"cell_type": "code",
- "execution_count": 30,
+ "execution_count": 7,
"outputs": [
{
"name": "stdout",
@@ -205,7 +205,8 @@
{
"cell_type": "markdown",
"source": [
- "#"
+ "# 进行模型训练\n",
+ "分出一部分数据进行训练"
],
"metadata": {
"collapsed": false,
@@ -213,6 +214,237 @@
"name": "#%% md\n"
}
}
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "outputs": [],
+ "source": [
+ "from models import DecisionTree\n",
+ "from sklearn.model_selection import train_test_split"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "outputs": [],
+ "source": [
+ "train_x, test_x, train_y, test_y = train_test_split(x_resampled, y_resampled, test_size=0.2)"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "outputs": [],
+ "source": [
+ "tree = DecisionTree(class_weight={1:20})"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "outputs": [],
+ "source": [
+ "tree = tree.fit(train_x, train_y)"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# 模型评估"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "pycharm": {
+ "name": "#%% md\n"
+ }
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## 多分类精度"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "pycharm": {
+ "name": "#%% md\n"
+ }
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "outputs": [],
+ "source": [
+ "pred_y = tree.predict(test_x)"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " precision recall f1-score support\n",
+ "\n",
+ " 0.0 1.00 1.00 1.00 304\n",
+ " 1.0 0.97 0.98 0.97 275\n",
+ " 2.0 0.98 1.00 0.99 297\n",
+ " 3.0 0.95 0.99 0.97 293\n",
+ " 4.0 0.98 0.91 0.95 316\n",
+ " 5.0 0.98 0.98 0.98 264\n",
+ "\n",
+ " accuracy 0.98 1749\n",
+ " macro avg 0.98 0.98 0.98 1749\n",
+ "weighted avg 0.98 0.98 0.98 1749\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "from sklearn.metrics import classification_report\n",
+ "\n",
+ "print(classification_report(y_pred=pred_y, y_true=test_y))"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "## 二分类精度"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "pycharm": {
+ "name": "#%% md\n"
+ }
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "outputs": [],
+ "source": [
+ "test_y[test_y <= 1] = 0\n",
+ "test_y[test_y > 1] = 1\n",
+ "pred_y = tree.predict_bin(test_x)"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " precision recall f1-score support\n",
+ "\n",
+ " 0.0 1.00 1.00 1.00 582\n",
+ " 1.0 1.00 1.00 1.00 1167\n",
+ "\n",
+ " accuracy 1.00 1749\n",
+ " macro avg 1.00 1.00 1.00 1749\n",
+ "weighted avg 1.00 1.00 1.00 1749\n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(classification_report(y_true=pred_y, y_pred=pred_y))"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ }
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# 模型保存"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "pycharm": {
+ "name": "#%% md\n"
+ }
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "outputs": [],
+ "source": [
+ "import datetime\n",
+ "\n",
+ "path = datetime.datetime.now().strftime(f\"models/pixel_%Y-%m-%d_%H-%M.model\")\n",
+ "with open(path, 'wb') as f:\n",
+ " pickle.dump(tree, f)"
+ ],
+ "metadata": {
+ "collapsed": false,
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "outputs": [],
+ "source": [],
+ "metadata": {
+ "collapsed": false,
+ "pycharm": {
+ "name": "#%%\n"
+ }
+ }
}
],
"metadata": {
diff --git a/README.md b/README.md
index 1a742c8..7944366 100644
--- a/README.md
+++ b/README.md
@@ -175,3 +175,30 @@
偏差的影响,也可从这幅图当中看到,这幅图的上下偏差达到了惊人的200像素,明显考虑是触发有问题了,不然偏差值至少是恒定的。
结论是考虑RGB相机的触发存在一定问题。
+
+## 喷阀检查
+
+为了能够有效的对喷阀进行检查,我写了一个用于测试的小socket,这个小socket的使用方式是这样的:
+
+开启服务端:
+
+```shel
+python valve_test.py
+```
+
+然后按照要求进行输入就可以了,我还在里头藏了个彩蛋,你猜猜是啥。
+
+如果想要开客户端,可以加个参数,就像这样:
+
+```shel
+python valve_test.py -c
+```
+
+这个客户端啥也不会干,只会做去显示相应的收到的指令。
+
+同时运行这两个可以在本地看到测试结果,不用看zynq那边的结果:
+
+
+
+
+
diff --git a/models.py b/models.py
index abb1a74..f0acf56 100755
--- a/models.py
+++ b/models.py
@@ -255,7 +255,7 @@ class ManualTree:
# 机器学习像素模型类
class PixelModelML:
- def __init__(self, pixel_model_path):
+ def __init__(self, pixel_model_path=None):
with open(pixel_model_path, "rb") as f:
self.dt = pickle.load(f)
@@ -409,7 +409,7 @@ class SpecDetector(Detector):
if x_yellow.shape[0] == 0:
return non_yellow_things
else:
- tobacco = self.pixel_model_ml.predict(x_yellow[..., Config.green_bands]) > 0.5
+ tobacco = self.pixel_model_ml.predict_bin(x_yellow) < 0.5
non_yellow_things[yellow_things] = ~tobacco
# 杂质mask中将背景赋值为0,将杂质赋值为1
@@ -436,6 +436,14 @@ class SpecDetector(Detector):
return blk_result_array
+class DecisionTree(DecisionTreeClassifier):
+ def predict_bin(self, feature):
+ res = self.predict(feature)
+ res[res <= 1] = 0
+ res[res > 1] = 1
+ return res
+
+
if __name__ == '__main__':
data_dir = "data/dataset"
color_dict = {(0, 0, 255): "yangeng"}
diff --git a/valve_test.py b/valve_test.py
index 554aac6..e74c89e 100644
--- a/valve_test.py
+++ b/valve_test.py
@@ -154,7 +154,7 @@ d. 阀板的脉冲分频系数,>=2即可 h. 发个da和
class VirtualValve:
def __init__(self):
self.client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # 声明socket类型,同时生成链接对象
- self.client.connect(('localhost', 13452)) # 建立一个链接,连接到本地的6969端口
+ self.client.connect(('localhost', 13452)) # 建立一个链接,连接到本地的13452端口
def run(self):
while True:
@@ -166,7 +166,6 @@ class VirtualValve:
if __name__ == '__main__':
import argparse
-
parser = argparse.ArgumentParser(description='阀门测程序')
parser.add_argument('-c', default=False, action='store_true', help='是否是开个客户端', required=False)
args = parser.parse_args()