diff --git a/classifer.py b/classifer.py index b45dc6e..f30192c 100644 --- a/classifer.py +++ b/classifer.py @@ -461,6 +461,40 @@ class WoodClass(object): img = np.clip(img, 0, 255).astype(dtype=np.uint8) return img + def get_luminance_data(self, data_dir=None, plot_2d=False): + """ + 获取l数据 + :param data_dir: 图片路径 + :param plot_2d: 是否绘制二维图 + :return: + """ + x_data, y_data, img_names = self.get_train_data(data_dir, plot_2d=plot_2d) + x_data = x_data[:, [0, 1, 2]] + dark_num = len(x_data[y_data == 0]) + middle_num = len(x_data[y_data == 1]) + light_num = len(x_data[y_data == 2]) + # 将数据按照亮度进行排序 + x_data = x_data[np.argsort(x_data[:, 0])] + # 按照x的顺序,将y和names也进行排序 + y_data = y_data[np.argsort(x_data[:, 0])] + # 按照x的顺序,将img_names也进行排序,但是这里需要注意,img_names是一个list,所以需要先转换成np.array + img_names = [img_names[i] for i in np.argsort(x_data[:, 0])] + + # 创建一个labels,用于存储每个像素点的标签 + labels = np.zeros_like(x_data[:, 0]) + # 将亮度最低的dark_num个像素点标记为0 + labels[:dark_num] = 0 + # 将亮度最高的light_num个像素点标记为2 + labels[-light_num:] = 2 + # 将中间的middle_num个像素点标记为1 + labels[dark_num:-light_num] = 1 + if plot_2d: + plt.figure() + plt.scatter(x_data[:, 0], x_data[:, 1], c=labels) + plt.show() + + return x_data, y_data, labels, img_names + def get_kmeans_data(self, data_dir=None, plot_2d=False): """ 获取kmeans数据 diff --git a/socket_detector.py b/socket_detector.py index 6a4f8d6..c1f1ada 100644 --- a/socket_detector.py +++ b/socket_detector.py @@ -46,7 +46,7 @@ def process_cmd(cmd: str, data: any, connected_sock: socket.socket, detector: Wo detector.load(path=settings.model_path) response = simple_sock(connected_sock, cmd_type=cmd, result=result) elif cmd == 'KM': - x_data, y_data, labels, img_names = detector.get_kmeans_data(data, plot_2d=True) + x_data, y_data, labels, img_names = detector.get_luminance_data(data, plot_2d=True) result = detector.data_adjustments(x_data, y_data, labels, img_names) result = ','.join([str(x) for x in result]) response = simple_sock(connected_sock, cmd_type=cmd, result=result)