From f593a67e76292cd9ce4055a01b4f594829ad1e08 Mon Sep 17 00:00:00 2001 From: duanmu <774052669@qq.com> Date: Sun, 26 Mar 2023 15:51:44 +0800 Subject: [PATCH] =?UTF-8?q?=E5=B0=9D=E8=AF=95=E5=8A=A0=E4=B8=AA=E5=8A=9F?= =?UTF-8?q?=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- classifer.py | 52 +++++++++++++++++++++++++++++++++++++++++++++- socket_detector.py | 2 ++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/classifer.py b/classifer.py index 70bd9da..a5b9c2c 100755 --- a/classifer.py +++ b/classifer.py @@ -31,7 +31,7 @@ sys.path.append(os.getcwd()) from root_dir import ROOT_DIR import utils -FEATURE_INDEX = [0,1,2,3,4,5] +FEATURE_INDEX = [0,1,2] delete_columns = 10 # 已弃用 num_bins = 10 @@ -367,6 +367,16 @@ class WoodClass(object): data = self.extract_feature(train_img) img_data.append(data) img_data = np.array(img_data) + + # 提取图像名称 + img_name = [os.path.splitext(file)[0] for file in files] + # 提取每个图像名称中的数字 + img_name = [name[3:] for name in img_name] + # 将图像名称个位数前补零 + img_name = [name.zfill(2) for name in img_name] + # 打印图像名称 + print('img_name:', img_name) + return img_data def get_train_data(self, data_dir=None, plot_2d=False, plot_data_3d=False, save_data=False): @@ -386,6 +396,46 @@ class WoodClass(object): light_label = 2 * np.ones(len(light_data)).T y_data = np.hstack((dark_label, middle_label, light_label)) x_data = x_data[:, FEATURE_INDEX] + + # 使用KMeans算法对图片数据进行聚类 + kmeans = KMeans(n_clusters=3, random_state=0).fit(x_data) + z = kmeans.predict(x_data) + # 获取聚类后的数据 + dark = x_data[kmeans.labels_ == 0] + middle = x_data[kmeans.labels_ == 1] + light = x_data[kmeans.labels_ == 2] + # 获取数据的均值 + dark_mean = np.mean(dark, axis=0) + middle_mean = np.mean(middle, axis=0) + light_mean = np.mean(light, axis=0) + + # 按照平均值从小到大排序 + sorted_cluster_indices = np.argsort([dark_mean[0], middle_mean[0], light_mean[0]]) + print('sorted_cluster_indices:', sorted_cluster_indices) + # 重新编号聚类标签 + sorted_labels = np.zeros(len(kmeans.labels_), dtype=int) + for i, label in enumerate(kmeans.labels_): + sorted_labels[i] = sorted_cluster_indices[label] + # 更新kmeans.labels_ + kmeans.labels_ = sorted_labels + print('kmeans.labels_:', kmeans.labels_) + # 获取更新聚类后的数据 + dark_new = x_data[kmeans.labels_ == 0] + middle_new = x_data[kmeans.labels_ == 1] + light_new = x_data[kmeans.labels_ == 2] + # 获取更新数据的均值 + dark_mean_new = np.mean(dark_new, axis=0) + middle_mean_new = np.mean(middle_new, axis=0) + light_mean_new = np.mean(light_new, axis=0) + # 打印每个聚类的平均值 + print('Dark cluster mean:', dark_mean_new) + print('Middle cluster mean:', middle_mean_new) + print('Light cluster mean:', light_mean_new) + # plot_2d + plt.figure() + plt.scatter(x_data[:, 0], x_data[:, 1], c=z) + plt.show() + # 进行色彩数据校正 if self.isCorrect: x_data = x_data / (self.correct_color + 1e-4) diff --git a/socket_detector.py b/socket_detector.py index af56d1d..69c8e1a 100644 --- a/socket_detector.py +++ b/socket_detector.py @@ -45,6 +45,8 @@ def process_cmd(cmd: str, data: any, connected_sock: socket.socket, detector: Wo settings.model_path = data detector.load(path=settings.model_path) response = simple_sock(connected_sock, cmd_type=cmd) + elif cmd == 'DT': + pass else: logging.error(f'错误指令,指令为{cmd}') response = False