diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5511f63 --- /dev/null +++ b/.gitignore @@ -0,0 +1,19 @@ +data/correct/*.bmp +data/dark/*.bmp +data/light/*.bmp +data/middle/*.bmp +data1 +data2 +data3 +data4 +data5 +.idea +__pycache__ +*.pyc +test.py +*.log +.models +data +*.png +models +*.json \ No newline at end of file diff --git a/QT_test.py b/QT_test.py new file mode 100644 index 0000000..d93bb0e --- /dev/null +++ b/QT_test.py @@ -0,0 +1,162 @@ +# -*- codeing = utf-8 -*- +# Time : 2022/9/17 15:05 +# @Auther : zhouchao +# @File: QT_test.py +# @Software:PyCharm +import logging +import socket +import numpy as np +import cv2 + + + +def rec_socket(recv_sock: socket.socket, cmd_type: str, ack: bool) -> bool: # 未使用 + if ack: + cmd = 'A' + cmd_type + else: + cmd = 'D' + cmd_type + while True: + try: + temp = recv_sock.recv(1) + except ConnectionError as e: + logging.error(f'连接出错, 错误代码:\n{e}') + return False + except TimeoutError as e: + logging.error(f'超时了,错误代码: \n{e}') + return False + except Exception as e: + logging.error(f'遇见未知错误,错误代码: \n{e}') + return False + if temp == b'\xaa': + break + + # 获取报文长度 + temp = b'' + while len(temp) < 4: + try: + temp += recv_sock.recv(1) + except Exception as e: + logging.error(f'接收报文长度失败, 错误代码: \n{e}') + return False + try: + data_len = int.from_bytes(temp, byteorder='big') + except Exception as e: + logging.error(f'转换失败,错误代码 \n{e}, \n报文内容\n{temp}') + return False + + # 读取报文内容 + temp = b'' + while len(temp) < data_len: + try: + temp += recv_sock.recv(data_len) + except Exception as e: + logging.error(f'接收报文内容失败, 错误代码: \n{e},\n报文内容\n{temp}') + return False + data = temp + if cmd.strip().upper() != data[:4].decode('ascii').strip().upper(): + logging.error(f'客户端接收指令错误,\n指令内容\n{data}') + return False + else: + if cmd == 'DIM': + print(data) + + # 进行数据校验 + temp = b'' + while len(temp) < 3: + try: + temp += recv_sock.recv(1) + except Exception as e: + logging.error(f'接收报文校验失败, 错误代码: \n{e}') + return False + if temp == b'\xff\xff\xbb': + return True + else: + logging.error(f"接收了一个完美的只错了校验位的报文,\n data: {data}") + return False + + +def main(): # 在一个端口上接收文件,在另一个端口上接收控制命令 + socket_receive = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # 创建一个socket对象,AF_INET是地址簇,SOCK_STREAM是socket类型,表示TCP连接 + socket_receive.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + socket_receive.bind(('127.0.0.1', 21123)) # 127.0.0.1是本机的回环地址,意味着socket仅接受从同一台机器上发起的连接请求。21123是端口号,它是计算机上用于区分不同服务的数字标签。 + socket_receive.listen(5) # 开始监听传入的连接,5是指在拒绝连接之前,操作系统可以挂起的最大连接数量 + socket_send = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + socket_send.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + socket_send.bind(('127.0.0.1', 21122)) + socket_send.listen(5) + print('等待连接') + socket_send_1, receive_addr_1 = socket_send.accept() + print("连接成功:", receive_addr_1) + # socket_send_2 = socket_send_1 + socket_send_2, receive_addr_2 = socket_receive.accept() + print("连接成功:", receive_addr_2) + while True: + cmd = input().strip().upper() + if cmd == 'IM': + # img = cv2.imread(r"/Users/zhouchao/Library/CloudStorage/OneDrive-macrosolid/PycharmProjects/wood_color/data/data20220919/dark/rgb60.png") + img = cv2.imread(r"D:\Projects\PycharmProjects\xiangmu_wenli\data\wenli\Huawen\huaweisecha (132).png") # 读取图片,返回的img对象是一个NumPy数组,包含图像的像素数据。 + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 将BGR格式的图像转换为RGB格式 + img = np.asarray(img, dtype=np.uint8) # 通过np.asarray()确保图像数据是NumPy数组格式,dtype=np.uint8表示使用8位无符号整数格式存储每个颜色通道,这是图像处理中常用的数据类型。 + height = img.shape[0] # 获取图像的高度 + width = img.shape[1] # 获取图像的宽度 + img_bytes = img.tobytes() # 将图像数据转换为字节流,以便通过网络传输 + length = len(img_bytes) + 8 # 计算报文长度,包括命令部分、宽度、高度和图像数据,以及结束符。 + 8:这个加法操作包括额外的协议或消息格式所需的字节长度 4 字节用于表示命令类型(例如 'IM')。在某些实现中可能已经固定包含在消息的开始部分。2 字节用于图像的宽度。2 字节用于图像的高度。 + length = length.to_bytes(4, byteorder='big') # 将报文长度转换为4字节的大端字节序 + height = height.to_bytes(2, byteorder='big') # 将图像高度转换为2字节的大端字节序,这样可以确保图像的宽度和高度在网络传输中的顺序是正确的。 + width = width.to_bytes(2, byteorder='big') # 将图像宽度转换为2字节的大端字节序 + send_message = b'\xaa' + length + (' ' + cmd).upper().encode('ascii') + height + width + img_bytes + b'\xff\xff\xbb' # 消息以'\xaa'开始,包含消息长度、命令代码、图像宽度和高度、图像数据本身,以及结束符'\xff\xff\xbb'来标记消息结束。 + socket_send_1.send(send_message) + print('发送成功') + result = socket_send_2.recv(1) + print(result) + # if rec_socket(socket_send_2, cmd_type=cmd, ack=True): + # print('接收指令成功') + # else: + # print('接收指令失败') + # if rec_socket(socket_send_2, cmd_type=cmd, ack=False): + # print('指令执行完毕') + # else: + # print('指令执行失败') + elif cmd == 'TR': + # model = "/Users/zhouchao/Library/CloudStorage/OneDrive-macrosolid/PycharmProjects/wood_color/data/data20220919" + model = r"D:\Projects\PycharmProjects\xiangmu_wenli_2\data\xiangmu_photos_wenli" # 数据路径 + model = model.encode('ascii') # 将字符串转换为字节流 + length = len(model) + 4 # 计算报文长度 + 4:这个加法操作通常包括额外的协议或消息格式所需的字节长度,特别是:4 字节用于存储整个消息长度的数值本身,表示消息的起始部分。 + length = length.to_bytes(4, byteorder='big') # 将报文长度转换为4字节的大端字节序 + send_message = b'\xaa' + length + (' ' + cmd).upper().encode('ascii') + model + b'\xff\xff\xbb' + socket_send_1.send(send_message) + print('发送成功') + result = socket_send_2.recv(1) + print(result) + # if rec_socket(socket_send_2, cmd_type=cmd, ack=True): + # print('接收指令成功') + # else: + # print('接收指令失败') + # if rec_socket(socket_send_2, cmd_type=cmd, ack=False): + # print('指令执行完毕') + # else: + # print('指令执行失败') + elif cmd == 'MD': + # model = "/Users/zhouchao/Library/CloudStorage/OneDrive-macrosolid/PycharmProjects/wood_color/models/model_2020-11-08_20-49.p" + # model = r"C:\Users\FEIJINTI\OneDrive\PycharmProjects\wood_color\models\model_2023-03-27_16-32.p" + model = r"D:\Projects\PycharmProjects\xiangmu_wenli_2\models\model_2024-05-07_13-58.p" # 模型路径 + model = model.encode('ascii') # 将字符串转换为字节流 + length = len(model) + 4 + length = length.to_bytes(4, byteorder='big') + send_message = b'\xaa' + length + (' ' + cmd).upper().encode('ascii') + model + b'\xff\xff\xbb' + socket_send_1.send(send_message) + print('发送成功') + result = socket_send_2.recv(1) + print(result) + # if rec_socket(socket_send_2, cmd_type=cmd, ack=True): + # print('接收指令成功') + # else: + # print('接收指令失败') + # if rec_socket(socket_send_2, cmd_type=cmd, ack=False): + # print('指令执行完毕') + # else: + # print('指令执行失败') + +if __name__ == '__main__': + main() + diff --git a/classifer.py b/classifer.py new file mode 100644 index 0000000..a14201c --- /dev/null +++ b/classifer.py @@ -0,0 +1,329 @@ +# -*- coding: UTF-8 -*- +# @Time : 2024/4/15 19:03 +# @Auther : DUANMU +# @File : wenli_classifer.py +# @Software : PyCharm +import logging +import numpy as np +import cv2 +import os +import random + +from sklearn.model_selection import train_test_split +from sklearn.metrics import accuracy_score, confusion_matrix +from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier, GradientBoostingClassifier, ExtraTreesClassifier +from skimage.filters import sobel +from skimage import filters +from skimage.morphology import opening, square +from skimage.color import rgb2gray +from skimage.filters import threshold_otsu +from skimage.feature import local_binary_pattern, graycomatrix, graycoprops +from scipy.stats import kurtosis, skew +from skimage import transform, util +from skimage.feature import hog +from skimage import color, exposure +import lightgbm as lgb +from sklearn.decomposition import PCA + + +import matplotlib.pyplot as plt +import pickle +import time + +from root_dir import ROOT_DIR +import utils + +class TextureClass: + def __init__(self, load_from=None, w=4096, h=1200, debug_mode=False): + """ + 初始化纹理分类器 + :param load_from: + :param w: + :param h: + :param debug_mode: + """ + if load_from is None: + if w is None or h is None: + print("It will damage your performance if you don't set w and h") + raise ValueError("w or h is None") + self.w, self.h= w, h + # self.model = RandomForestClassifier(n_estimators=100) + # self.model = lgb.LGBMClassifier( + # n_estimators=100, + # max_depth=5, + # num_leaves=16, # 减小叶节点数 + # min_data_in_leaf=40, # 增加每个叶子的最小数据量 + # lambda_l1=0.1, # 增强L1正则化 + # lambda_l2=0.1, # 增强L2正则化 + # boosting_type='gbdt', # 使用传统的梯度提升决策树 + # objective='binary', # 二分类问题 + # learning_rate=0.05, # 可以尝试降低学习速率 + # subsample=0.8, # 子样本比率 + # colsample_bytree=0.8, # 基于树的列采样 + # metric='binary_logloss', # 评价指标 + # verbose=-1 # 不输出任何东西,包括警告 + # ) + self.model = lgb.LGBMClassifier(n_estimators=100 ,verbose=-1) + # self.model = RandomForestClassifier(n_estimators=150, random_state=42, max_depth=10,min_samples_leaf=2, + # min_samples_split=4,max_features='sqrt') + # self.model = AdaBoostClassifier() + # self.model = GradientBoostingClassifier() + # self.model = ExtraTreesClassifier() + # self.model = GradientBoostingClassifier(n_estimators=100, learning_rate=0.1, max_depth=3, random_state=42) + else: + self.load(load_from) + self.debug_mode = debug_mode + self.log = utils.Logger(is_to_file=debug_mode) + self.image_num = 0 + + def extract_features(self, img): + # """使用局部二值模式(LBP)提取图像的纹理特征""" + # gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + # if gray.dtype != np.uint8: + # gray = (gray * 255).astype(np.uint8) + # + # # 设置LBP参数 + # radius = 1 # LBP算法中圆的半径 + # n_points = 8 * radius # 统一模式用的点数 + # lbp = local_binary_pattern(gray, n_points, radius, method='uniform') + # + # # 计算LBP的直方图 + # n_bins = int(lbp.max() + 1) # 加1因为直方图的区间是开区间 + # hist, _ = np.histogram(lbp, bins=n_bins, range=(0, n_bins), density=True) + # festures = hist + # + # return festures + + """使用局部二值模式(LBP)提取图像的纹理特征,优化版""" + # 图像下采样 + img_resized = cv2.resize(img, (img.shape[1] // 2, img.shape[0] // 2)) + gray = cv2.cvtColor(img_resized, cv2.COLOR_BGR2GRAY) + # gray = self.preprocess_image(img_resized) + if gray.dtype != np.uint8: + gray = (gray * 255).astype(np.uint8) + + # 设置LBP参数 + radius = 1 + n_points = 8 * radius + lbp = local_binary_pattern(gray, n_points, radius, method='uniform') + + # 计算LBP的直方图 + n_bins = int(lbp.max() + 1) + hist, _ = np.histogram(lbp, bins=n_bins, range=(0, n_bins), density=True) + features = hist + return features + + def augment_image(self, img): + """对输入图像应用随机数据增强""" + # 随机旋转 ±30 度 + angle = np.random.uniform(-30, 30) + M = cv2.getRotationMatrix2D((img.shape[1] / 2, img.shape[0] / 2), angle, 1) + img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0])) + + # 随机水平翻转 + if np.random.rand() > 0.5: + img = cv2.flip(img, 1) + + # 随机调整亮度 ±50 + value = int(np.random.uniform(-50, 50)) + hsv = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) + h, s, v = cv2.split(hsv) + v = cv2.add(v, value) + v[v > 255] = 255 + v[v < 0] = 0 + final_hsv = cv2.merge((h, s, v)) + img = cv2.cvtColor(final_hsv, cv2.COLOR_HSV2BGR) + + return img + + def get_image_data(self, img_dir="./data/wenli/Huawen",augment=False): + """ + :param img_dir: 图像文件的路径 + :return: 图像数据 + """ + img_data = [] + img_name = [] + utils.mkdir_if_not_exist(img_dir) + files = os.listdir(img_dir) + if len(files) == 0: + return False + for file in files: + path = os.path.join(img_dir, file) + if self.debug_mode: + self.log.log(path) + train_img = cv2.imread(path) + if augment: + train_img = self.augment_image(train_img) # 应用数据增强 + data = self.extract_features(train_img) + img_data.append(data) + img_name.append(file) + img_data = np.array(img_data) + + return img_data, img_name + + def get_train_data(self, data_dir=None, plot_2d=False, save_data=False, augment=False): + """ + 获取图像数据 + :return: x_data, y_data + """ + print("开始加载训练数据...") + data_dir = os.path.join(ROOT_DIR, "data", "wenli") if data_dir is None else data_dir + hw_data, hw_name = self.get_image_data(img_dir=os.path.join(data_dir, "Huawen"), augment=augment) + zw_data, zw_name = self.get_image_data(img_dir=os.path.join(data_dir, "Zhiwen"), augment=augment) + if (hw_data is False) or (zw_data is False) : + return False + x_data = np.vstack((hw_data, zw_data)) + + hw_label = np.zeros(len(hw_data)).T # 为"Huawen"类图像赋值标签0 + zw_label = np.ones(len(zw_data)).T # 为"Zhiwen"类图像赋值标签1 + y_data = np.hstack((hw_label, zw_label)) + + img_name = hw_name + zw_name + + # if plot_2d: + # plt.scatter(x_data[:, 0], x_data[:, 1], c=y_data, cmap='viridis') + # plt.show() + if save_data: + with open(os.path.join("data", "data.p"), "rb") as f: + pass + if (hw_data is False) or (zw_data is False): + print("未找到有效的训练数据") + return False + print("训练数据加载完成") + return x_data, y_data, img_name + + def fit_pictures(self, data_path=ROOT_DIR, file_name=None, augment=False): + """ + 根据给出的data_path 进行 fit.如果没有给出data目录,那么将会使用当前文件夹 + :param data_path: + :return: + """ + print("开始训练模型...") + # 训练数据文件位置 + result = self.get_train_data(data_path, plot_2d=True, augment=augment) + if result is False: + print("训练数据加载失败,中止训练") + return 0 + x, y, name = result + print("训练数据加载成功,开始模型训练") + score = self.fit(x, y) + print('model score', score) + model_name = self.save(file_name) + return model_name + + def fit(self, X, y): + """训练模型""" + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42) + self.model.fit(X_train, y_train) + y_pred = self.model.predict(X_test) + + print(confusion_matrix(y_test, y_pred)) + + pre_score = accuracy_score(y_test, y_pred) # 计算测试集的准确率 + self.log.log("Test accuracy is:" + str(pre_score * 100) + "%.") + + y_pred = self.model.predict(X_train) + pre_score = accuracy_score(y_train, y_pred) + self.log.log("Train accuracy is:" + str(pre_score * 100) + "%.") + + y_pred = self.model.predict(X) + pre_score = accuracy_score(y, y_pred) + self.log.log("Total accuracy is:" + str(pre_score * 100) + "%.") + + return int(pre_score * 100) + + def predict(self, img): + """预测图像纹理""" + if self.debug_mode: + cv2.imwrite(str(self.image_num) + ".bmp", img) + self.image_num += 1 + features = self.extract_features(img) # 提取图像特征 + features = np.array(features) # 将列表转换为 Numpy 数组 + features = features.reshape(1, -1) # 使用 reshape(1, -1) 将特征数组转换成适合模型预测的形状。这里的 1 表示样本数为一(单个图像),-1 表示自动计算特征数量。 + pred_wenli = self.model.predict(features) # 使用模型预测图像的纹理 + if self.debug_mode: + self.log.log(features) + return int(pred_wenli[0]) # 从模型预测结果中提取第一个元素(假设预测结果是一个数组),并将其转换为整数。这通常是分类任务的类别标签。 + + + def save(self, file_name): + """保存模型到文件""" + if file_name is None: + file_name = "model_" + time.strftime("%Y-%m-%d_%H-%M") + ".p" + file_name = os.path.join(ROOT_DIR, "models", file_name) + model_dic = { "model": self.model,"w": self.w, "h": self.h} + with open(file_name, "wb") as f: + pickle.dump(model_dic, f) + self.log.log("Save file to '" + str(file_name) + "'") + return file_name + + + def load(self, path=None): + if path is None: + path = os.path.join(ROOT_DIR, "models") + utils.mkdir_if_not_exist(path) + model_files = os.listdir(path) + if len(model_files) == 0: + self.log.log("No model found!") + return 1 + self.log.log("./ Models Found:") + _ = [self.log.log("├--" + str(model_file)) for model_file in model_files] + file_times = [model_file[6:-2] for model_file in model_files] + latest_model = model_files[int(np.argmax(file_times))] + self.log.log("└--Using the latest model: " + str(latest_model)) + path = os.path.join(ROOT_DIR, "models", str(latest_model)) + if not os.path.isabs(path): + logging.warning('给的是相对路径') + return -1 + if not os.path.exists(path): + logging.warning('文件不存在') + return -1 + with open(path, "rb") as f: + model_dic = pickle.load(f) + + self.model = model_dic["model"] + self.w = model_dic["w"] + self.h = model_dic["h"] + + return 0 + +if __name__ == '__main__': + from config import Config + + # 加载配置设置 + settings = Config() + + # 初始化 TextureClass 实例 + texture = TextureClass(w=4096, h=1200, debug_mode=False) + print("初始化纹理分类器完成") + + # 获取数据路径 + data_path = settings.data_path + # texture.correct() # 如果有色彩校正的步骤可以添加 + # texture.load() # 如果需要加载先前的模型可以调用 + + # 训练模型并保存 + model_path = texture.fit_pictures(data_path=data_path) + print(f"模型保存在 {model_path}") + + # # 加载K-means数据并进行数据调整 + # x_data, y_data, labels, img_names = texture.get_kmeans_data(data_path, plot_2d=True) + # send_data = texture.data_adjustments(x_data, y_data, labels, img_names) + # print(f"调整后的数据已发送/保存") + + # 测试单张图片的预测性能 + pic = cv2.imread(r"data/wenli/Zhiwen/rgb98.png") + start_time = time.time() + texture_type = texture.predict(pic) + end_time = time.time() + print("单次预测耗时:", (end_time - start_time) * 1000, "毫秒") + print("预测的纹理类型:", texture_type) + + # 如果有批量处理或性能测试需求 + # 总时间 = 0 + # for i in range(100): + # start_time = time.time() + # _ = texture.predict(pic) + # end_time = time.time() + # total_time += (end_time - start_time) + # print("平均预测时间:", (total_time / 100) * 1000, "毫秒") diff --git a/config.py b/config.py new file mode 100644 index 0000000..dceb8d6 --- /dev/null +++ b/config.py @@ -0,0 +1,60 @@ +# -*- codeing = utf-8 -*- +# Time : 2022/10/17 11:07 +# @Auther : zhouchao +# @File: config.py +# @Software:PyCharm +import json +import os +from pathlib import WindowsPath + +from root_dir import ROOT_DIR + + +class Config(object): + model_path = ROOT_DIR / 'config.json' # “/”运算符被重载,用于连接两个路径,表示 ROOT_DIR 目录下的 config.json 文件 + + def __init__(self): # 初始化方法 + self._param_dict = {} # 创建一个空字典 + if os.path.exists(Config.model_path): + self._read() # 如果 config.json 文件存在,调用 _read 方法 + else: + self.model_path = str(ROOT_DIR / 'models/model_2024-04-18_10-16.p') # 模型路径 + self.data_path = str(ROOT_DIR / 'data/xiangmu_photos_wenli') # 数据路径 + self.database_addr = str("mysql+pymysql://root:@localhost:3306/orm_test") # 测试用数据库地址 + self._param_dict['model_path'] = self.model_path # 将模型路径写入 _param_dict 属性中 + self._param_dict['data_path'] = self.data_path + self._param_dict['database_addr'] = self.database_addr + + def __setitem__(self, key, value): # 重载 __setitem__ 方法 + if key in self._param_dict: # 如果 key 在 _param_dict 中 key是键,value是值 用于设置值 + self._param_dict[key] = value # 将键值对写入 _param_dict 属性中 + self._write() # 将 _param_dict 属性写入 config.json 文件 + + def __getitem__(self, item): # 重载 __getitem__ 方法 + if item in self._param_dict: # 如果 item 在 _param_dict 中 item是键 value是值 用于获取值 + return self._param_dict[item] # 返回 _param_dict 中 item 键的值 + + def __setattr__(self, key, value): + self.__dict__[key] = value # 直接在对象的 __dict__ 属性(这是一个存储对象所有属性的字典)中设置键(属性名)和值 + if '_param_dict' in self.__dict__ and key != '_param_dict': # 如果 _param_dict 属性存在,且不是 _param_dict 属性本身 + if isinstance(value, WindowsPath): # 如果 value 是 WindowsPath 对象,将其转换为字符串 + value = str(value) # WindowsPath 对象不能被 json 序列化,需要转换为字符串 + self.__dict__['_param_dict'][key] = value # 将键值对写入 _param_dict 属性中 + self._write() # 将 _param_dict 属性写入 config.json 文件 + + def _read(self): # 读取 config.json 文件 + with open(Config.model_path, 'r') as f: # 打开 config.json 文件 + self._param_dict = json.load(f) # 读取文件内容,将其转换为字典 + self.data_path = self._param_dict['data_path'] # 从字典中读取 data_path 键的值 + self.model_path = self._param_dict['model_path'] # 从字典中读取 model_path 键的值 + self.database_addr = self._param_dict['database_addr'] + + def _write(self): # 将 _param_dict 属性写入 config.json 文件 + with open(Config.model_path, 'w') as f: # 打开 config.json 文件 + json.dump(self._param_dict, f) # 将 _param_dict 写入文件 + + +if __name__ == '__main__': + config = Config() + print(config.model_path) + print(config.data_path) diff --git a/database.py b/database.py new file mode 100644 index 0000000..d070334 --- /dev/null +++ b/database.py @@ -0,0 +1,65 @@ +# -*- codeing = utf-8 -*- +# Time : 2022/10/20 13:54 +# @Auther : zhouchao +# @File: database.py +# @Software:PyCharm +import datetime +import time + +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from sqlalchemy.dialects.mysql import INTEGER, VARCHAR +from sqlalchemy import Column, TIMESTAMP, DATETIME +from sqlalchemy import create_engine +from sqlalchemy.ext.declarative import declarative_base +Base = declarative_base() + + +class Wood(Base): + __tablename__ = 'color' + + id = Column(INTEGER, primary_key=True) + color = Column(VARCHAR(256), nullable=False) + time = Column(DATETIME, nullable=False) + + + def __init__(self, color): + self.time = datetime.datetime.now() + self.color = color + + +class Database(object): + def __init__(self, database_addr): + self.database_addr = database_addr + + def init_db(self): + engine = create_engine(self.database_addr, encoding="utf-8", echo=True) + Base.metadata.create_all(engine) + print('Create table successfully!') + + def add_data(self, color): + # 初始化数据库连接 + engine = create_engine(self.database_addr, encoding="utf-8") + # 创建DBSession类型 + DBSession = sessionmaker(bind=engine) + + # 创建session对象 + session = DBSession() + # 插入单条数据 + # 创建新User对象 + new_wood = Wood(color) + # 添加到session + session.add(new_wood) + # 提交即保存到数据库 + session.commit() + + +if __name__ == '__main__': + test_addr = "mysql+pymysql://root:@localhost:3306/color" + database = Database(test_addr) + # database.init_db() + t1 = time.time() + for i in range(100): + database.add_data('middle') + t2 = time.time() + print((t2-t1)/100) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..82b3c4d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,10 @@ +lightgbm==4.3.0 +matplotlib==3.5.2 +numpy==1.23.1 +numpy==1.25.1 +opencv_contrib_python==4.7.0.72 +opencv_python==4.6.0.66 +scikit_learn==1.1.1 +scipy==1.9.0 +skimage==0.0 +SQLAlchemy==2.0.19 diff --git a/root_dir.py b/root_dir.py new file mode 100644 index 0000000..f0b470d --- /dev/null +++ b/root_dir.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- +""" +Created on Nov 3 21:18:26 2020 + +@author: l.z.y +@e-mail: li.zhenye@qq.com +""" +import pathlib + +file_path = pathlib.Path(__file__) # 获得当前文件路径 +ROOT_DIR = file_path.parent # 获得当前文件的父目录 + +# pathlib的作用是将文件路径转换为操作系统的路径格式,这样可以避免不同操作系统的路径格式不同的问题 \ No newline at end of file diff --git a/socket_detector.py b/socket_detector.py new file mode 100644 index 0000000..84bf738 --- /dev/null +++ b/socket_detector.py @@ -0,0 +1,110 @@ +import socket +import sys + +import numpy as np +import cv2 + +import root_dir +from classifer import TextureClass +import time +import os + +# from database import Database +from root_dir import ROOT_DIR +from utils import PreSocket, receive_sock, parse_protocol, ack_sock, done_sock, DualSock, simple_sock +import logging +from config import Config + + +def process_cmd(cmd: str, data: any, connected_sock: socket.socket, detector: TextureClass, settings: Config) -> tuple: + """ + 处理指令 + + :param cmd: 指令类型 + :param data: 指令内容 + :param connected_sock: socket + :param detector: 模型 + :return: 是否处理成功 + """ + result = '' + if cmd == 'IM': + data = np.clip(data, 0, 255).astype(dtype=np.uint8) # 将 data 数组中的每个元素限制在范围 [0, 255] 内,将数组的数据类型转换为无符号 8 位整数 + wood_wenli = detector.predict(data) + result = {0: 'Huawen', 1: 'Zhiwen'}[wood_wenli] + response = simple_sock(connected_sock, cmd_type=cmd, result=wood_wenli) + elif cmd == 'TR': + detector = TextureClass(w=4096, h=1200, debug_mode=False) + model_name = None + if "$" in data: + data, model_name = data.split("$", 1) + model_name = model_name + ".p" + settings.data_path = data + settings.model_path = ROOT_DIR / 'models' / detector.fit_pictures(data_path=settings.data_path, file_name=model_name) + response = simple_sock(connected_sock, cmd_type=cmd, result=result) + elif cmd == 'MD': + settings.model_path = data + detector.load(path=settings.model_path) + response = simple_sock(connected_sock, cmd_type=cmd, result=result) + + + else: + logging.error(f'错误指令,指令为{cmd}') + response = False + return response, result + + +def main(is_debug=False): + settings = Config() # 创建一个配置类的实例 + file_handler = logging.FileHandler(os.path.join(ROOT_DIR, 'report.log')) # 创建一个文件处理器,将日志信息写入到指定的 report.log 文件 + file_handler.setLevel(logging.DEBUG if is_debug else logging.WARNING) # 设置日志级别, DEBUG 级别最低,可以输出所有级别的日志信息; WARNING 级别最高,只能输出 WARNING 级别及以上的日志信息 + console_handler = logging.StreamHandler(sys.stdout) # 创建一个控制台处理器,将日志信息输出到控制台 + console_handler.setLevel(logging.DEBUG if is_debug else logging.WARNING) # 设置日志级别 + logging.basicConfig(format='%(asctime)s %(filename)s[line:%(lineno)d] - %(levelname)s - %(message)s', + handlers=[file_handler, console_handler], level=logging.DEBUG) + dual_sock = DualSock(connect_ip='127.0.0.1') + + # database = Database(settings.database_addr) + + + while not dual_sock.status: + dual_sock.reconnect() + detector = TextureClass(w=4096, h=1200, debug_mode=False) + detector.load(path=settings.model_path) + _ = detector.predict(np.random.randint(1, 254, (1200, 4096, 3), dtype=np.uint8)) + while True: + pack, next_pack = receive_sock(dual_sock) + if pack == b"": + time.sleep(5) + dual_sock.reconnect() + continue + + cmd, data = parse_protocol(pack) + # ack_sock(received_sock, cmd_type=cmd) + response, result = process_cmd(cmd=cmd, data=data, connected_sock=dual_sock, detector=detector, settings=settings) + + # if result != "": + # database.add_data(result) + + + +if __name__ == '__main__': + # 2个端口 + # 接受端口21122 + # 发送端口21123 + # 接收到图片 n_rows * n_bands * n_cols, float32 + # 发送图片 n_rows * n_cols, uint8 + main(is_debug=False) + # test(r"D:\build-tobacco-Desktop_Qt_5_9_0_MSVC2015_64bit-Release\calibrated15.raw") + # main() + # debug_main() + # test_run(all_data_dir=r'D:\数据') + # with open(r'D:\数据\虫子\valid2.raw', 'rb') as f: + # data = np.frombuffer(f.read(), dtype=np.float32).reshape(600, 29, 1024).transpose(0, 2, 1) + # plt.matshow(data[:, :, 10]) + # plt.show() + # detector = SpecDetector('model_spec/model_29.p') + # result = detector.predict(data) + # + # plt.matshow(result) + # plt.show() + # result = result.reshape((600, 1024)) diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..2666e62 --- /dev/null +++ b/utils.py @@ -0,0 +1,347 @@ +# -*- coding: utf-8 -*- +""" +Created on Nov 3 21:18:26 2020 + +@author: l.z.y +@e-mail: li.zhenye@qq.com +""" +import logging +import os +import shutil +import time +import socket +import numpy as np + + +def mkdir_if_not_exist(dir_name, is_delete=False): + """ + 创建文件夹 + :param dir_name: 文件夹 + :param is_delete: 是否删除 + :return: 是否成功 + """ + try: + if is_delete: + if os.path.exists(dir_name): # 如果 is_delete 为 True,函数会检查 dir_name 是否已经存在。 + shutil.rmtree(dir_name) # 如果存在,它会使用 shutil.rmtree 删除整个目录结构,并输出一个信息消息。 + print('[Info] 文件夹 "%s" 存在, 删除文件夹.' % dir_name) + + if not os.path.exists(dir_name): # 使用 os.path.exists 来检查 dir_name 是否存在。 + os.makedirs(dir_name) # 如果不存在,它会使用 os.makedirs 创建整个目录树。 + print('[Info] 文件夹 "%s" 不存在, 创建文件夹.' % dir_name) + return True + except Exception as e: + print('[Exception] %s' % e) + return False + + +def create_file(file_name): + """ + 创建文件 + :param file_name: 文件名 + :return: None + """ + if os.path.exists(file_name): + print("文件存在:%s" % file_name) + return False + # os.remove(file_name) # 删除已有文件 + if not os.path.exists(file_name): + print("文件不存在,创建文件:%s" % file_name) + open(file_name, 'a').close() + return True + + +class Logger(object): # 日志类 + def __init__(self, is_to_file=False, path=None): + self.is_to_file = is_to_file # 布尔值,用于决定日志是保存到文件(True)还是输出到控制台(False) + if path is None: # 如果没有指定日志文件的路径,则默认为当前目录下的 wood.log + path = "wood.log" + self.path = path # 指定保存日志的文件路径。 + create_file(path) # 确保日志文件存在,如果不存在则创建它 + + def log(self, content): + if self.is_to_file: # 如果指定了日志文件的路径,则将日志信息写入到日志文件中 + with open(self.path, "a") as f: # 以追加的方式打开日志文件,'a'(Append 模式): 在这种模式下,如果文件已经存在,新的数据会被写入到文件的末尾。如果文件不存在,它会被创建。 + print(time.strftime("[%Y-%m-%d_%H-%M-%S]:"), file=f) # 将当前时间戳和实际日志内容写入到文件,显示为[年-月-日_时-分-秒]: 的格式。 + print(content, file=f) + else: + print(content) # 如果没有指定日志文件的路径,则将日志信息输出到控制台 + + +def try_connect(connect_ip: str, port_number: int, is_repeat: bool = False, max_reconnect_times: int = 50) -> ( + bool, socket.socket): + """ + 尝试连接. + + :param is_repeat: 是否是重新连接 + :param max_reconnect_times:最大重连次数 + :return: (连接状态True为成功, Socket / None) + """ + reconnect_time = 0 + while reconnect_time < max_reconnect_times: + logging.warning(f'尝试{"重新" if is_repeat else ""}发起第{reconnect_time + 1}次连接...') + try: + connected_sock = PreSocket(socket.AF_INET, socket.SOCK_STREAM) + connected_sock.connect((connect_ip, port_number)) + except Exception as e: + reconnect_time += 1 + logging.error(f'第{reconnect_time}次连接失败... 5秒后重新连接...\n {e}') + time.sleep(5) + continue + logging.warning(f'{"重新" if is_repeat else ""}连接成功') + return True, connected_sock + return False, None + + +class PreSocket(socket.socket): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.pre_pack = b'' + self.settimeout(5) + + def receive(self, *args, **kwargs): + if self.pre_pack == b'': + return self.recv(*args, **kwargs) + else: + data_len = args[0] + required, left = self.pre_pack[:data_len], self.pre_pack[data_len:] + self.pre_pack = left + return required + + def set_prepack(self, pre_pack: bytes): + temp = self.pre_pack + self.pre_pack = temp + pre_pack + + +class DualSock(PreSocket): + def __init__(self, connect_ip='127.0.0.1', recv_port: int = 21122, send_port: int = 21123): + super().__init__() + received_status, self.received_sock = try_connect(connect_ip=connect_ip, port_number=recv_port) # 这两行代码分别设置接收和发送的sockets。 + send_status, self.send_sock = try_connect(connect_ip=connect_ip, port_number=send_port) + self.status = received_status and send_status + + def send(self, *args, **kwargs) -> int: + return self.send_sock.send(*args, **kwargs) + + def receive(self, *args, **kwargs) -> bytes: + return self.received_sock.receive(*args, **kwargs) + + def set_prepack(self, pre_pack: bytes): + self.received_sock.set_prepack(pre_pack) + + def reconnect(self, connect_ip='127.0.0.1', recv_port:int = 21122, send_port: int = 21123): + received_status, self.received_sock = try_connect(connect_ip=connect_ip, port_number=recv_port) + send_status, self.send_sock = try_connect(connect_ip=connect_ip, port_number=send_port) + return received_status and send_status + + +def receive_sock(recv_sock: PreSocket, pre_pack: bytes = b'', time_out: float = -1.0, time_out_single=5e20) -> ( +bytes, bytes): + """ + 从指定的socket中读取数据. + + :param recv_sock: 指定sock + :param pre_pack: 上一包的粘包内容 + :param time_out: 每隔time_out至少要发来一次指令,否则认为出现问题进行重连,小于0则为一直等 + :param time_out_single: 单次指令超时时间,单位是秒 + :return: data, next_pack + """ + recv_sock.set_prepack(pre_pack) + # 开头校验 + time_start_recv = time.time() + while True: + if time_out > 0: + if (time.time() - time_start_recv) > time_out: + logging.error(f'指令接收超时') + return b'', b'' + try: + temp = recv_sock.receive(1) + except ConnectionError as e: + logging.error(f'连接出错, 错误代码:\n{e}') + return b'', b'' + except TimeoutError as e: + # logging.error(f'超时了,错误代码: \n{e}') + logging.info('运行中,等待指令..') + continue + except socket.timeout as e: + logging.info('运行中,等待指令..') + continue + except Exception as e: + logging.error(f'遇见未知错误,错误代码: \n{e}') + return b'', b'' + if temp == b'\xaa': + break + + # 接收开头后,开始进行时间记录 + time_start_recv = time.time() + + # 获取报文长度 + temp = b'' + while len(temp) < 4: + if (time.time() - time_start_recv) > time_out_single: + logging.error(f'单次指令接收超时') + return b'', b'' + try: + temp += recv_sock.receive(1) + except Exception as e: + logging.error(f'接收报文的长度不正确, 错误代码: \n{e}') + return b'', b'' + try: + data_len = int.from_bytes(temp, byteorder='big') + except Exception as e: + logging.error(f'转换失败,错误代码 \n{e}') + return b'', b'' + + # 读取报文内容 + temp = b'' + while len(temp) < data_len: + if (time.time() - time_start_recv) > time_out_single: + logging.error(f'单次指令接收超时') + return b'', b'' + try: + temp += recv_sock.receive(data_len) + except Exception as e: + logging.error(f'接收报文内容失败, 错误代码: \n{e}') + return b'', b'' + data, next_pack = temp[:data_len], temp[data_len:] + recv_sock.set_prepack(next_pack) + next_pack = b'' + + # 进行数据校验 + temp = b'' + while len(temp) < 3: + if (time.time() - time_start_recv) > time_out_single: + logging.error(f'单次指令接收超时') + return b'', b'' + try: + temp += recv_sock.receive(1) + except Exception as e: + logging.error(f'接收报文校验失败, 错误代码: \n{e}') + return b'', b'' + if temp == b'\xff\xff\xbb': + return data, next_pack + else: + logging.error(f"接收了一个完美的只错了校验位的报文") + return b'', b'' + + +def parse_protocol(data: bytes) -> (str, any): # data: 参数类型为 bytes,表示接收到的报文数据。返回类型 (str, any) 表示函数返回一个元组,其中第一个元素是指令类型的字符串,第二个元素是指令对应的内容,内容的类型可以是任何类型(由指令决定)。 + """ + 指令转换. + + :param data:接收到的报文 + :return: 指令类型和内容 + """ + try: + assert len(data) > 4 + except AssertionError: + logging.error('指令转换失败,长度不足5') + return '', None + cmd, data = data[:4], data[4:] # 从 data 中取出前 4 个字节作为指令类型,剩下的部分作为指令内容。 + cmd = cmd.decode('ascii').strip().upper() # 将指令类型转换为字符串,并去除首尾空格,然后转换为大写。 + if cmd == 'IM': + n_rows, n_cols, img = data[:2], data[2:4], data[4:] # 按照协议,先是两个字节的行数(高),两个字节的列数(宽),后面是图像数据 + try: + n_rows, n_cols = [int.from_bytes(x, byteorder='big') for x in [n_rows, n_cols]] + except Exception as e: + logging.error(f'长宽转换失败, 错误代码{e}, 报文大小: n_rows:{n_rows}, n_cols: {n_cols}') + return '', None + try: + assert n_rows * n_cols * 3 == len(img) + # 因为是float32类型 所以长度要乘12 ,如果是uint8则乘3 + except AssertionError: + logging.error('图像指令IM转换失败,数据长度错误') + return '', None + img = np.frombuffer(img, dtype=np.uint8).reshape((n_rows, n_cols, -1)) + return cmd, img + elif cmd == 'TR': + data = data.decode('ascii') + return cmd, data + elif cmd == 'MD': + data = data.decode('ascii') + return cmd, data + + +def ack_sock(send_sock: socket.socket, cmd_type: str) -> bool: # 未使用 + ''' + 发送应答 + :param cmd_type:指令类型 + :param send_sock:指定sock + :return:是否发送成功 + ''' + msg = b'\xaa\x00\x00\x00\x05' + (' A' + cmd_type).upper().encode('ascii') + b'\xff\xff\xff\xbb' + try: + send_sock.send(msg) + except Exception as e: + logging.error(f'发送应答失败,错误类型:{e}') + return False + return True + + +def done_sock(send_sock: socket.socket, cmd_type: str, result = '') -> bool: # 未使用 + ''' + 发送任务完成指令 + :param cmd_type:指令类型 + :param send_sock:指定sock + :param result:数据 + :return:是否发送成功 + ''' + cmd_type = cmd_type.strip().upper() + if (cmd_type == "TR") or (cmd_type == "MD"): + if result != '': + logging.error('结果在这种指令里很没必要') + result = b'\xff' + elif cmd_type == 'IM': + if result == 0: + result = b'H' + elif result == 1: + result = b'Z' + length = len(result) + 4 + length = length.to_bytes(4, byteorder='big') + msg = b'\xaa' +length + (' D' + cmd_type).upper().encode('ascii') + result + b'\xff\xff\xbb' + try: + send_sock.send(msg) + except Exception as e: + logging.error(f'发送完成指令失败,错误类型:{e}') + return False + return True + + +def simple_sock(send_sock: socket.socket, cmd_type: str, result) -> bool: + ''' + 发送任务完成指令 + :param cmd_type:指令类型 + :param send_sock:指定sock + :param result:数据 + :return:是否发送成功 + ''' + cmd_type = cmd_type.strip().upper() # 去除空格并转换为大写 + if cmd_type == 'IM': + if result == 0: + msg = b'H' + elif result == 1: + msg = b'Z' + elif cmd_type == 'TR': + msg = b'A' + elif cmd_type == 'MD': + msg = b'D' + result = result.encode('ascii') + result = b',' + result + length = len(result) + msg = msg + length.to_bytes(4, 'big') + result + try: + send_sock.send(msg) + except Exception as e: + logging.error(f'发送完成指令失败,错误类型:{e}') + return False + return True + + + +if __name__ == '__main__': + log = Logger(is_to_file=True) + log.log("nihao") + import numpy as np + + a = np.ones((100, 100, 3)) + log.log(a.shape)