diff --git a/classifer.py b/classifer.py index a5b9c2c..28d9121 100644 --- a/classifer.py +++ b/classifer.py @@ -31,7 +31,7 @@ sys.path.append(os.getcwd()) from root_dir import ROOT_DIR import utils -FEATURE_INDEX = [0,1,2] +FEATURE_INDEX = [0,1,2]#因为显示需要,所以0,1,2分别为lab,是必须使用的,不然会影响图片显示 delete_columns = 10 # 已弃用 num_bins = 10 @@ -116,7 +116,7 @@ class WoodClass(object): result = self.get_train_data(data_path, plot_2d=True) if result is False: return 0 - x, y = result + x, y, name = result score = self.fit(x, y) print('model score', score) model_name = self.save(file_name) @@ -354,6 +354,7 @@ class WoodClass(object): :return: 图像数据 """ img_data = [] + img_name = [] utils.mkdir_if_not_exist(img_dir) files = os.listdir(img_dir) if len(files) == 0: @@ -366,18 +367,19 @@ class WoodClass(object): train_img = self.realtime_correct(train_img, 10, 20) data = self.extract_feature(train_img) img_data.append(data) + img_name.append(file) img_data = np.array(img_data) - # 提取图像名称 - img_name = [os.path.splitext(file)[0] for file in files] - # 提取每个图像名称中的数字 - img_name = [name[3:] for name in img_name] - # 将图像名称个位数前补零 - img_name = [name.zfill(2) for name in img_name] - # 打印图像名称 - print('img_name:', img_name) + # # 提取图像名称 + # img_name = [os.path.splitext(file)[0] for file in files] + # # 提取每个图像名称中的数字 + # img_name = [name[3:] for name in img_name] + # # 将图像名称个位数前补零 + # img_name = [name.zfill(2) for name in img_name] + # # 打印图像名称 + # print('img_name:', img_name) - return img_data + return img_data, img_name def get_train_data(self, data_dir=None, plot_2d=False, plot_data_3d=False, save_data=False): """ @@ -385,9 +387,9 @@ class WoodClass(object): :return: x_data, y_data """ data_dir = os.path.join(ROOT_DIR, "data", "data20220919") if data_dir is None else data_dir - dark_data = self.get_image_data(img_dir=os.path.join(data_dir, "dark")) - middle_data = self.get_image_data(img_dir=os.path.join(data_dir, "middle")) - light_data = self.get_image_data(img_dir=os.path.join(data_dir, "light")) + dark_data, dark_name = self.get_image_data(img_dir=os.path.join(data_dir, "dark")) + middle_data, middle_name = self.get_image_data(img_dir=os.path.join(data_dir, "middle")) + light_data, light_name = self.get_image_data(img_dir=os.path.join(data_dir, "light")) if (dark_data is False) or (middle_data is False) or (light_data is False): return False x_data = np.vstack((dark_data, middle_data, light_data)) @@ -396,45 +398,47 @@ class WoodClass(object): light_label = 2 * np.ones(len(light_data)).T y_data = np.hstack((dark_label, middle_label, light_label)) x_data = x_data[:, FEATURE_INDEX] + # dark_name, middle_name, light_name三个list合并 + img_name = dark_name + middle_name + light_name - # 使用KMeans算法对图片数据进行聚类 - kmeans = KMeans(n_clusters=3, random_state=0).fit(x_data) - z = kmeans.predict(x_data) - # 获取聚类后的数据 - dark = x_data[kmeans.labels_ == 0] - middle = x_data[kmeans.labels_ == 1] - light = x_data[kmeans.labels_ == 2] - # 获取数据的均值 - dark_mean = np.mean(dark, axis=0) - middle_mean = np.mean(middle, axis=0) - light_mean = np.mean(light, axis=0) - - # 按照平均值从小到大排序 - sorted_cluster_indices = np.argsort([dark_mean[0], middle_mean[0], light_mean[0]]) - print('sorted_cluster_indices:', sorted_cluster_indices) - # 重新编号聚类标签 - sorted_labels = np.zeros(len(kmeans.labels_), dtype=int) - for i, label in enumerate(kmeans.labels_): - sorted_labels[i] = sorted_cluster_indices[label] - # 更新kmeans.labels_ - kmeans.labels_ = sorted_labels - print('kmeans.labels_:', kmeans.labels_) - # 获取更新聚类后的数据 - dark_new = x_data[kmeans.labels_ == 0] - middle_new = x_data[kmeans.labels_ == 1] - light_new = x_data[kmeans.labels_ == 2] - # 获取更新数据的均值 - dark_mean_new = np.mean(dark_new, axis=0) - middle_mean_new = np.mean(middle_new, axis=0) - light_mean_new = np.mean(light_new, axis=0) - # 打印每个聚类的平均值 - print('Dark cluster mean:', dark_mean_new) - print('Middle cluster mean:', middle_mean_new) - print('Light cluster mean:', light_mean_new) - # plot_2d - plt.figure() - plt.scatter(x_data[:, 0], x_data[:, 1], c=z) - plt.show() + # # 使用KMeans算法对图片数据进行聚类 + # kmeans = KMeans(n_clusters=3, random_state=0).fit(x_data) + # z = kmeans.predict(x_data) + # # 获取聚类后的数据 + # dark = x_data[kmeans.labels_ == 0] + # middle = x_data[kmeans.labels_ == 1] + # light = x_data[kmeans.labels_ == 2] + # # 获取数据的均值 + # dark_mean = np.mean(dark, axis=0) + # middle_mean = np.mean(middle, axis=0) + # light_mean = np.mean(light, axis=0) + # + # # 按照平均值从小到大排序 + # sorted_cluster_indices = np.argsort([dark_mean[0], middle_mean[0], light_mean[0]]) + # print('sorted_cluster_indices:', sorted_cluster_indices) + # # 重新编号聚类标签 + # sorted_labels = np.zeros(len(kmeans.labels_), dtype=int) + # for i, label in enumerate(kmeans.labels_): + # sorted_labels[i] = sorted_cluster_indices[label] + # # 更新kmeans.labels_ + # kmeans.labels_ = sorted_labels + # print('kmeans.labels_:', kmeans.labels_) + # # 获取更新聚类后的数据 + # dark_new = x_data[kmeans.labels_ == 0] + # middle_new = x_data[kmeans.labels_ == 1] + # light_new = x_data[kmeans.labels_ == 2] + # # 获取更新数据的均值 + # dark_mean_new = np.mean(dark_new, axis=0) + # middle_mean_new = np.mean(middle_new, axis=0) + # light_mean_new = np.mean(light_new, axis=0) + # # 打印每个聚类的平均值 + # print('Dark cluster mean:', dark_mean_new) + # print('Middle cluster mean:', middle_mean_new) + # print('Light cluster mean:', light_mean_new) + # # plot_2d + # plt.figure() + # plt.scatter(x_data[:, 0], x_data[:, 1], c=z) + # plt.show() # 进行色彩数据校正 if self.isCorrect: @@ -464,7 +468,7 @@ class WoodClass(object): if save_data: with open(os.path.join("data", "data.p"), "rb") as f: pass - return x_data, y_data + return x_data, y_data, img_name def realtime_correction(self, img): """ @@ -503,6 +507,40 @@ class WoodClass(object): img = np.clip(img, 0, 255).astype(dtype=np.uint8) return img + def get_kmeans_data(self, data_dir=None, plot_2d=False): + """ + 获取kmeans数据 + :param data_dir: 图片路径 + :param plot_2d: 是否绘制二维图 + :return: + """ + x_data, y_data, img_names = self.get_train_data(data_dir, plot_2d=plot_2d) + kmeans = KMeans(n_clusters=3, random_state=0).fit(x_data) + + # 获取聚类后的数据 + dark = x_data[kmeans.labels_ == 0] + middle = x_data[kmeans.labels_ == 1] + light = x_data[kmeans.labels_ == 2] + + dark_mean = np.mean(dark, axis=0) + middle_mean = np.mean(middle, axis=0) + light_mean = np.mean(light, axis=0) + + sorted_cluster_indices = np.argsort([dark_mean[0], middle_mean[0], light_mean[0]]) + labels = kmeans.labels_ + for i in range(labels.shape[0]): + labels[i] = sorted_cluster_indices[labels[i]] + + + if plot_2d: + plt.figure() + plt.scatter(x_data[:, 0], x_data[:, 1], c=labels) + plt.show() + + return x_data, y_data, labels, img_names + + + if __name__ == '__main__': from config import Config @@ -517,6 +555,9 @@ if __name__ == '__main__': # fit 相应的文件夹 settings.model_path = str(ROOT_DIR / 'models' / wood.fit_pictures(data_path=data_path)) + wood.get_kmeans_data(data_path, plot_2d=True) + + # 测试单张图片的预测,predict_mode=True表示导入本地的model, False为现场训练的 pic = cv2.imread(r"data/318/dark/rgb89.png") start_time = time.time()