mirror of
https://github.com/NanjingForestryUniversity/supermachine-tobacco.git
synced 2025-11-08 14:23:55 +00:00
添加了像素点模型的训练代码
This commit is contained in:
parent
cd4e12c46c
commit
e4cedc2516
@ -15,7 +15,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"execution_count": 1,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
@ -45,7 +45,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 23,
|
||||
"execution_count": 2,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data_path = r'data/envi20220802.txt'\n",
|
||||
@ -72,7 +72,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 24,
|
||||
"execution_count": 3,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data = read_envi_ascii(data_path)"
|
||||
@ -86,7 +86,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"execution_count": 4,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
@ -114,7 +114,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 26,
|
||||
"execution_count": 5,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data_x = [d for class_name, d in data.items() if class_name in name_dict.keys()]\n",
|
||||
@ -142,7 +142,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 27,
|
||||
"execution_count": 6,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
@ -179,7 +179,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 30,
|
||||
"execution_count": 7,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
@ -205,7 +205,8 @@
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"#"
|
||||
"# 进行模型训练\n",
|
||||
"分出一部分数据进行训练"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
@ -213,6 +214,237 @@
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from models import DecisionTree\n",
|
||||
"from sklearn.model_selection import train_test_split"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 9,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_x, test_x, train_y, test_y = train_test_split(x_resampled, y_resampled, test_size=0.2)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tree = DecisionTree(class_weight={1:20})"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"tree = tree.fit(train_x, train_y)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# 模型评估"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## 多分类精度"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pred_y = tree.predict(test_x)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" 0.0 1.00 1.00 1.00 304\n",
|
||||
" 1.0 0.97 0.98 0.97 275\n",
|
||||
" 2.0 0.98 1.00 0.99 297\n",
|
||||
" 3.0 0.95 0.99 0.97 293\n",
|
||||
" 4.0 0.98 0.91 0.95 316\n",
|
||||
" 5.0 0.98 0.98 0.98 264\n",
|
||||
"\n",
|
||||
" accuracy 0.98 1749\n",
|
||||
" macro avg 0.98 0.98 0.98 1749\n",
|
||||
"weighted avg 0.98 0.98 0.98 1749\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from sklearn.metrics import classification_report\n",
|
||||
"\n",
|
||||
"print(classification_report(y_pred=pred_y, y_true=test_y))"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"## 二分类精度"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 19,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"test_y[test_y <= 1] = 0\n",
|
||||
"test_y[test_y > 1] = 1\n",
|
||||
"pred_y = tree.predict_bin(test_x)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 20,
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
" precision recall f1-score support\n",
|
||||
"\n",
|
||||
" 0.0 1.00 1.00 1.00 582\n",
|
||||
" 1.0 1.00 1.00 1.00 1167\n",
|
||||
"\n",
|
||||
" accuracy 1.00 1749\n",
|
||||
" macro avg 1.00 1.00 1.00 1749\n",
|
||||
"weighted avg 1.00 1.00 1.00 1749\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(classification_report(y_true=pred_y, y_pred=pred_y))"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"source": [
|
||||
"# 模型保存"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%% md\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 22,
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import datetime\n",
|
||||
"\n",
|
||||
"path = datetime.datetime.now().strftime(f\"models/pixel_%Y-%m-%d_%H-%M.model\")\n",
|
||||
"with open(path, 'wb') as f:\n",
|
||||
" pickle.dump(tree, f)"
|
||||
],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"outputs": [],
|
||||
"source": [],
|
||||
"metadata": {
|
||||
"collapsed": false,
|
||||
"pycharm": {
|
||||
"name": "#%%\n"
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
27
README.md
27
README.md
@ -175,3 +175,30 @@
|
||||
偏差的影响,也可从这幅图当中看到,这幅图的上下偏差达到了惊人的200像素,明显考虑是触发有问题了,不然偏差值至少是恒定的。
|
||||
|
||||
结论是考虑RGB相机的触发存在一定问题。
|
||||
|
||||
## 喷阀检查
|
||||
|
||||
为了能够有效的对喷阀进行检查,我写了一个用于测试的小socket,这个小socket的使用方式是这样的:
|
||||
|
||||
开启服务端:
|
||||
|
||||
```shel
|
||||
python valve_test.py
|
||||
```
|
||||
|
||||
然后按照要求进行输入就可以了,我还在里头藏了个彩蛋,你猜猜是啥。
|
||||
|
||||
如果想要开客户端,可以加个参数,就像这样:
|
||||
|
||||
```shel
|
||||
python valve_test.py -c
|
||||
```
|
||||
|
||||
这个客户端啥也不会干,只会做去显示相应的收到的指令。
|
||||
|
||||
同时运行这两个可以在本地看到测试结果,不用看zynq那边的结果:
|
||||
|
||||

|
||||
|
||||
|
||||
|
||||
|
||||
12
models.py
12
models.py
@ -255,7 +255,7 @@ class ManualTree:
|
||||
|
||||
# 机器学习像素模型类
|
||||
class PixelModelML:
|
||||
def __init__(self, pixel_model_path):
|
||||
def __init__(self, pixel_model_path=None):
|
||||
with open(pixel_model_path, "rb") as f:
|
||||
self.dt = pickle.load(f)
|
||||
|
||||
@ -409,7 +409,7 @@ class SpecDetector(Detector):
|
||||
if x_yellow.shape[0] == 0:
|
||||
return non_yellow_things
|
||||
else:
|
||||
tobacco = self.pixel_model_ml.predict(x_yellow[..., Config.green_bands]) > 0.5
|
||||
tobacco = self.pixel_model_ml.predict_bin(x_yellow) < 0.5
|
||||
|
||||
non_yellow_things[yellow_things] = ~tobacco
|
||||
# 杂质mask中将背景赋值为0,将杂质赋值为1
|
||||
@ -436,6 +436,14 @@ class SpecDetector(Detector):
|
||||
return blk_result_array
|
||||
|
||||
|
||||
class DecisionTree(DecisionTreeClassifier):
|
||||
def predict_bin(self, feature):
|
||||
res = self.predict(feature)
|
||||
res[res <= 1] = 0
|
||||
res[res > 1] = 1
|
||||
return res
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
data_dir = "data/dataset"
|
||||
color_dict = {(0, 0, 255): "yangeng"}
|
||||
|
||||
@ -154,7 +154,7 @@ d. 阀板的脉冲分频系数,>=2即可 h. 发个da和
|
||||
class VirtualValve:
|
||||
def __init__(self):
|
||||
self.client = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # 声明socket类型,同时生成链接对象
|
||||
self.client.connect(('localhost', 13452)) # 建立一个链接,连接到本地的6969端口
|
||||
self.client.connect(('localhost', 13452)) # 建立一个链接,连接到本地的13452端口
|
||||
|
||||
def run(self):
|
||||
while True:
|
||||
@ -166,7 +166,6 @@ class VirtualValve:
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='阀门测程序')
|
||||
parser.add_argument('-c', default=False, action='store_true', help='是否是开个客户端', required=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user