mirror of
https://github.com/NanjingForestryUniversity/supermachine-woodwenli.git
synced 2025-11-08 11:53:56 +00:00
第一次添加
This commit is contained in:
parent
a12c554118
commit
48fc1479d5
19
.gitignore
vendored
Normal file
19
.gitignore
vendored
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
data/correct/*.bmp
|
||||||
|
data/dark/*.bmp
|
||||||
|
data/light/*.bmp
|
||||||
|
data/middle/*.bmp
|
||||||
|
data1
|
||||||
|
data2
|
||||||
|
data3
|
||||||
|
data4
|
||||||
|
data5
|
||||||
|
.idea
|
||||||
|
__pycache__
|
||||||
|
*.pyc
|
||||||
|
test.py
|
||||||
|
*.log
|
||||||
|
.models
|
||||||
|
data
|
||||||
|
*.png
|
||||||
|
models
|
||||||
|
*.json
|
||||||
162
QT_test.py
Normal file
162
QT_test.py
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
# -*- codeing = utf-8 -*-
|
||||||
|
# Time : 2022/9/17 15:05
|
||||||
|
# @Auther : zhouchao
|
||||||
|
# @File: QT_test.py
|
||||||
|
# @Software:PyCharm
|
||||||
|
import logging
|
||||||
|
import socket
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def rec_socket(recv_sock: socket.socket, cmd_type: str, ack: bool) -> bool: # 未使用
|
||||||
|
if ack:
|
||||||
|
cmd = 'A' + cmd_type
|
||||||
|
else:
|
||||||
|
cmd = 'D' + cmd_type
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
temp = recv_sock.recv(1)
|
||||||
|
except ConnectionError as e:
|
||||||
|
logging.error(f'连接出错, 错误代码:\n{e}')
|
||||||
|
return False
|
||||||
|
except TimeoutError as e:
|
||||||
|
logging.error(f'超时了,错误代码: \n{e}')
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f'遇见未知错误,错误代码: \n{e}')
|
||||||
|
return False
|
||||||
|
if temp == b'\xaa':
|
||||||
|
break
|
||||||
|
|
||||||
|
# 获取报文长度
|
||||||
|
temp = b''
|
||||||
|
while len(temp) < 4:
|
||||||
|
try:
|
||||||
|
temp += recv_sock.recv(1)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f'接收报文长度失败, 错误代码: \n{e}')
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
data_len = int.from_bytes(temp, byteorder='big')
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f'转换失败,错误代码 \n{e}, \n报文内容\n{temp}')
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 读取报文内容
|
||||||
|
temp = b''
|
||||||
|
while len(temp) < data_len:
|
||||||
|
try:
|
||||||
|
temp += recv_sock.recv(data_len)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f'接收报文内容失败, 错误代码: \n{e},\n报文内容\n{temp}')
|
||||||
|
return False
|
||||||
|
data = temp
|
||||||
|
if cmd.strip().upper() != data[:4].decode('ascii').strip().upper():
|
||||||
|
logging.error(f'客户端接收指令错误,\n指令内容\n{data}')
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
if cmd == 'DIM':
|
||||||
|
print(data)
|
||||||
|
|
||||||
|
# 进行数据校验
|
||||||
|
temp = b''
|
||||||
|
while len(temp) < 3:
|
||||||
|
try:
|
||||||
|
temp += recv_sock.recv(1)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f'接收报文校验失败, 错误代码: \n{e}')
|
||||||
|
return False
|
||||||
|
if temp == b'\xff\xff\xbb':
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
logging.error(f"接收了一个完美的只错了校验位的报文,\n data: {data}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def main(): # 在一个端口上接收文件,在另一个端口上接收控制命令
|
||||||
|
socket_receive = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # 创建一个socket对象,AF_INET是地址簇,SOCK_STREAM是socket类型,表示TCP连接
|
||||||
|
socket_receive.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
|
socket_receive.bind(('127.0.0.1', 21123)) # 127.0.0.1是本机的回环地址,意味着socket仅接受从同一台机器上发起的连接请求。21123是端口号,它是计算机上用于区分不同服务的数字标签。
|
||||||
|
socket_receive.listen(5) # 开始监听传入的连接,5是指在拒绝连接之前,操作系统可以挂起的最大连接数量
|
||||||
|
socket_send = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
socket_send.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||||
|
socket_send.bind(('127.0.0.1', 21122))
|
||||||
|
socket_send.listen(5)
|
||||||
|
print('等待连接')
|
||||||
|
socket_send_1, receive_addr_1 = socket_send.accept()
|
||||||
|
print("连接成功:", receive_addr_1)
|
||||||
|
# socket_send_2 = socket_send_1
|
||||||
|
socket_send_2, receive_addr_2 = socket_receive.accept()
|
||||||
|
print("连接成功:", receive_addr_2)
|
||||||
|
while True:
|
||||||
|
cmd = input().strip().upper()
|
||||||
|
if cmd == 'IM':
|
||||||
|
# img = cv2.imread(r"/Users/zhouchao/Library/CloudStorage/OneDrive-macrosolid/PycharmProjects/wood_color/data/data20220919/dark/rgb60.png")
|
||||||
|
img = cv2.imread(r"D:\Projects\PycharmProjects\xiangmu_wenli\data\wenli\Huawen\huaweisecha (132).png") # 读取图片,返回的img对象是一个NumPy数组,包含图像的像素数据。
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 将BGR格式的图像转换为RGB格式
|
||||||
|
img = np.asarray(img, dtype=np.uint8) # 通过np.asarray()确保图像数据是NumPy数组格式,dtype=np.uint8表示使用8位无符号整数格式存储每个颜色通道,这是图像处理中常用的数据类型。
|
||||||
|
height = img.shape[0] # 获取图像的高度
|
||||||
|
width = img.shape[1] # 获取图像的宽度
|
||||||
|
img_bytes = img.tobytes() # 将图像数据转换为字节流,以便通过网络传输
|
||||||
|
length = len(img_bytes) + 8 # 计算报文长度,包括命令部分、宽度、高度和图像数据,以及结束符。 + 8:这个加法操作包括额外的协议或消息格式所需的字节长度 4 字节用于表示命令类型(例如 'IM')。在某些实现中可能已经固定包含在消息的开始部分。2 字节用于图像的宽度。2 字节用于图像的高度。
|
||||||
|
length = length.to_bytes(4, byteorder='big') # 将报文长度转换为4字节的大端字节序
|
||||||
|
height = height.to_bytes(2, byteorder='big') # 将图像高度转换为2字节的大端字节序,这样可以确保图像的宽度和高度在网络传输中的顺序是正确的。
|
||||||
|
width = width.to_bytes(2, byteorder='big') # 将图像宽度转换为2字节的大端字节序
|
||||||
|
send_message = b'\xaa' + length + (' ' + cmd).upper().encode('ascii') + height + width + img_bytes + b'\xff\xff\xbb' # 消息以'\xaa'开始,包含消息长度、命令代码、图像宽度和高度、图像数据本身,以及结束符'\xff\xff\xbb'来标记消息结束。
|
||||||
|
socket_send_1.send(send_message)
|
||||||
|
print('发送成功')
|
||||||
|
result = socket_send_2.recv(1)
|
||||||
|
print(result)
|
||||||
|
# if rec_socket(socket_send_2, cmd_type=cmd, ack=True):
|
||||||
|
# print('接收指令成功')
|
||||||
|
# else:
|
||||||
|
# print('接收指令失败')
|
||||||
|
# if rec_socket(socket_send_2, cmd_type=cmd, ack=False):
|
||||||
|
# print('指令执行完毕')
|
||||||
|
# else:
|
||||||
|
# print('指令执行失败')
|
||||||
|
elif cmd == 'TR':
|
||||||
|
# model = "/Users/zhouchao/Library/CloudStorage/OneDrive-macrosolid/PycharmProjects/wood_color/data/data20220919"
|
||||||
|
model = r"D:\Projects\PycharmProjects\xiangmu_wenli_2\data\xiangmu_photos_wenli" # 数据路径
|
||||||
|
model = model.encode('ascii') # 将字符串转换为字节流
|
||||||
|
length = len(model) + 4 # 计算报文长度 + 4:这个加法操作通常包括额外的协议或消息格式所需的字节长度,特别是:4 字节用于存储整个消息长度的数值本身,表示消息的起始部分。
|
||||||
|
length = length.to_bytes(4, byteorder='big') # 将报文长度转换为4字节的大端字节序
|
||||||
|
send_message = b'\xaa' + length + (' ' + cmd).upper().encode('ascii') + model + b'\xff\xff\xbb'
|
||||||
|
socket_send_1.send(send_message)
|
||||||
|
print('发送成功')
|
||||||
|
result = socket_send_2.recv(1)
|
||||||
|
print(result)
|
||||||
|
# if rec_socket(socket_send_2, cmd_type=cmd, ack=True):
|
||||||
|
# print('接收指令成功')
|
||||||
|
# else:
|
||||||
|
# print('接收指令失败')
|
||||||
|
# if rec_socket(socket_send_2, cmd_type=cmd, ack=False):
|
||||||
|
# print('指令执行完毕')
|
||||||
|
# else:
|
||||||
|
# print('指令执行失败')
|
||||||
|
elif cmd == 'MD':
|
||||||
|
# model = "/Users/zhouchao/Library/CloudStorage/OneDrive-macrosolid/PycharmProjects/wood_color/models/model_2020-11-08_20-49.p"
|
||||||
|
# model = r"C:\Users\FEIJINTI\OneDrive\PycharmProjects\wood_color\models\model_2023-03-27_16-32.p"
|
||||||
|
model = r"D:\Projects\PycharmProjects\xiangmu_wenli_2\models\model_2024-05-07_13-58.p" # 模型路径
|
||||||
|
model = model.encode('ascii') # 将字符串转换为字节流
|
||||||
|
length = len(model) + 4
|
||||||
|
length = length.to_bytes(4, byteorder='big')
|
||||||
|
send_message = b'\xaa' + length + (' ' + cmd).upper().encode('ascii') + model + b'\xff\xff\xbb'
|
||||||
|
socket_send_1.send(send_message)
|
||||||
|
print('发送成功')
|
||||||
|
result = socket_send_2.recv(1)
|
||||||
|
print(result)
|
||||||
|
# if rec_socket(socket_send_2, cmd_type=cmd, ack=True):
|
||||||
|
# print('接收指令成功')
|
||||||
|
# else:
|
||||||
|
# print('接收指令失败')
|
||||||
|
# if rec_socket(socket_send_2, cmd_type=cmd, ack=False):
|
||||||
|
# print('指令执行完毕')
|
||||||
|
# else:
|
||||||
|
# print('指令执行失败')
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
|
|
||||||
329
classifer.py
Normal file
329
classifer.py
Normal file
@ -0,0 +1,329 @@
|
|||||||
|
# -*- coding: UTF-8 -*-
|
||||||
|
# @Time : 2024/4/15 19:03
|
||||||
|
# @Auther : DUANMU
|
||||||
|
# @File : wenli_classifer.py
|
||||||
|
# @Software : PyCharm
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from sklearn.metrics import accuracy_score, confusion_matrix
|
||||||
|
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier, GradientBoostingClassifier, ExtraTreesClassifier
|
||||||
|
from skimage.filters import sobel
|
||||||
|
from skimage import filters
|
||||||
|
from skimage.morphology import opening, square
|
||||||
|
from skimage.color import rgb2gray
|
||||||
|
from skimage.filters import threshold_otsu
|
||||||
|
from skimage.feature import local_binary_pattern, graycomatrix, graycoprops
|
||||||
|
from scipy.stats import kurtosis, skew
|
||||||
|
from skimage import transform, util
|
||||||
|
from skimage.feature import hog
|
||||||
|
from skimage import color, exposure
|
||||||
|
import lightgbm as lgb
|
||||||
|
from sklearn.decomposition import PCA
|
||||||
|
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import pickle
|
||||||
|
import time
|
||||||
|
|
||||||
|
from root_dir import ROOT_DIR
|
||||||
|
import utils
|
||||||
|
|
||||||
|
class TextureClass:
|
||||||
|
def __init__(self, load_from=None, w=4096, h=1200, debug_mode=False):
|
||||||
|
"""
|
||||||
|
初始化纹理分类器
|
||||||
|
:param load_from:
|
||||||
|
:param w:
|
||||||
|
:param h:
|
||||||
|
:param debug_mode:
|
||||||
|
"""
|
||||||
|
if load_from is None:
|
||||||
|
if w is None or h is None:
|
||||||
|
print("It will damage your performance if you don't set w and h")
|
||||||
|
raise ValueError("w or h is None")
|
||||||
|
self.w, self.h= w, h
|
||||||
|
# self.model = RandomForestClassifier(n_estimators=100)
|
||||||
|
# self.model = lgb.LGBMClassifier(
|
||||||
|
# n_estimators=100,
|
||||||
|
# max_depth=5,
|
||||||
|
# num_leaves=16, # 减小叶节点数
|
||||||
|
# min_data_in_leaf=40, # 增加每个叶子的最小数据量
|
||||||
|
# lambda_l1=0.1, # 增强L1正则化
|
||||||
|
# lambda_l2=0.1, # 增强L2正则化
|
||||||
|
# boosting_type='gbdt', # 使用传统的梯度提升决策树
|
||||||
|
# objective='binary', # 二分类问题
|
||||||
|
# learning_rate=0.05, # 可以尝试降低学习速率
|
||||||
|
# subsample=0.8, # 子样本比率
|
||||||
|
# colsample_bytree=0.8, # 基于树的列采样
|
||||||
|
# metric='binary_logloss', # 评价指标
|
||||||
|
# verbose=-1 # 不输出任何东西,包括警告
|
||||||
|
# )
|
||||||
|
self.model = lgb.LGBMClassifier(n_estimators=100 ,verbose=-1)
|
||||||
|
# self.model = RandomForestClassifier(n_estimators=150, random_state=42, max_depth=10,min_samples_leaf=2,
|
||||||
|
# min_samples_split=4,max_features='sqrt')
|
||||||
|
# self.model = AdaBoostClassifier()
|
||||||
|
# self.model = GradientBoostingClassifier()
|
||||||
|
# self.model = ExtraTreesClassifier()
|
||||||
|
# self.model = GradientBoostingClassifier(n_estimators=100, learning_rate=0.1, max_depth=3, random_state=42)
|
||||||
|
else:
|
||||||
|
self.load(load_from)
|
||||||
|
self.debug_mode = debug_mode
|
||||||
|
self.log = utils.Logger(is_to_file=debug_mode)
|
||||||
|
self.image_num = 0
|
||||||
|
|
||||||
|
def extract_features(self, img):
|
||||||
|
# """使用局部二值模式(LBP)提取图像的纹理特征"""
|
||||||
|
# gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||||
|
# if gray.dtype != np.uint8:
|
||||||
|
# gray = (gray * 255).astype(np.uint8)
|
||||||
|
#
|
||||||
|
# # 设置LBP参数
|
||||||
|
# radius = 1 # LBP算法中圆的半径
|
||||||
|
# n_points = 8 * radius # 统一模式用的点数
|
||||||
|
# lbp = local_binary_pattern(gray, n_points, radius, method='uniform')
|
||||||
|
#
|
||||||
|
# # 计算LBP的直方图
|
||||||
|
# n_bins = int(lbp.max() + 1) # 加1因为直方图的区间是开区间
|
||||||
|
# hist, _ = np.histogram(lbp, bins=n_bins, range=(0, n_bins), density=True)
|
||||||
|
# festures = hist
|
||||||
|
#
|
||||||
|
# return festures
|
||||||
|
|
||||||
|
"""使用局部二值模式(LBP)提取图像的纹理特征,优化版"""
|
||||||
|
# 图像下采样
|
||||||
|
img_resized = cv2.resize(img, (img.shape[1] // 2, img.shape[0] // 2))
|
||||||
|
gray = cv2.cvtColor(img_resized, cv2.COLOR_BGR2GRAY)
|
||||||
|
# gray = self.preprocess_image(img_resized)
|
||||||
|
if gray.dtype != np.uint8:
|
||||||
|
gray = (gray * 255).astype(np.uint8)
|
||||||
|
|
||||||
|
# 设置LBP参数
|
||||||
|
radius = 1
|
||||||
|
n_points = 8 * radius
|
||||||
|
lbp = local_binary_pattern(gray, n_points, radius, method='uniform')
|
||||||
|
|
||||||
|
# 计算LBP的直方图
|
||||||
|
n_bins = int(lbp.max() + 1)
|
||||||
|
hist, _ = np.histogram(lbp, bins=n_bins, range=(0, n_bins), density=True)
|
||||||
|
features = hist
|
||||||
|
return features
|
||||||
|
|
||||||
|
def augment_image(self, img):
|
||||||
|
"""对输入图像应用随机数据增强"""
|
||||||
|
# 随机旋转 ±30 度
|
||||||
|
angle = np.random.uniform(-30, 30)
|
||||||
|
M = cv2.getRotationMatrix2D((img.shape[1] / 2, img.shape[0] / 2), angle, 1)
|
||||||
|
img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0]))
|
||||||
|
|
||||||
|
# 随机水平翻转
|
||||||
|
if np.random.rand() > 0.5:
|
||||||
|
img = cv2.flip(img, 1)
|
||||||
|
|
||||||
|
# 随机调整亮度 ±50
|
||||||
|
value = int(np.random.uniform(-50, 50))
|
||||||
|
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
|
||||||
|
h, s, v = cv2.split(hsv)
|
||||||
|
v = cv2.add(v, value)
|
||||||
|
v[v > 255] = 255
|
||||||
|
v[v < 0] = 0
|
||||||
|
final_hsv = cv2.merge((h, s, v))
|
||||||
|
img = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR)
|
||||||
|
|
||||||
|
return img
|
||||||
|
|
||||||
|
def get_image_data(self, img_dir="./data/wenli/Huawen",augment=False):
|
||||||
|
"""
|
||||||
|
:param img_dir: 图像文件的路径
|
||||||
|
:return: 图像数据
|
||||||
|
"""
|
||||||
|
img_data = []
|
||||||
|
img_name = []
|
||||||
|
utils.mkdir_if_not_exist(img_dir)
|
||||||
|
files = os.listdir(img_dir)
|
||||||
|
if len(files) == 0:
|
||||||
|
return False
|
||||||
|
for file in files:
|
||||||
|
path = os.path.join(img_dir, file)
|
||||||
|
if self.debug_mode:
|
||||||
|
self.log.log(path)
|
||||||
|
train_img = cv2.imread(path)
|
||||||
|
if augment:
|
||||||
|
train_img = self.augment_image(train_img) # 应用数据增强
|
||||||
|
data = self.extract_features(train_img)
|
||||||
|
img_data.append(data)
|
||||||
|
img_name.append(file)
|
||||||
|
img_data = np.array(img_data)
|
||||||
|
|
||||||
|
return img_data, img_name
|
||||||
|
|
||||||
|
def get_train_data(self, data_dir=None, plot_2d=False, save_data=False, augment=False):
|
||||||
|
"""
|
||||||
|
获取图像数据
|
||||||
|
:return: x_data, y_data
|
||||||
|
"""
|
||||||
|
print("开始加载训练数据...")
|
||||||
|
data_dir = os.path.join(ROOT_DIR, "data", "wenli") if data_dir is None else data_dir
|
||||||
|
hw_data, hw_name = self.get_image_data(img_dir=os.path.join(data_dir, "Huawen"), augment=augment)
|
||||||
|
zw_data, zw_name = self.get_image_data(img_dir=os.path.join(data_dir, "Zhiwen"), augment=augment)
|
||||||
|
if (hw_data is False) or (zw_data is False) :
|
||||||
|
return False
|
||||||
|
x_data = np.vstack((hw_data, zw_data))
|
||||||
|
|
||||||
|
hw_label = np.zeros(len(hw_data)).T # 为"Huawen"类图像赋值标签0
|
||||||
|
zw_label = np.ones(len(zw_data)).T # 为"Zhiwen"类图像赋值标签1
|
||||||
|
y_data = np.hstack((hw_label, zw_label))
|
||||||
|
|
||||||
|
img_name = hw_name + zw_name
|
||||||
|
|
||||||
|
# if plot_2d:
|
||||||
|
# plt.scatter(x_data[:, 0], x_data[:, 1], c=y_data, cmap='viridis')
|
||||||
|
# plt.show()
|
||||||
|
if save_data:
|
||||||
|
with open(os.path.join("data", "data.p"), "rb") as f:
|
||||||
|
pass
|
||||||
|
if (hw_data is False) or (zw_data is False):
|
||||||
|
print("未找到有效的训练数据")
|
||||||
|
return False
|
||||||
|
print("训练数据加载完成")
|
||||||
|
return x_data, y_data, img_name
|
||||||
|
|
||||||
|
def fit_pictures(self, data_path=ROOT_DIR, file_name=None, augment=False):
|
||||||
|
"""
|
||||||
|
根据给出的data_path 进行 fit.如果没有给出data目录,那么将会使用当前文件夹
|
||||||
|
:param data_path:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
print("开始训练模型...")
|
||||||
|
# 训练数据文件位置
|
||||||
|
result = self.get_train_data(data_path, plot_2d=True, augment=augment)
|
||||||
|
if result is False:
|
||||||
|
print("训练数据加载失败,中止训练")
|
||||||
|
return 0
|
||||||
|
x, y, name = result
|
||||||
|
print("训练数据加载成功,开始模型训练")
|
||||||
|
score = self.fit(x, y)
|
||||||
|
print('model score', score)
|
||||||
|
model_name = self.save(file_name)
|
||||||
|
return model_name
|
||||||
|
|
||||||
|
def fit(self, X, y):
|
||||||
|
"""训练模型"""
|
||||||
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
|
||||||
|
self.model.fit(X_train, y_train)
|
||||||
|
y_pred = self.model.predict(X_test)
|
||||||
|
|
||||||
|
print(confusion_matrix(y_test, y_pred))
|
||||||
|
|
||||||
|
pre_score = accuracy_score(y_test, y_pred) # 计算测试集的准确率
|
||||||
|
self.log.log("Test accuracy is:" + str(pre_score * 100) + "%.")
|
||||||
|
|
||||||
|
y_pred = self.model.predict(X_train)
|
||||||
|
pre_score = accuracy_score(y_train, y_pred)
|
||||||
|
self.log.log("Train accuracy is:" + str(pre_score * 100) + "%.")
|
||||||
|
|
||||||
|
y_pred = self.model.predict(X)
|
||||||
|
pre_score = accuracy_score(y, y_pred)
|
||||||
|
self.log.log("Total accuracy is:" + str(pre_score * 100) + "%.")
|
||||||
|
|
||||||
|
return int(pre_score * 100)
|
||||||
|
|
||||||
|
def predict(self, img):
|
||||||
|
"""预测图像纹理"""
|
||||||
|
if self.debug_mode:
|
||||||
|
cv2.imwrite(str(self.image_num) + ".bmp", img)
|
||||||
|
self.image_num += 1
|
||||||
|
features = self.extract_features(img) # 提取图像特征
|
||||||
|
features = np.array(features) # 将列表转换为 Numpy 数组
|
||||||
|
features = features.reshape(1, -1) # 使用 reshape(1, -1) 将特征数组转换成适合模型预测的形状。这里的 1 表示样本数为一(单个图像),-1 表示自动计算特征数量。
|
||||||
|
pred_wenli = self.model.predict(features) # 使用模型预测图像的纹理
|
||||||
|
if self.debug_mode:
|
||||||
|
self.log.log(features)
|
||||||
|
return int(pred_wenli[0]) # 从模型预测结果中提取第一个元素(假设预测结果是一个数组),并将其转换为整数。这通常是分类任务的类别标签。
|
||||||
|
|
||||||
|
|
||||||
|
def save(self, file_name):
|
||||||
|
"""保存模型到文件"""
|
||||||
|
if file_name is None:
|
||||||
|
file_name = "model_" + time.strftime("%Y-%m-%d_%H-%M") + ".p"
|
||||||
|
file_name = os.path.join(ROOT_DIR, "models", file_name)
|
||||||
|
model_dic = { "model": self.model,"w": self.w, "h": self.h}
|
||||||
|
with open(file_name, "wb") as f:
|
||||||
|
pickle.dump(model_dic, f)
|
||||||
|
self.log.log("Save file to '" + str(file_name) + "'")
|
||||||
|
return file_name
|
||||||
|
|
||||||
|
|
||||||
|
def load(self, path=None):
|
||||||
|
if path is None:
|
||||||
|
path = os.path.join(ROOT_DIR, "models")
|
||||||
|
utils.mkdir_if_not_exist(path)
|
||||||
|
model_files = os.listdir(path)
|
||||||
|
if len(model_files) == 0:
|
||||||
|
self.log.log("No model found!")
|
||||||
|
return 1
|
||||||
|
self.log.log("./ Models Found:")
|
||||||
|
_ = [self.log.log("├--" + str(model_file)) for model_file in model_files]
|
||||||
|
file_times = [model_file[6:-2] for model_file in model_files]
|
||||||
|
latest_model = model_files[int(np.argmax(file_times))]
|
||||||
|
self.log.log("└--Using the latest model: " + str(latest_model))
|
||||||
|
path = os.path.join(ROOT_DIR, "models", str(latest_model))
|
||||||
|
if not os.path.isabs(path):
|
||||||
|
logging.warning('给的是相对路径')
|
||||||
|
return -1
|
||||||
|
if not os.path.exists(path):
|
||||||
|
logging.warning('文件不存在')
|
||||||
|
return -1
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
model_dic = pickle.load(f)
|
||||||
|
|
||||||
|
self.model = model_dic["model"]
|
||||||
|
self.w = model_dic["w"]
|
||||||
|
self.h = model_dic["h"]
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from config import Config
|
||||||
|
|
||||||
|
# 加载配置设置
|
||||||
|
settings = Config()
|
||||||
|
|
||||||
|
# 初始化 TextureClass 实例
|
||||||
|
texture = TextureClass(w=4096, h=1200, debug_mode=False)
|
||||||
|
print("初始化纹理分类器完成")
|
||||||
|
|
||||||
|
# 获取数据路径
|
||||||
|
data_path = settings.data_path
|
||||||
|
# texture.correct() # 如果有色彩校正的步骤可以添加
|
||||||
|
# texture.load() # 如果需要加载先前的模型可以调用
|
||||||
|
|
||||||
|
# 训练模型并保存
|
||||||
|
model_path = texture.fit_pictures(data_path=data_path)
|
||||||
|
print(f"模型保存在 {model_path}")
|
||||||
|
|
||||||
|
# # 加载K-means数据并进行数据调整
|
||||||
|
# x_data, y_data, labels, img_names = texture.get_kmeans_data(data_path, plot_2d=True)
|
||||||
|
# send_data = texture.data_adjustments(x_data, y_data, labels, img_names)
|
||||||
|
# print(f"调整后的数据已发送/保存")
|
||||||
|
|
||||||
|
# 测试单张图片的预测性能
|
||||||
|
pic = cv2.imread(r"data/wenli/Zhiwen/rgb98.png")
|
||||||
|
start_time = time.time()
|
||||||
|
texture_type = texture.predict(pic)
|
||||||
|
end_time = time.time()
|
||||||
|
print("单次预测耗时:", (end_time - start_time) * 1000, "毫秒")
|
||||||
|
print("预测的纹理类型:", texture_type)
|
||||||
|
|
||||||
|
# 如果有批量处理或性能测试需求
|
||||||
|
# 总时间 = 0
|
||||||
|
# for i in range(100):
|
||||||
|
# start_time = time.time()
|
||||||
|
# _ = texture.predict(pic)
|
||||||
|
# end_time = time.time()
|
||||||
|
# total_time += (end_time - start_time)
|
||||||
|
# print("平均预测时间:", (total_time / 100) * 1000, "毫秒")
|
||||||
60
config.py
Normal file
60
config.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
# -*- codeing = utf-8 -*-
|
||||||
|
# Time : 2022/10/17 11:07
|
||||||
|
# @Auther : zhouchao
|
||||||
|
# @File: config.py
|
||||||
|
# @Software:PyCharm
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import WindowsPath
|
||||||
|
|
||||||
|
from root_dir import ROOT_DIR
|
||||||
|
|
||||||
|
|
||||||
|
class Config(object):
|
||||||
|
model_path = ROOT_DIR / 'config.json' # “/”运算符被重载,用于连接两个路径,表示 ROOT_DIR 目录下的 config.json 文件
|
||||||
|
|
||||||
|
def __init__(self): # 初始化方法
|
||||||
|
self._param_dict = {} # 创建一个空字典
|
||||||
|
if os.path.exists(Config.model_path):
|
||||||
|
self._read() # 如果 config.json 文件存在,调用 _read 方法
|
||||||
|
else:
|
||||||
|
self.model_path = str(ROOT_DIR / 'models/model_2024-04-18_10-16.p') # 模型路径
|
||||||
|
self.data_path = str(ROOT_DIR / 'data/xiangmu_photos_wenli') # 数据路径
|
||||||
|
self.database_addr = str("mysql+pymysql://root:@localhost:3306/orm_test") # 测试用数据库地址
|
||||||
|
self._param_dict['model_path'] = self.model_path # 将模型路径写入 _param_dict 属性中
|
||||||
|
self._param_dict['data_path'] = self.data_path
|
||||||
|
self._param_dict['database_addr'] = self.database_addr
|
||||||
|
|
||||||
|
def __setitem__(self, key, value): # 重载 __setitem__ 方法
|
||||||
|
if key in self._param_dict: # 如果 key 在 _param_dict 中 key是键,value是值 用于设置值
|
||||||
|
self._param_dict[key] = value # 将键值对写入 _param_dict 属性中
|
||||||
|
self._write() # 将 _param_dict 属性写入 config.json 文件
|
||||||
|
|
||||||
|
def __getitem__(self, item): # 重载 __getitem__ 方法
|
||||||
|
if item in self._param_dict: # 如果 item 在 _param_dict 中 item是键 value是值 用于获取值
|
||||||
|
return self._param_dict[item] # 返回 _param_dict 中 item 键的值
|
||||||
|
|
||||||
|
def __setattr__(self, key, value):
|
||||||
|
self.__dict__[key] = value # 直接在对象的 __dict__ 属性(这是一个存储对象所有属性的字典)中设置键(属性名)和值
|
||||||
|
if '_param_dict' in self.__dict__ and key != '_param_dict': # 如果 _param_dict 属性存在,且不是 _param_dict 属性本身
|
||||||
|
if isinstance(value, WindowsPath): # 如果 value 是 WindowsPath 对象,将其转换为字符串
|
||||||
|
value = str(value) # WindowsPath 对象不能被 json 序列化,需要转换为字符串
|
||||||
|
self.__dict__['_param_dict'][key] = value # 将键值对写入 _param_dict 属性中
|
||||||
|
self._write() # 将 _param_dict 属性写入 config.json 文件
|
||||||
|
|
||||||
|
def _read(self): # 读取 config.json 文件
|
||||||
|
with open(Config.model_path, 'r') as f: # 打开 config.json 文件
|
||||||
|
self._param_dict = json.load(f) # 读取文件内容,将其转换为字典
|
||||||
|
self.data_path = self._param_dict['data_path'] # 从字典中读取 data_path 键的值
|
||||||
|
self.model_path = self._param_dict['model_path'] # 从字典中读取 model_path 键的值
|
||||||
|
self.database_addr = self._param_dict['database_addr']
|
||||||
|
|
||||||
|
def _write(self): # 将 _param_dict 属性写入 config.json 文件
|
||||||
|
with open(Config.model_path, 'w') as f: # 打开 config.json 文件
|
||||||
|
json.dump(self._param_dict, f) # 将 _param_dict 写入文件
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
config = Config()
|
||||||
|
print(config.model_path)
|
||||||
|
print(config.data_path)
|
||||||
65
database.py
Normal file
65
database.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
# -*- codeing = utf-8 -*-
|
||||||
|
# Time : 2022/10/20 13:54
|
||||||
|
# @Auther : zhouchao
|
||||||
|
# @File: database.py
|
||||||
|
# @Software:PyCharm
|
||||||
|
import datetime
|
||||||
|
import time
|
||||||
|
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from sqlalchemy.dialects.mysql import INTEGER, VARCHAR
|
||||||
|
from sqlalchemy import Column, TIMESTAMP, DATETIME
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
|
class Wood(Base):
|
||||||
|
__tablename__ = 'color'
|
||||||
|
|
||||||
|
id = Column(INTEGER, primary_key=True)
|
||||||
|
color = Column(VARCHAR(256), nullable=False)
|
||||||
|
time = Column(DATETIME, nullable=False)
|
||||||
|
|
||||||
|
|
||||||
|
def __init__(self, color):
|
||||||
|
self.time = datetime.datetime.now()
|
||||||
|
self.color = color
|
||||||
|
|
||||||
|
|
||||||
|
class Database(object):
|
||||||
|
def __init__(self, database_addr):
|
||||||
|
self.database_addr = database_addr
|
||||||
|
|
||||||
|
def init_db(self):
|
||||||
|
engine = create_engine(self.database_addr, encoding="utf-8", echo=True)
|
||||||
|
Base.metadata.create_all(engine)
|
||||||
|
print('Create table successfully!')
|
||||||
|
|
||||||
|
def add_data(self, color):
|
||||||
|
# 初始化数据库连接
|
||||||
|
engine = create_engine(self.database_addr, encoding="utf-8")
|
||||||
|
# 创建DBSession类型
|
||||||
|
DBSession = sessionmaker(bind=engine)
|
||||||
|
|
||||||
|
# 创建session对象
|
||||||
|
session = DBSession()
|
||||||
|
# 插入单条数据
|
||||||
|
# 创建新User对象
|
||||||
|
new_wood = Wood(color)
|
||||||
|
# 添加到session
|
||||||
|
session.add(new_wood)
|
||||||
|
# 提交即保存到数据库
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_addr = "mysql+pymysql://root:@localhost:3306/color"
|
||||||
|
database = Database(test_addr)
|
||||||
|
# database.init_db()
|
||||||
|
t1 = time.time()
|
||||||
|
for i in range(100):
|
||||||
|
database.add_data('middle')
|
||||||
|
t2 = time.time()
|
||||||
|
print((t2-t1)/100)
|
||||||
10
requirements.txt
Normal file
10
requirements.txt
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
lightgbm==4.3.0
|
||||||
|
matplotlib==3.5.2
|
||||||
|
numpy==1.23.1
|
||||||
|
numpy==1.25.1
|
||||||
|
opencv_contrib_python==4.7.0.72
|
||||||
|
opencv_python==4.6.0.66
|
||||||
|
scikit_learn==1.1.1
|
||||||
|
scipy==1.9.0
|
||||||
|
skimage==0.0
|
||||||
|
SQLAlchemy==2.0.19
|
||||||
13
root_dir.py
Normal file
13
root_dir.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Created on Nov 3 21:18:26 2020
|
||||||
|
|
||||||
|
@author: l.z.y
|
||||||
|
@e-mail: li.zhenye@qq.com
|
||||||
|
"""
|
||||||
|
import pathlib
|
||||||
|
|
||||||
|
file_path = pathlib.Path(__file__) # 获得当前文件路径
|
||||||
|
ROOT_DIR = file_path.parent # 获得当前文件的父目录
|
||||||
|
|
||||||
|
# pathlib的作用是将文件路径转换为操作系统的路径格式,这样可以避免不同操作系统的路径格式不同的问题
|
||||||
110
socket_detector.py
Normal file
110
socket_detector.py
Normal file
@ -0,0 +1,110 @@
|
|||||||
|
import socket
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
import root_dir
|
||||||
|
from classifer import TextureClass
|
||||||
|
import time
|
||||||
|
import os
|
||||||
|
|
||||||
|
# from database import Database
|
||||||
|
from root_dir import ROOT_DIR
|
||||||
|
from utils import PreSocket, receive_sock, parse_protocol, ack_sock, done_sock, DualSock, simple_sock
|
||||||
|
import logging
|
||||||
|
from config import Config
|
||||||
|
|
||||||
|
|
||||||
|
def process_cmd(cmd: str, data: any, connected_sock: socket.socket, detector: TextureClass, settings: Config) -> tuple:
|
||||||
|
"""
|
||||||
|
处理指令
|
||||||
|
|
||||||
|
:param cmd: 指令类型
|
||||||
|
:param data: 指令内容
|
||||||
|
:param connected_sock: socket
|
||||||
|
:param detector: 模型
|
||||||
|
:return: 是否处理成功
|
||||||
|
"""
|
||||||
|
result = ''
|
||||||
|
if cmd == 'IM':
|
||||||
|
data = np.clip(data, 0, 255).astype(dtype=np.uint8) # 将 data 数组中的每个元素限制在范围 [0, 255] 内,将数组的数据类型转换为无符号 8 位整数
|
||||||
|
wood_wenli = detector.predict(data)
|
||||||
|
result = {0: 'Huawen', 1: 'Zhiwen'}[wood_wenli]
|
||||||
|
response = simple_sock(connected_sock, cmd_type=cmd, result=wood_wenli)
|
||||||
|
elif cmd == 'TR':
|
||||||
|
detector = TextureClass(w=4096, h=1200, debug_mode=False)
|
||||||
|
model_name = None
|
||||||
|
if "$" in data:
|
||||||
|
data, model_name = data.split("$", 1)
|
||||||
|
model_name = model_name + ".p"
|
||||||
|
settings.data_path = data
|
||||||
|
settings.model_path = ROOT_DIR / 'models' / detector.fit_pictures(data_path=settings.data_path, file_name=model_name)
|
||||||
|
response = simple_sock(connected_sock, cmd_type=cmd, result=result)
|
||||||
|
elif cmd == 'MD':
|
||||||
|
settings.model_path = data
|
||||||
|
detector.load(path=settings.model_path)
|
||||||
|
response = simple_sock(connected_sock, cmd_type=cmd, result=result)
|
||||||
|
|
||||||
|
|
||||||
|
else:
|
||||||
|
logging.error(f'错误指令,指令为{cmd}')
|
||||||
|
response = False
|
||||||
|
return response, result
|
||||||
|
|
||||||
|
|
||||||
|
def main(is_debug=False):
|
||||||
|
settings = Config() # 创建一个配置类的实例
|
||||||
|
file_handler = logging.FileHandler(os.path.join(ROOT_DIR, 'report.log')) # 创建一个文件处理器,将日志信息写入到指定的 report.log 文件
|
||||||
|
file_handler.setLevel(logging.DEBUG if is_debug else logging.WARNING) # 设置日志级别, DEBUG 级别最低,可以输出所有级别的日志信息; WARNING 级别最高,只能输出 WARNING 级别及以上的日志信息
|
||||||
|
console_handler = logging.StreamHandler(sys.stdout) # 创建一个控制台处理器,将日志信息输出到控制台
|
||||||
|
console_handler.setLevel(logging.DEBUG if is_debug else logging.WARNING) # 设置日志级别
|
||||||
|
logging.basicConfig(format='%(asctime)s %(filename)s[line:%(lineno)d] - %(levelname)s - %(message)s',
|
||||||
|
handlers=[file_handler, console_handler], level=logging.DEBUG)
|
||||||
|
dual_sock = DualSock(connect_ip='127.0.0.1')
|
||||||
|
|
||||||
|
# database = Database(settings.database_addr)
|
||||||
|
|
||||||
|
|
||||||
|
while not dual_sock.status:
|
||||||
|
dual_sock.reconnect()
|
||||||
|
detector = TextureClass(w=4096, h=1200, debug_mode=False)
|
||||||
|
detector.load(path=settings.model_path)
|
||||||
|
_ = detector.predict(np.random.randint(1, 254, (1200, 4096, 3), dtype=np.uint8))
|
||||||
|
while True:
|
||||||
|
pack, next_pack = receive_sock(dual_sock)
|
||||||
|
if pack == b"":
|
||||||
|
time.sleep(5)
|
||||||
|
dual_sock.reconnect()
|
||||||
|
continue
|
||||||
|
|
||||||
|
cmd, data = parse_protocol(pack)
|
||||||
|
# ack_sock(received_sock, cmd_type=cmd)
|
||||||
|
response, result = process_cmd(cmd=cmd, data=data, connected_sock=dual_sock, detector=detector, settings=settings)
|
||||||
|
|
||||||
|
# if result != "":
|
||||||
|
# database.add_data(result)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# 2个端口
|
||||||
|
# 接受端口21122
|
||||||
|
# 发送端口21123
|
||||||
|
# 接收到图片 n_rows * n_bands * n_cols, float32
|
||||||
|
# 发送图片 n_rows * n_cols, uint8
|
||||||
|
main(is_debug=False)
|
||||||
|
# test(r"D:\build-tobacco-Desktop_Qt_5_9_0_MSVC2015_64bit-Release\calibrated15.raw")
|
||||||
|
# main()
|
||||||
|
# debug_main()
|
||||||
|
# test_run(all_data_dir=r'D:\数据')
|
||||||
|
# with open(r'D:\数据\虫子\valid2.raw', 'rb') as f:
|
||||||
|
# data = np.frombuffer(f.read(), dtype=np.float32).reshape(600, 29, 1024).transpose(0, 2, 1)
|
||||||
|
# plt.matshow(data[:, :, 10])
|
||||||
|
# plt.show()
|
||||||
|
# detector = SpecDetector('model_spec/model_29.p')
|
||||||
|
# result = detector.predict(data)
|
||||||
|
#
|
||||||
|
# plt.matshow(result)
|
||||||
|
# plt.show()
|
||||||
|
# result = result.reshape((600, 1024))
|
||||||
347
utils.py
Normal file
347
utils.py
Normal file
@ -0,0 +1,347 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Created on Nov 3 21:18:26 2020
|
||||||
|
|
||||||
|
@author: l.z.y
|
||||||
|
@e-mail: li.zhenye@qq.com
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import time
|
||||||
|
import socket
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
def mkdir_if_not_exist(dir_name, is_delete=False):
|
||||||
|
"""
|
||||||
|
创建文件夹
|
||||||
|
:param dir_name: 文件夹
|
||||||
|
:param is_delete: 是否删除
|
||||||
|
:return: 是否成功
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if is_delete:
|
||||||
|
if os.path.exists(dir_name): # 如果 is_delete 为 True,函数会检查 dir_name 是否已经存在。
|
||||||
|
shutil.rmtree(dir_name) # 如果存在,它会使用 shutil.rmtree 删除整个目录结构,并输出一个信息消息。
|
||||||
|
print('[Info] 文件夹 "%s" 存在, 删除文件夹.' % dir_name)
|
||||||
|
|
||||||
|
if not os.path.exists(dir_name): # 使用 os.path.exists 来检查 dir_name 是否存在。
|
||||||
|
os.makedirs(dir_name) # 如果不存在,它会使用 os.makedirs 创建整个目录树。
|
||||||
|
print('[Info] 文件夹 "%s" 不存在, 创建文件夹.' % dir_name)
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print('[Exception] %s' % e)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def create_file(file_name):
|
||||||
|
"""
|
||||||
|
创建文件
|
||||||
|
:param file_name: 文件名
|
||||||
|
:return: None
|
||||||
|
"""
|
||||||
|
if os.path.exists(file_name):
|
||||||
|
print("文件存在:%s" % file_name)
|
||||||
|
return False
|
||||||
|
# os.remove(file_name) # 删除已有文件
|
||||||
|
if not os.path.exists(file_name):
|
||||||
|
print("文件不存在,创建文件:%s" % file_name)
|
||||||
|
open(file_name, 'a').close()
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class Logger(object): # 日志类
|
||||||
|
def __init__(self, is_to_file=False, path=None):
|
||||||
|
self.is_to_file = is_to_file # 布尔值,用于决定日志是保存到文件(True)还是输出到控制台(False)
|
||||||
|
if path is None: # 如果没有指定日志文件的路径,则默认为当前目录下的 wood.log
|
||||||
|
path = "wood.log"
|
||||||
|
self.path = path # 指定保存日志的文件路径。
|
||||||
|
create_file(path) # 确保日志文件存在,如果不存在则创建它
|
||||||
|
|
||||||
|
def log(self, content):
|
||||||
|
if self.is_to_file: # 如果指定了日志文件的路径,则将日志信息写入到日志文件中
|
||||||
|
with open(self.path, "a") as f: # 以追加的方式打开日志文件,'a'(Append 模式): 在这种模式下,如果文件已经存在,新的数据会被写入到文件的末尾。如果文件不存在,它会被创建。
|
||||||
|
print(time.strftime("[%Y-%m-%d_%H-%M-%S]:"), file=f) # 将当前时间戳和实际日志内容写入到文件,显示为[年-月-日_时-分-秒]: 的格式。
|
||||||
|
print(content, file=f)
|
||||||
|
else:
|
||||||
|
print(content) # 如果没有指定日志文件的路径,则将日志信息输出到控制台
|
||||||
|
|
||||||
|
|
||||||
|
def try_connect(connect_ip: str, port_number: int, is_repeat: bool = False, max_reconnect_times: int = 50) -> (
|
||||||
|
bool, socket.socket):
|
||||||
|
"""
|
||||||
|
尝试连接.
|
||||||
|
|
||||||
|
:param is_repeat: 是否是重新连接
|
||||||
|
:param max_reconnect_times:最大重连次数
|
||||||
|
:return: (连接状态True为成功, Socket / None)
|
||||||
|
"""
|
||||||
|
reconnect_time = 0
|
||||||
|
while reconnect_time < max_reconnect_times:
|
||||||
|
logging.warning(f'尝试{"重新" if is_repeat else ""}发起第{reconnect_time + 1}次连接...')
|
||||||
|
try:
|
||||||
|
connected_sock = PreSocket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
connected_sock.connect((connect_ip, port_number))
|
||||||
|
except Exception as e:
|
||||||
|
reconnect_time += 1
|
||||||
|
logging.error(f'第{reconnect_time}次连接失败... 5秒后重新连接...\n {e}')
|
||||||
|
time.sleep(5)
|
||||||
|
continue
|
||||||
|
logging.warning(f'{"重新" if is_repeat else ""}连接成功')
|
||||||
|
return True, connected_sock
|
||||||
|
return False, None
|
||||||
|
|
||||||
|
|
||||||
|
class PreSocket(socket.socket):
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.pre_pack = b''
|
||||||
|
self.settimeout(5)
|
||||||
|
|
||||||
|
def receive(self, *args, **kwargs):
|
||||||
|
if self.pre_pack == b'':
|
||||||
|
return self.recv(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
data_len = args[0]
|
||||||
|
required, left = self.pre_pack[:data_len], self.pre_pack[data_len:]
|
||||||
|
self.pre_pack = left
|
||||||
|
return required
|
||||||
|
|
||||||
|
def set_prepack(self, pre_pack: bytes):
|
||||||
|
temp = self.pre_pack
|
||||||
|
self.pre_pack = temp + pre_pack
|
||||||
|
|
||||||
|
|
||||||
|
class DualSock(PreSocket):
|
||||||
|
def __init__(self, connect_ip='127.0.0.1', recv_port: int = 21122, send_port: int = 21123):
|
||||||
|
super().__init__()
|
||||||
|
received_status, self.received_sock = try_connect(connect_ip=connect_ip, port_number=recv_port) # 这两行代码分别设置接收和发送的sockets。
|
||||||
|
send_status, self.send_sock = try_connect(connect_ip=connect_ip, port_number=send_port)
|
||||||
|
self.status = received_status and send_status
|
||||||
|
|
||||||
|
def send(self, *args, **kwargs) -> int:
|
||||||
|
return self.send_sock.send(*args, **kwargs)
|
||||||
|
|
||||||
|
def receive(self, *args, **kwargs) -> bytes:
|
||||||
|
return self.received_sock.receive(*args, **kwargs)
|
||||||
|
|
||||||
|
def set_prepack(self, pre_pack: bytes):
|
||||||
|
self.received_sock.set_prepack(pre_pack)
|
||||||
|
|
||||||
|
def reconnect(self, connect_ip='127.0.0.1', recv_port:int = 21122, send_port: int = 21123):
|
||||||
|
received_status, self.received_sock = try_connect(connect_ip=connect_ip, port_number=recv_port)
|
||||||
|
send_status, self.send_sock = try_connect(connect_ip=connect_ip, port_number=send_port)
|
||||||
|
return received_status and send_status
|
||||||
|
|
||||||
|
|
||||||
|
def receive_sock(recv_sock: PreSocket, pre_pack: bytes = b'', time_out: float = -1.0, time_out_single=5e20) -> (
|
||||||
|
bytes, bytes):
|
||||||
|
"""
|
||||||
|
从指定的socket中读取数据.
|
||||||
|
|
||||||
|
:param recv_sock: 指定sock
|
||||||
|
:param pre_pack: 上一包的粘包内容
|
||||||
|
:param time_out: 每隔time_out至少要发来一次指令,否则认为出现问题进行重连,小于0则为一直等
|
||||||
|
:param time_out_single: 单次指令超时时间,单位是秒
|
||||||
|
:return: data, next_pack
|
||||||
|
"""
|
||||||
|
recv_sock.set_prepack(pre_pack)
|
||||||
|
# 开头校验
|
||||||
|
time_start_recv = time.time()
|
||||||
|
while True:
|
||||||
|
if time_out > 0:
|
||||||
|
if (time.time() - time_start_recv) > time_out:
|
||||||
|
logging.error(f'指令接收超时')
|
||||||
|
return b'', b''
|
||||||
|
try:
|
||||||
|
temp = recv_sock.receive(1)
|
||||||
|
except ConnectionError as e:
|
||||||
|
logging.error(f'连接出错, 错误代码:\n{e}')
|
||||||
|
return b'', b''
|
||||||
|
except TimeoutError as e:
|
||||||
|
# logging.error(f'超时了,错误代码: \n{e}')
|
||||||
|
logging.info('运行中,等待指令..')
|
||||||
|
continue
|
||||||
|
except socket.timeout as e:
|
||||||
|
logging.info('运行中,等待指令..')
|
||||||
|
continue
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f'遇见未知错误,错误代码: \n{e}')
|
||||||
|
return b'', b''
|
||||||
|
if temp == b'\xaa':
|
||||||
|
break
|
||||||
|
|
||||||
|
# 接收开头后,开始进行时间记录
|
||||||
|
time_start_recv = time.time()
|
||||||
|
|
||||||
|
# 获取报文长度
|
||||||
|
temp = b''
|
||||||
|
while len(temp) < 4:
|
||||||
|
if (time.time() - time_start_recv) > time_out_single:
|
||||||
|
logging.error(f'单次指令接收超时')
|
||||||
|
return b'', b''
|
||||||
|
try:
|
||||||
|
temp += recv_sock.receive(1)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f'接收报文的长度不正确, 错误代码: \n{e}')
|
||||||
|
return b'', b''
|
||||||
|
try:
|
||||||
|
data_len = int.from_bytes(temp, byteorder='big')
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f'转换失败,错误代码 \n{e}')
|
||||||
|
return b'', b''
|
||||||
|
|
||||||
|
# 读取报文内容
|
||||||
|
temp = b''
|
||||||
|
while len(temp) < data_len:
|
||||||
|
if (time.time() - time_start_recv) > time_out_single:
|
||||||
|
logging.error(f'单次指令接收超时')
|
||||||
|
return b'', b''
|
||||||
|
try:
|
||||||
|
temp += recv_sock.receive(data_len)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f'接收报文内容失败, 错误代码: \n{e}')
|
||||||
|
return b'', b''
|
||||||
|
data, next_pack = temp[:data_len], temp[data_len:]
|
||||||
|
recv_sock.set_prepack(next_pack)
|
||||||
|
next_pack = b''
|
||||||
|
|
||||||
|
# 进行数据校验
|
||||||
|
temp = b''
|
||||||
|
while len(temp) < 3:
|
||||||
|
if (time.time() - time_start_recv) > time_out_single:
|
||||||
|
logging.error(f'单次指令接收超时')
|
||||||
|
return b'', b''
|
||||||
|
try:
|
||||||
|
temp += recv_sock.receive(1)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f'接收报文校验失败, 错误代码: \n{e}')
|
||||||
|
return b'', b''
|
||||||
|
if temp == b'\xff\xff\xbb':
|
||||||
|
return data, next_pack
|
||||||
|
else:
|
||||||
|
logging.error(f"接收了一个完美的只错了校验位的报文")
|
||||||
|
return b'', b''
|
||||||
|
|
||||||
|
|
||||||
|
def parse_protocol(data: bytes) -> (str, any): # data: 参数类型为 bytes,表示接收到的报文数据。返回类型 (str, any) 表示函数返回一个元组,其中第一个元素是指令类型的字符串,第二个元素是指令对应的内容,内容的类型可以是任何类型(由指令决定)。
|
||||||
|
"""
|
||||||
|
指令转换.
|
||||||
|
|
||||||
|
:param data:接收到的报文
|
||||||
|
:return: 指令类型和内容
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
assert len(data) > 4
|
||||||
|
except AssertionError:
|
||||||
|
logging.error('指令转换失败,长度不足5')
|
||||||
|
return '', None
|
||||||
|
cmd, data = data[:4], data[4:] # 从 data 中取出前 4 个字节作为指令类型,剩下的部分作为指令内容。
|
||||||
|
cmd = cmd.decode('ascii').strip().upper() # 将指令类型转换为字符串,并去除首尾空格,然后转换为大写。
|
||||||
|
if cmd == 'IM':
|
||||||
|
n_rows, n_cols, img = data[:2], data[2:4], data[4:] # 按照协议,先是两个字节的行数(高),两个字节的列数(宽),后面是图像数据
|
||||||
|
try:
|
||||||
|
n_rows, n_cols = [int.from_bytes(x, byteorder='big') for x in [n_rows, n_cols]]
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f'长宽转换失败, 错误代码{e}, 报文大小: n_rows:{n_rows}, n_cols: {n_cols}')
|
||||||
|
return '', None
|
||||||
|
try:
|
||||||
|
assert n_rows * n_cols * 3 == len(img)
|
||||||
|
# 因为是float32类型 所以长度要乘12 ,如果是uint8则乘3
|
||||||
|
except AssertionError:
|
||||||
|
logging.error('图像指令IM转换失败,数据长度错误')
|
||||||
|
return '', None
|
||||||
|
img = np.frombuffer(img, dtype=np.uint8).reshape((n_rows, n_cols, -1))
|
||||||
|
return cmd, img
|
||||||
|
elif cmd == 'TR':
|
||||||
|
data = data.decode('ascii')
|
||||||
|
return cmd, data
|
||||||
|
elif cmd == 'MD':
|
||||||
|
data = data.decode('ascii')
|
||||||
|
return cmd, data
|
||||||
|
|
||||||
|
|
||||||
|
def ack_sock(send_sock: socket.socket, cmd_type: str) -> bool: # 未使用
|
||||||
|
'''
|
||||||
|
发送应答
|
||||||
|
:param cmd_type:指令类型
|
||||||
|
:param send_sock:指定sock
|
||||||
|
:return:是否发送成功
|
||||||
|
'''
|
||||||
|
msg = b'\xaa\x00\x00\x00\x05' + (' A' + cmd_type).upper().encode('ascii') + b'\xff\xff\xff\xbb'
|
||||||
|
try:
|
||||||
|
send_sock.send(msg)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f'发送应答失败,错误类型:{e}')
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def done_sock(send_sock: socket.socket, cmd_type: str, result = '') -> bool: # 未使用
|
||||||
|
'''
|
||||||
|
发送任务完成指令
|
||||||
|
:param cmd_type:指令类型
|
||||||
|
:param send_sock:指定sock
|
||||||
|
:param result:数据
|
||||||
|
:return:是否发送成功
|
||||||
|
'''
|
||||||
|
cmd_type = cmd_type.strip().upper()
|
||||||
|
if (cmd_type == "TR") or (cmd_type == "MD"):
|
||||||
|
if result != '':
|
||||||
|
logging.error('结果在这种指令里很没必要')
|
||||||
|
result = b'\xff'
|
||||||
|
elif cmd_type == 'IM':
|
||||||
|
if result == 0:
|
||||||
|
result = b'H'
|
||||||
|
elif result == 1:
|
||||||
|
result = b'Z'
|
||||||
|
length = len(result) + 4
|
||||||
|
length = length.to_bytes(4, byteorder='big')
|
||||||
|
msg = b'\xaa' +length + (' D' + cmd_type).upper().encode('ascii') + result + b'\xff\xff\xbb'
|
||||||
|
try:
|
||||||
|
send_sock.send(msg)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f'发送完成指令失败,错误类型:{e}')
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def simple_sock(send_sock: socket.socket, cmd_type: str, result) -> bool:
|
||||||
|
'''
|
||||||
|
发送任务完成指令
|
||||||
|
:param cmd_type:指令类型
|
||||||
|
:param send_sock:指定sock
|
||||||
|
:param result:数据
|
||||||
|
:return:是否发送成功
|
||||||
|
'''
|
||||||
|
cmd_type = cmd_type.strip().upper() # 去除空格并转换为大写
|
||||||
|
if cmd_type == 'IM':
|
||||||
|
if result == 0:
|
||||||
|
msg = b'H'
|
||||||
|
elif result == 1:
|
||||||
|
msg = b'Z'
|
||||||
|
elif cmd_type == 'TR':
|
||||||
|
msg = b'A'
|
||||||
|
elif cmd_type == 'MD':
|
||||||
|
msg = b'D'
|
||||||
|
result = result.encode('ascii')
|
||||||
|
result = b',' + result
|
||||||
|
length = len(result)
|
||||||
|
msg = msg + length.to_bytes(4, 'big') + result
|
||||||
|
try:
|
||||||
|
send_sock.send(msg)
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f'发送完成指令失败,错误类型:{e}')
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
log = Logger(is_to_file=True)
|
||||||
|
log.log("nihao")
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
a = np.ones((100, 100, 3))
|
||||||
|
log.log(a.shape)
|
||||||
Loading…
Reference in New Issue
Block a user