第一次添加

This commit is contained in:
duanmu 2024-05-10 19:08:26 +08:00
parent a12c554118
commit 48fc1479d5
9 changed files with 1115 additions and 0 deletions

19
.gitignore vendored Normal file
View File

@ -0,0 +1,19 @@
data/correct/*.bmp
data/dark/*.bmp
data/light/*.bmp
data/middle/*.bmp
data1
data2
data3
data4
data5
.idea
__pycache__
*.pyc
test.py
*.log
.models
data
*.png
models
*.json

162
QT_test.py Normal file
View File

@ -0,0 +1,162 @@
# -*- codeing = utf-8 -*-
# Time : 2022/9/17 15:05
# @Auther : zhouchao
# @File: QT_test.py
# @Software:PyCharm
import logging
import socket
import numpy as np
import cv2
def rec_socket(recv_sock: socket.socket, cmd_type: str, ack: bool) -> bool: # 未使用
if ack:
cmd = 'A' + cmd_type
else:
cmd = 'D' + cmd_type
while True:
try:
temp = recv_sock.recv(1)
except ConnectionError as e:
logging.error(f'连接出错, 错误代码:\n{e}')
return False
except TimeoutError as e:
logging.error(f'超时了,错误代码: \n{e}')
return False
except Exception as e:
logging.error(f'遇见未知错误,错误代码: \n{e}')
return False
if temp == b'\xaa':
break
# 获取报文长度
temp = b''
while len(temp) < 4:
try:
temp += recv_sock.recv(1)
except Exception as e:
logging.error(f'接收报文长度失败, 错误代码: \n{e}')
return False
try:
data_len = int.from_bytes(temp, byteorder='big')
except Exception as e:
logging.error(f'转换失败,错误代码 \n{e}, \n报文内容\n{temp}')
return False
# 读取报文内容
temp = b''
while len(temp) < data_len:
try:
temp += recv_sock.recv(data_len)
except Exception as e:
logging.error(f'接收报文内容失败, 错误代码: \n{e}\n报文内容\n{temp}')
return False
data = temp
if cmd.strip().upper() != data[:4].decode('ascii').strip().upper():
logging.error(f'客户端接收指令错误,\n指令内容\n{data}')
return False
else:
if cmd == 'DIM':
print(data)
# 进行数据校验
temp = b''
while len(temp) < 3:
try:
temp += recv_sock.recv(1)
except Exception as e:
logging.error(f'接收报文校验失败, 错误代码: \n{e}')
return False
if temp == b'\xff\xff\xbb':
return True
else:
logging.error(f"接收了一个完美的只错了校验位的报文,\n data: {data}")
return False
def main(): # 在一个端口上接收文件,在另一个端口上接收控制命令
socket_receive = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # 创建一个socket对象AF_INET是地址簇SOCK_STREAM是socket类型表示TCP连接
socket_receive.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
socket_receive.bind(('127.0.0.1', 21123)) # 127.0.0.1是本机的回环地址意味着socket仅接受从同一台机器上发起的连接请求。21123是端口号它是计算机上用于区分不同服务的数字标签。
socket_receive.listen(5) # 开始监听传入的连接5是指在拒绝连接之前操作系统可以挂起的最大连接数量
socket_send = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
socket_send.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
socket_send.bind(('127.0.0.1', 21122))
socket_send.listen(5)
print('等待连接')
socket_send_1, receive_addr_1 = socket_send.accept()
print("连接成功:", receive_addr_1)
# socket_send_2 = socket_send_1
socket_send_2, receive_addr_2 = socket_receive.accept()
print("连接成功:", receive_addr_2)
while True:
cmd = input().strip().upper()
if cmd == 'IM':
# img = cv2.imread(r"/Users/zhouchao/Library/CloudStorage/OneDrive-macrosolid/PycharmProjects/wood_color/data/data20220919/dark/rgb60.png")
img = cv2.imread(r"D:\Projects\PycharmProjects\xiangmu_wenli\data\wenli\Huawen\huaweisecha (132).png") # 读取图片返回的img对象是一个NumPy数组包含图像的像素数据。
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 将BGR格式的图像转换为RGB格式
img = np.asarray(img, dtype=np.uint8) # 通过np.asarray()确保图像数据是NumPy数组格式dtype=np.uint8表示使用8位无符号整数格式存储每个颜色通道这是图像处理中常用的数据类型。
height = img.shape[0] # 获取图像的高度
width = img.shape[1] # 获取图像的宽度
img_bytes = img.tobytes() # 将图像数据转换为字节流,以便通过网络传输
length = len(img_bytes) + 8 # 计算报文长度,包括命令部分、宽度、高度和图像数据,以及结束符。 + 8这个加法操作包括额外的协议或消息格式所需的字节长度 4 字节用于表示命令类型(例如 'IM'。在某些实现中可能已经固定包含在消息的开始部分。2 字节用于图像的宽度。2 字节用于图像的高度。
length = length.to_bytes(4, byteorder='big') # 将报文长度转换为4字节的大端字节序
height = height.to_bytes(2, byteorder='big') # 将图像高度转换为2字节的大端字节序这样可以确保图像的宽度和高度在网络传输中的顺序是正确的。
width = width.to_bytes(2, byteorder='big') # 将图像宽度转换为2字节的大端字节序
send_message = b'\xaa' + length + (' ' + cmd).upper().encode('ascii') + height + width + img_bytes + b'\xff\xff\xbb' # 消息以'\xaa'开始,包含消息长度、命令代码、图像宽度和高度、图像数据本身,以及结束符'\xff\xff\xbb'来标记消息结束。
socket_send_1.send(send_message)
print('发送成功')
result = socket_send_2.recv(1)
print(result)
# if rec_socket(socket_send_2, cmd_type=cmd, ack=True):
# print('接收指令成功')
# else:
# print('接收指令失败')
# if rec_socket(socket_send_2, cmd_type=cmd, ack=False):
# print('指令执行完毕')
# else:
# print('指令执行失败')
elif cmd == 'TR':
# model = "/Users/zhouchao/Library/CloudStorage/OneDrive-macrosolid/PycharmProjects/wood_color/data/data20220919"
model = r"D:\Projects\PycharmProjects\xiangmu_wenli_2\data\xiangmu_photos_wenli" # 数据路径
model = model.encode('ascii') # 将字符串转换为字节流
length = len(model) + 4 # 计算报文长度 + 4这个加法操作通常包括额外的协议或消息格式所需的字节长度特别是4 字节用于存储整个消息长度的数值本身,表示消息的起始部分。
length = length.to_bytes(4, byteorder='big') # 将报文长度转换为4字节的大端字节序
send_message = b'\xaa' + length + (' ' + cmd).upper().encode('ascii') + model + b'\xff\xff\xbb'
socket_send_1.send(send_message)
print('发送成功')
result = socket_send_2.recv(1)
print(result)
# if rec_socket(socket_send_2, cmd_type=cmd, ack=True):
# print('接收指令成功')
# else:
# print('接收指令失败')
# if rec_socket(socket_send_2, cmd_type=cmd, ack=False):
# print('指令执行完毕')
# else:
# print('指令执行失败')
elif cmd == 'MD':
# model = "/Users/zhouchao/Library/CloudStorage/OneDrive-macrosolid/PycharmProjects/wood_color/models/model_2020-11-08_20-49.p"
# model = r"C:\Users\FEIJINTI\OneDrive\PycharmProjects\wood_color\models\model_2023-03-27_16-32.p"
model = r"D:\Projects\PycharmProjects\xiangmu_wenli_2\models\model_2024-05-07_13-58.p" # 模型路径
model = model.encode('ascii') # 将字符串转换为字节流
length = len(model) + 4
length = length.to_bytes(4, byteorder='big')
send_message = b'\xaa' + length + (' ' + cmd).upper().encode('ascii') + model + b'\xff\xff\xbb'
socket_send_1.send(send_message)
print('发送成功')
result = socket_send_2.recv(1)
print(result)
# if rec_socket(socket_send_2, cmd_type=cmd, ack=True):
# print('接收指令成功')
# else:
# print('接收指令失败')
# if rec_socket(socket_send_2, cmd_type=cmd, ack=False):
# print('指令执行完毕')
# else:
# print('指令执行失败')
if __name__ == '__main__':
main()

329
classifer.py Normal file
View File

@ -0,0 +1,329 @@
# -*- coding: UTF-8 -*-
# @Time : 2024/4/15 19:03
# @Auther : DUANMU
# @File : wenli_classifer.py
# @Software : PyCharm
import logging
import numpy as np
import cv2
import os
import random
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier, GradientBoostingClassifier, ExtraTreesClassifier
from skimage.filters import sobel
from skimage import filters
from skimage.morphology import opening, square
from skimage.color import rgb2gray
from skimage.filters import threshold_otsu
from skimage.feature import local_binary_pattern, graycomatrix, graycoprops
from scipy.stats import kurtosis, skew
from skimage import transform, util
from skimage.feature import hog
from skimage import color, exposure
import lightgbm as lgb
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import pickle
import time
from root_dir import ROOT_DIR
import utils
class TextureClass:
def __init__(self, load_from=None, w=4096, h=1200, debug_mode=False):
"""
初始化纹理分类器
:param load_from:
:param w:
:param h:
:param debug_mode:
"""
if load_from is None:
if w is None or h is None:
print("It will damage your performance if you don't set w and h")
raise ValueError("w or h is None")
self.w, self.h= w, h
# self.model = RandomForestClassifier(n_estimators=100)
# self.model = lgb.LGBMClassifier(
# n_estimators=100,
# max_depth=5,
# num_leaves=16, # 减小叶节点数
# min_data_in_leaf=40, # 增加每个叶子的最小数据量
# lambda_l1=0.1, # 增强L1正则化
# lambda_l2=0.1, # 增强L2正则化
# boosting_type='gbdt', # 使用传统的梯度提升决策树
# objective='binary', # 二分类问题
# learning_rate=0.05, # 可以尝试降低学习速率
# subsample=0.8, # 子样本比率
# colsample_bytree=0.8, # 基于树的列采样
# metric='binary_logloss', # 评价指标
# verbose=-1 # 不输出任何东西,包括警告
# )
self.model = lgb.LGBMClassifier(n_estimators=100 ,verbose=-1)
# self.model = RandomForestClassifier(n_estimators=150, random_state=42, max_depth=10,min_samples_leaf=2,
# min_samples_split=4,max_features='sqrt')
# self.model = AdaBoostClassifier()
# self.model = GradientBoostingClassifier()
# self.model = ExtraTreesClassifier()
# self.model = GradientBoostingClassifier(n_estimators=100, learning_rate=0.1, max_depth=3, random_state=42)
else:
self.load(load_from)
self.debug_mode = debug_mode
self.log = utils.Logger(is_to_file=debug_mode)
self.image_num = 0
def extract_features(self, img):
# """使用局部二值模式LBP提取图像的纹理特征"""
# gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# if gray.dtype != np.uint8:
# gray = (gray * 255).astype(np.uint8)
#
# # 设置LBP参数
# radius = 1 # LBP算法中圆的半径
# n_points = 8 * radius # 统一模式用的点数
# lbp = local_binary_pattern(gray, n_points, radius, method='uniform')
#
# # 计算LBP的直方图
# n_bins = int(lbp.max() + 1) # 加1因为直方图的区间是开区间
# hist, _ = np.histogram(lbp, bins=n_bins, range=(0, n_bins), density=True)
# festures = hist
#
# return festures
"""使用局部二值模式LBP提取图像的纹理特征优化版"""
# 图像下采样
img_resized = cv2.resize(img, (img.shape[1] // 2, img.shape[0] // 2))
gray = cv2.cvtColor(img_resized, cv2.COLOR_BGR2GRAY)
# gray = self.preprocess_image(img_resized)
if gray.dtype != np.uint8:
gray = (gray * 255).astype(np.uint8)
# 设置LBP参数
radius = 1
n_points = 8 * radius
lbp = local_binary_pattern(gray, n_points, radius, method='uniform')
# 计算LBP的直方图
n_bins = int(lbp.max() + 1)
hist, _ = np.histogram(lbp, bins=n_bins, range=(0, n_bins), density=True)
features = hist
return features
def augment_image(self, img):
"""对输入图像应用随机数据增强"""
# 随机旋转 ±30 度
angle = np.random.uniform(-30, 30)
M = cv2.getRotationMatrix2D((img.shape[1] / 2, img.shape[0] / 2), angle, 1)
img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0]))
# 随机水平翻转
if np.random.rand() > 0.5:
img = cv2.flip(img, 1)
# 随机调整亮度 ±50
value = int(np.random.uniform(-50, 50))
hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
h, s, v = cv2.split(hsv)
v = cv2.add(v, value)
v[v > 255] = 255
v[v < 0] = 0
final_hsv = cv2.merge((h, s, v))
img = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR)
return img
def get_image_data(self, img_dir="./data/wenli/Huawen",augment=False):
"""
:param img_dir: 图像文件的路径
:return: 图像数据
"""
img_data = []
img_name = []
utils.mkdir_if_not_exist(img_dir)
files = os.listdir(img_dir)
if len(files) == 0:
return False
for file in files:
path = os.path.join(img_dir, file)
if self.debug_mode:
self.log.log(path)
train_img = cv2.imread(path)
if augment:
train_img = self.augment_image(train_img) # 应用数据增强
data = self.extract_features(train_img)
img_data.append(data)
img_name.append(file)
img_data = np.array(img_data)
return img_data, img_name
def get_train_data(self, data_dir=None, plot_2d=False, save_data=False, augment=False):
"""
获取图像数据
:return: x_data, y_data
"""
print("开始加载训练数据...")
data_dir = os.path.join(ROOT_DIR, "data", "wenli") if data_dir is None else data_dir
hw_data, hw_name = self.get_image_data(img_dir=os.path.join(data_dir, "Huawen"), augment=augment)
zw_data, zw_name = self.get_image_data(img_dir=os.path.join(data_dir, "Zhiwen"), augment=augment)
if (hw_data is False) or (zw_data is False) :
return False
x_data = np.vstack((hw_data, zw_data))
hw_label = np.zeros(len(hw_data)).T # 为"Huawen"类图像赋值标签0
zw_label = np.ones(len(zw_data)).T # 为"Zhiwen"类图像赋值标签1
y_data = np.hstack((hw_label, zw_label))
img_name = hw_name + zw_name
# if plot_2d:
# plt.scatter(x_data[:, 0], x_data[:, 1], c=y_data, cmap='viridis')
# plt.show()
if save_data:
with open(os.path.join("data", "data.p"), "rb") as f:
pass
if (hw_data is False) or (zw_data is False):
print("未找到有效的训练数据")
return False
print("训练数据加载完成")
return x_data, y_data, img_name
def fit_pictures(self, data_path=ROOT_DIR, file_name=None, augment=False):
"""
根据给出的data_path 进行 fit.如果没有给出data目录那么将会使用当前文件夹
:param data_path:
:return:
"""
print("开始训练模型...")
# 训练数据文件位置
result = self.get_train_data(data_path, plot_2d=True, augment=augment)
if result is False:
print("训练数据加载失败,中止训练")
return 0
x, y, name = result
print("训练数据加载成功,开始模型训练")
score = self.fit(x, y)
print('model score', score)
model_name = self.save(file_name)
return model_name
def fit(self, X, y):
"""训练模型"""
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
self.model.fit(X_train, y_train)
y_pred = self.model.predict(X_test)
print(confusion_matrix(y_test, y_pred))
pre_score = accuracy_score(y_test, y_pred) # 计算测试集的准确率
self.log.log("Test accuracy is:" + str(pre_score * 100) + "%.")
y_pred = self.model.predict(X_train)
pre_score = accuracy_score(y_train, y_pred)
self.log.log("Train accuracy is:" + str(pre_score * 100) + "%.")
y_pred = self.model.predict(X)
pre_score = accuracy_score(y, y_pred)
self.log.log("Total accuracy is:" + str(pre_score * 100) + "%.")
return int(pre_score * 100)
def predict(self, img):
"""预测图像纹理"""
if self.debug_mode:
cv2.imwrite(str(self.image_num) + ".bmp", img)
self.image_num += 1
features = self.extract_features(img) # 提取图像特征
features = np.array(features) # 将列表转换为 Numpy 数组
features = features.reshape(1, -1) # 使用 reshape(1, -1) 将特征数组转换成适合模型预测的形状。这里的 1 表示样本数为一(单个图像),-1 表示自动计算特征数量。
pred_wenli = self.model.predict(features) # 使用模型预测图像的纹理
if self.debug_mode:
self.log.log(features)
return int(pred_wenli[0]) # 从模型预测结果中提取第一个元素(假设预测结果是一个数组),并将其转换为整数。这通常是分类任务的类别标签。
def save(self, file_name):
"""保存模型到文件"""
if file_name is None:
file_name = "model_" + time.strftime("%Y-%m-%d_%H-%M") + ".p"
file_name = os.path.join(ROOT_DIR, "models", file_name)
model_dic = { "model": self.model,"w": self.w, "h": self.h}
with open(file_name, "wb") as f:
pickle.dump(model_dic, f)
self.log.log("Save file to '" + str(file_name) + "'")
return file_name
def load(self, path=None):
if path is None:
path = os.path.join(ROOT_DIR, "models")
utils.mkdir_if_not_exist(path)
model_files = os.listdir(path)
if len(model_files) == 0:
self.log.log("No model found!")
return 1
self.log.log("./ Models Found:")
_ = [self.log.log("├--" + str(model_file)) for model_file in model_files]
file_times = [model_file[6:-2] for model_file in model_files]
latest_model = model_files[int(np.argmax(file_times))]
self.log.log("└--Using the latest model: " + str(latest_model))
path = os.path.join(ROOT_DIR, "models", str(latest_model))
if not os.path.isabs(path):
logging.warning('给的是相对路径')
return -1
if not os.path.exists(path):
logging.warning('文件不存在')
return -1
with open(path, "rb") as f:
model_dic = pickle.load(f)
self.model = model_dic["model"]
self.w = model_dic["w"]
self.h = model_dic["h"]
return 0
if __name__ == '__main__':
from config import Config
# 加载配置设置
settings = Config()
# 初始化 TextureClass 实例
texture = TextureClass(w=4096, h=1200, debug_mode=False)
print("初始化纹理分类器完成")
# 获取数据路径
data_path = settings.data_path
# texture.correct() # 如果有色彩校正的步骤可以添加
# texture.load() # 如果需要加载先前的模型可以调用
# 训练模型并保存
model_path = texture.fit_pictures(data_path=data_path)
print(f"模型保存在 {model_path}")
# # 加载K-means数据并进行数据调整
# x_data, y_data, labels, img_names = texture.get_kmeans_data(data_path, plot_2d=True)
# send_data = texture.data_adjustments(x_data, y_data, labels, img_names)
# print(f"调整后的数据已发送/保存")
# 测试单张图片的预测性能
pic = cv2.imread(r"data/wenli/Zhiwen/rgb98.png")
start_time = time.time()
texture_type = texture.predict(pic)
end_time = time.time()
print("单次预测耗时:", (end_time - start_time) * 1000, "毫秒")
print("预测的纹理类型:", texture_type)
# 如果有批量处理或性能测试需求
# 总时间 = 0
# for i in range(100):
# start_time = time.time()
# _ = texture.predict(pic)
# end_time = time.time()
# total_time += (end_time - start_time)
# print("平均预测时间:", (total_time / 100) * 1000, "毫秒")

60
config.py Normal file
View File

@ -0,0 +1,60 @@
# -*- codeing = utf-8 -*-
# Time : 2022/10/17 11:07
# @Auther : zhouchao
# @File: config.py
# @Software:PyCharm
import json
import os
from pathlib import WindowsPath
from root_dir import ROOT_DIR
class Config(object):
model_path = ROOT_DIR / 'config.json' # “/”运算符被重载,用于连接两个路径,表示 ROOT_DIR 目录下的 config.json 文件
def __init__(self): # 初始化方法
self._param_dict = {} # 创建一个空字典
if os.path.exists(Config.model_path):
self._read() # 如果 config.json 文件存在,调用 _read 方法
else:
self.model_path = str(ROOT_DIR / 'models/model_2024-04-18_10-16.p') # 模型路径
self.data_path = str(ROOT_DIR / 'data/xiangmu_photos_wenli') # 数据路径
self.database_addr = str("mysql+pymysql://root:@localhost:3306/orm_test") # 测试用数据库地址
self._param_dict['model_path'] = self.model_path # 将模型路径写入 _param_dict 属性中
self._param_dict['data_path'] = self.data_path
self._param_dict['database_addr'] = self.database_addr
def __setitem__(self, key, value): # 重载 __setitem__ 方法
if key in self._param_dict: # 如果 key 在 _param_dict 中 key是键value是值 用于设置值
self._param_dict[key] = value # 将键值对写入 _param_dict 属性中
self._write() # 将 _param_dict 属性写入 config.json 文件
def __getitem__(self, item): # 重载 __getitem__ 方法
if item in self._param_dict: # 如果 item 在 _param_dict 中 item是键 value是值 用于获取值
return self._param_dict[item] # 返回 _param_dict 中 item 键的值
def __setattr__(self, key, value):
self.__dict__[key] = value # 直接在对象的 __dict__ 属性(这是一个存储对象所有属性的字典)中设置键(属性名)和值
if '_param_dict' in self.__dict__ and key != '_param_dict': # 如果 _param_dict 属性存在,且不是 _param_dict 属性本身
if isinstance(value, WindowsPath): # 如果 value 是 WindowsPath 对象,将其转换为字符串
value = str(value) # WindowsPath 对象不能被 json 序列化,需要转换为字符串
self.__dict__['_param_dict'][key] = value # 将键值对写入 _param_dict 属性中
self._write() # 将 _param_dict 属性写入 config.json 文件
def _read(self): # 读取 config.json 文件
with open(Config.model_path, 'r') as f: # 打开 config.json 文件
self._param_dict = json.load(f) # 读取文件内容,将其转换为字典
self.data_path = self._param_dict['data_path'] # 从字典中读取 data_path 键的值
self.model_path = self._param_dict['model_path'] # 从字典中读取 model_path 键的值
self.database_addr = self._param_dict['database_addr']
def _write(self): # 将 _param_dict 属性写入 config.json 文件
with open(Config.model_path, 'w') as f: # 打开 config.json 文件
json.dump(self._param_dict, f) # 将 _param_dict 写入文件
if __name__ == '__main__':
config = Config()
print(config.model_path)
print(config.data_path)

65
database.py Normal file
View File

@ -0,0 +1,65 @@
# -*- codeing = utf-8 -*-
# Time : 2022/10/20 13:54
# @Auther : zhouchao
# @File: database.py
# @Software:PyCharm
import datetime
import time
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.dialects.mysql import INTEGER, VARCHAR
from sqlalchemy import Column, TIMESTAMP, DATETIME
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
Base = declarative_base()
class Wood(Base):
__tablename__ = 'color'
id = Column(INTEGER, primary_key=True)
color = Column(VARCHAR(256), nullable=False)
time = Column(DATETIME, nullable=False)
def __init__(self, color):
self.time = datetime.datetime.now()
self.color = color
class Database(object):
def __init__(self, database_addr):
self.database_addr = database_addr
def init_db(self):
engine = create_engine(self.database_addr, encoding="utf-8", echo=True)
Base.metadata.create_all(engine)
print('Create table successfully!')
def add_data(self, color):
# 初始化数据库连接
engine = create_engine(self.database_addr, encoding="utf-8")
# 创建DBSession类型
DBSession = sessionmaker(bind=engine)
# 创建session对象
session = DBSession()
# 插入单条数据
# 创建新User对象
new_wood = Wood(color)
# 添加到session
session.add(new_wood)
# 提交即保存到数据库
session.commit()
if __name__ == '__main__':
test_addr = "mysql+pymysql://root:@localhost:3306/color"
database = Database(test_addr)
# database.init_db()
t1 = time.time()
for i in range(100):
database.add_data('middle')
t2 = time.time()
print((t2-t1)/100)

10
requirements.txt Normal file
View File

@ -0,0 +1,10 @@
lightgbm==4.3.0
matplotlib==3.5.2
numpy==1.23.1
numpy==1.25.1
opencv_contrib_python==4.7.0.72
opencv_python==4.6.0.66
scikit_learn==1.1.1
scipy==1.9.0
skimage==0.0
SQLAlchemy==2.0.19

13
root_dir.py Normal file
View File

@ -0,0 +1,13 @@
# -*- coding: utf-8 -*-
"""
Created on Nov 3 21:18:26 2020
@author: l.z.y
@e-mail: li.zhenye@qq.com
"""
import pathlib
file_path = pathlib.Path(__file__) # 获得当前文件路径
ROOT_DIR = file_path.parent # 获得当前文件的父目录
# pathlib的作用是将文件路径转换为操作系统的路径格式这样可以避免不同操作系统的路径格式不同的问题

110
socket_detector.py Normal file
View File

@ -0,0 +1,110 @@
import socket
import sys
import numpy as np
import cv2
import root_dir
from classifer import TextureClass
import time
import os
# from database import Database
from root_dir import ROOT_DIR
from utils import PreSocket, receive_sock, parse_protocol, ack_sock, done_sock, DualSock, simple_sock
import logging
from config import Config
def process_cmd(cmd: str, data: any, connected_sock: socket.socket, detector: TextureClass, settings: Config) -> tuple:
"""
处理指令
:param cmd: 指令类型
:param data: 指令内容
:param connected_sock: socket
:param detector: 模型
:return: 是否处理成功
"""
result = ''
if cmd == 'IM':
data = np.clip(data, 0, 255).astype(dtype=np.uint8) # 将 data 数组中的每个元素限制在范围 [0, 255] 内,将数组的数据类型转换为无符号 8 位整数
wood_wenli = detector.predict(data)
result = {0: 'Huawen', 1: 'Zhiwen'}[wood_wenli]
response = simple_sock(connected_sock, cmd_type=cmd, result=wood_wenli)
elif cmd == 'TR':
detector = TextureClass(w=4096, h=1200, debug_mode=False)
model_name = None
if "$" in data:
data, model_name = data.split("$", 1)
model_name = model_name + ".p"
settings.data_path = data
settings.model_path = ROOT_DIR / 'models' / detector.fit_pictures(data_path=settings.data_path, file_name=model_name)
response = simple_sock(connected_sock, cmd_type=cmd, result=result)
elif cmd == 'MD':
settings.model_path = data
detector.load(path=settings.model_path)
response = simple_sock(connected_sock, cmd_type=cmd, result=result)
else:
logging.error(f'错误指令,指令为{cmd}')
response = False
return response, result
def main(is_debug=False):
settings = Config() # 创建一个配置类的实例
file_handler = logging.FileHandler(os.path.join(ROOT_DIR, 'report.log')) # 创建一个文件处理器,将日志信息写入到指定的 report.log 文件
file_handler.setLevel(logging.DEBUG if is_debug else logging.WARNING) # 设置日志级别, DEBUG 级别最低,可以输出所有级别的日志信息; WARNING 级别最高,只能输出 WARNING 级别及以上的日志信息
console_handler = logging.StreamHandler(sys.stdout) # 创建一个控制台处理器,将日志信息输出到控制台
console_handler.setLevel(logging.DEBUG if is_debug else logging.WARNING) # 设置日志级别
logging.basicConfig(format='%(asctime)s %(filename)s[line:%(lineno)d] - %(levelname)s - %(message)s',
handlers=[file_handler, console_handler], level=logging.DEBUG)
dual_sock = DualSock(connect_ip='127.0.0.1')
# database = Database(settings.database_addr)
while not dual_sock.status:
dual_sock.reconnect()
detector = TextureClass(w=4096, h=1200, debug_mode=False)
detector.load(path=settings.model_path)
_ = detector.predict(np.random.randint(1, 254, (1200, 4096, 3), dtype=np.uint8))
while True:
pack, next_pack = receive_sock(dual_sock)
if pack == b"":
time.sleep(5)
dual_sock.reconnect()
continue
cmd, data = parse_protocol(pack)
# ack_sock(received_sock, cmd_type=cmd)
response, result = process_cmd(cmd=cmd, data=data, connected_sock=dual_sock, detector=detector, settings=settings)
# if result != "":
# database.add_data(result)
if __name__ == '__main__':
# 2个端口
# 接受端口21122
# 发送端口21123
# 接收到图片 n_rows * n_bands * n_cols, float32
# 发送图片 n_rows * n_cols, uint8
main(is_debug=False)
# test(r"D:\build-tobacco-Desktop_Qt_5_9_0_MSVC2015_64bit-Release\calibrated15.raw")
# main()
# debug_main()
# test_run(all_data_dir=r'D:\数据')
# with open(r'D:\数据\虫子\valid2.raw', 'rb') as f:
# data = np.frombuffer(f.read(), dtype=np.float32).reshape(600, 29, 1024).transpose(0, 2, 1)
# plt.matshow(data[:, :, 10])
# plt.show()
# detector = SpecDetector('model_spec/model_29.p')
# result = detector.predict(data)
#
# plt.matshow(result)
# plt.show()
# result = result.reshape((600, 1024))

347
utils.py Normal file
View File

@ -0,0 +1,347 @@
# -*- coding: utf-8 -*-
"""
Created on Nov 3 21:18:26 2020
@author: l.z.y
@e-mail: li.zhenye@qq.com
"""
import logging
import os
import shutil
import time
import socket
import numpy as np
def mkdir_if_not_exist(dir_name, is_delete=False):
"""
创建文件夹
:param dir_name: 文件夹
:param is_delete: 是否删除
:return: 是否成功
"""
try:
if is_delete:
if os.path.exists(dir_name): # 如果 is_delete 为 True函数会检查 dir_name 是否已经存在。
shutil.rmtree(dir_name) # 如果存在,它会使用 shutil.rmtree 删除整个目录结构,并输出一个信息消息。
print('[Info] 文件夹 "%s" 存在, 删除文件夹.' % dir_name)
if not os.path.exists(dir_name): # 使用 os.path.exists 来检查 dir_name 是否存在。
os.makedirs(dir_name) # 如果不存在,它会使用 os.makedirs 创建整个目录树。
print('[Info] 文件夹 "%s" 不存在, 创建文件夹.' % dir_name)
return True
except Exception as e:
print('[Exception] %s' % e)
return False
def create_file(file_name):
"""
创建文件
:param file_name: 文件名
:return: None
"""
if os.path.exists(file_name):
print("文件存在:%s" % file_name)
return False
# os.remove(file_name) # 删除已有文件
if not os.path.exists(file_name):
print("文件不存在,创建文件:%s" % file_name)
open(file_name, 'a').close()
return True
class Logger(object): # 日志类
def __init__(self, is_to_file=False, path=None):
self.is_to_file = is_to_file # 布尔值,用于决定日志是保存到文件(True)还是输出到控制台(False)
if path is None: # 如果没有指定日志文件的路径,则默认为当前目录下的 wood.log
path = "wood.log"
self.path = path # 指定保存日志的文件路径。
create_file(path) # 确保日志文件存在,如果不存在则创建它
def log(self, content):
if self.is_to_file: # 如果指定了日志文件的路径,则将日志信息写入到日志文件中
with open(self.path, "a") as f: # 以追加的方式打开日志文件,'a'Append 模式): 在这种模式下,如果文件已经存在,新的数据会被写入到文件的末尾。如果文件不存在,它会被创建。
print(time.strftime("[%Y-%m-%d_%H-%M-%S]:"), file=f) # 将当前时间戳和实际日志内容写入到文件,显示为[年-月-日_时-分-秒]: 的格式。
print(content, file=f)
else:
print(content) # 如果没有指定日志文件的路径,则将日志信息输出到控制台
def try_connect(connect_ip: str, port_number: int, is_repeat: bool = False, max_reconnect_times: int = 50) -> (
bool, socket.socket):
"""
尝试连接.
:param is_repeat: 是否是重新连接
:param max_reconnect_times:最大重连次数
:return: (连接状态True为成功, Socket / None)
"""
reconnect_time = 0
while reconnect_time < max_reconnect_times:
logging.warning(f'尝试{"重新" if is_repeat else ""}发起第{reconnect_time + 1}次连接...')
try:
connected_sock = PreSocket(socket.AF_INET, socket.SOCK_STREAM)
connected_sock.connect((connect_ip, port_number))
except Exception as e:
reconnect_time += 1
logging.error(f'{reconnect_time}次连接失败... 5秒后重新连接...\n {e}')
time.sleep(5)
continue
logging.warning(f'{"重新" if is_repeat else ""}连接成功')
return True, connected_sock
return False, None
class PreSocket(socket.socket):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.pre_pack = b''
self.settimeout(5)
def receive(self, *args, **kwargs):
if self.pre_pack == b'':
return self.recv(*args, **kwargs)
else:
data_len = args[0]
required, left = self.pre_pack[:data_len], self.pre_pack[data_len:]
self.pre_pack = left
return required
def set_prepack(self, pre_pack: bytes):
temp = self.pre_pack
self.pre_pack = temp + pre_pack
class DualSock(PreSocket):
def __init__(self, connect_ip='127.0.0.1', recv_port: int = 21122, send_port: int = 21123):
super().__init__()
received_status, self.received_sock = try_connect(connect_ip=connect_ip, port_number=recv_port) # 这两行代码分别设置接收和发送的sockets。
send_status, self.send_sock = try_connect(connect_ip=connect_ip, port_number=send_port)
self.status = received_status and send_status
def send(self, *args, **kwargs) -> int:
return self.send_sock.send(*args, **kwargs)
def receive(self, *args, **kwargs) -> bytes:
return self.received_sock.receive(*args, **kwargs)
def set_prepack(self, pre_pack: bytes):
self.received_sock.set_prepack(pre_pack)
def reconnect(self, connect_ip='127.0.0.1', recv_port:int = 21122, send_port: int = 21123):
received_status, self.received_sock = try_connect(connect_ip=connect_ip, port_number=recv_port)
send_status, self.send_sock = try_connect(connect_ip=connect_ip, port_number=send_port)
return received_status and send_status
def receive_sock(recv_sock: PreSocket, pre_pack: bytes = b'', time_out: float = -1.0, time_out_single=5e20) -> (
bytes, bytes):
"""
从指定的socket中读取数据.
:param recv_sock: 指定sock
:param pre_pack: 上一包的粘包内容
:param time_out: 每隔time_out至少要发来一次指令,否则认为出现问题进行重连小于0则为一直等
:param time_out_single: 单次指令超时时间单位是秒
:return: data, next_pack
"""
recv_sock.set_prepack(pre_pack)
# 开头校验
time_start_recv = time.time()
while True:
if time_out > 0:
if (time.time() - time_start_recv) > time_out:
logging.error(f'指令接收超时')
return b'', b''
try:
temp = recv_sock.receive(1)
except ConnectionError as e:
logging.error(f'连接出错, 错误代码:\n{e}')
return b'', b''
except TimeoutError as e:
# logging.error(f'超时了,错误代码: \n{e}')
logging.info('运行中,等待指令..')
continue
except socket.timeout as e:
logging.info('运行中,等待指令..')
continue
except Exception as e:
logging.error(f'遇见未知错误,错误代码: \n{e}')
return b'', b''
if temp == b'\xaa':
break
# 接收开头后,开始进行时间记录
time_start_recv = time.time()
# 获取报文长度
temp = b''
while len(temp) < 4:
if (time.time() - time_start_recv) > time_out_single:
logging.error(f'单次指令接收超时')
return b'', b''
try:
temp += recv_sock.receive(1)
except Exception as e:
logging.error(f'接收报文的长度不正确, 错误代码: \n{e}')
return b'', b''
try:
data_len = int.from_bytes(temp, byteorder='big')
except Exception as e:
logging.error(f'转换失败,错误代码 \n{e}')
return b'', b''
# 读取报文内容
temp = b''
while len(temp) < data_len:
if (time.time() - time_start_recv) > time_out_single:
logging.error(f'单次指令接收超时')
return b'', b''
try:
temp += recv_sock.receive(data_len)
except Exception as e:
logging.error(f'接收报文内容失败, 错误代码: \n{e}')
return b'', b''
data, next_pack = temp[:data_len], temp[data_len:]
recv_sock.set_prepack(next_pack)
next_pack = b''
# 进行数据校验
temp = b''
while len(temp) < 3:
if (time.time() - time_start_recv) > time_out_single:
logging.error(f'单次指令接收超时')
return b'', b''
try:
temp += recv_sock.receive(1)
except Exception as e:
logging.error(f'接收报文校验失败, 错误代码: \n{e}')
return b'', b''
if temp == b'\xff\xff\xbb':
return data, next_pack
else:
logging.error(f"接收了一个完美的只错了校验位的报文")
return b'', b''
def parse_protocol(data: bytes) -> (str, any): # data: 参数类型为 bytes表示接收到的报文数据。返回类型 (str, any) 表示函数返回一个元组,其中第一个元素是指令类型的字符串,第二个元素是指令对应的内容,内容的类型可以是任何类型(由指令决定)。
"""
指令转换.
:param data:接收到的报文
:return: 指令类型和内容
"""
try:
assert len(data) > 4
except AssertionError:
logging.error('指令转换失败长度不足5')
return '', None
cmd, data = data[:4], data[4:] # 从 data 中取出前 4 个字节作为指令类型,剩下的部分作为指令内容。
cmd = cmd.decode('ascii').strip().upper() # 将指令类型转换为字符串,并去除首尾空格,然后转换为大写。
if cmd == 'IM':
n_rows, n_cols, img = data[:2], data[2:4], data[4:] # 按照协议,先是两个字节的行数(高),两个字节的列数(宽),后面是图像数据
try:
n_rows, n_cols = [int.from_bytes(x, byteorder='big') for x in [n_rows, n_cols]]
except Exception as e:
logging.error(f'长宽转换失败, 错误代码{e}, 报文大小: n_rows:{n_rows}, n_cols: {n_cols}')
return '', None
try:
assert n_rows * n_cols * 3 == len(img)
# 因为是float32类型 所以长度要乘12 如果是uint8则乘3
except AssertionError:
logging.error('图像指令IM转换失败数据长度错误')
return '', None
img = np.frombuffer(img, dtype=np.uint8).reshape((n_rows, n_cols, -1))
return cmd, img
elif cmd == 'TR':
data = data.decode('ascii')
return cmd, data
elif cmd == 'MD':
data = data.decode('ascii')
return cmd, data
def ack_sock(send_sock: socket.socket, cmd_type: str) -> bool: # 未使用
'''
发送应答
:param cmd_type:指令类型
:param send_sock:指定sock
:return:是否发送成功
'''
msg = b'\xaa\x00\x00\x00\x05' + (' A' + cmd_type).upper().encode('ascii') + b'\xff\xff\xff\xbb'
try:
send_sock.send(msg)
except Exception as e:
logging.error(f'发送应答失败,错误类型:{e}')
return False
return True
def done_sock(send_sock: socket.socket, cmd_type: str, result = '') -> bool: # 未使用
'''
发送任务完成指令
:param cmd_type:指令类型
:param send_sock:指定sock
:param result:数据
:return:是否发送成功
'''
cmd_type = cmd_type.strip().upper()
if (cmd_type == "TR") or (cmd_type == "MD"):
if result != '':
logging.error('结果在这种指令里很没必要')
result = b'\xff'
elif cmd_type == 'IM':
if result == 0:
result = b'H'
elif result == 1:
result = b'Z'
length = len(result) + 4
length = length.to_bytes(4, byteorder='big')
msg = b'\xaa' +length + (' D' + cmd_type).upper().encode('ascii') + result + b'\xff\xff\xbb'
try:
send_sock.send(msg)
except Exception as e:
logging.error(f'发送完成指令失败,错误类型:{e}')
return False
return True
def simple_sock(send_sock: socket.socket, cmd_type: str, result) -> bool:
'''
发送任务完成指令
:param cmd_type:指令类型
:param send_sock:指定sock
:param result:数据
:return:是否发送成功
'''
cmd_type = cmd_type.strip().upper() # 去除空格并转换为大写
if cmd_type == 'IM':
if result == 0:
msg = b'H'
elif result == 1:
msg = b'Z'
elif cmd_type == 'TR':
msg = b'A'
elif cmd_type == 'MD':
msg = b'D'
result = result.encode('ascii')
result = b',' + result
length = len(result)
msg = msg + length.to_bytes(4, 'big') + result
try:
send_sock.send(msg)
except Exception as e:
logging.error(f'发送完成指令失败,错误类型:{e}')
return False
return True
if __name__ == '__main__':
log = Logger(is_to_file=True)
log.log("nihao")
import numpy as np
a = np.ones((100, 100, 3))
log.log(a.shape)