实现了聚类重新给标签,返回x,y,labels,names的功能,后续需要联调实现发送功能

This commit is contained in:
FEIJINTI 2023-03-27 16:42:32 +08:00
parent 81f7b1dedc
commit a72c954b0a

View File

@ -31,7 +31,7 @@ sys.path.append(os.getcwd())
from root_dir import ROOT_DIR
import utils
FEATURE_INDEX = [0,1,2]
FEATURE_INDEX = [0,1,2]#因为显示需要所以0,1,2分别为lab是必须使用的不然会影响图片显示
delete_columns = 10 # 已弃用
num_bins = 10
@ -116,7 +116,7 @@ class WoodClass(object):
result = self.get_train_data(data_path, plot_2d=True)
if result is False:
return 0
x, y = result
x, y, name = result
score = self.fit(x, y)
print('model score', score)
model_name = self.save(file_name)
@ -354,6 +354,7 @@ class WoodClass(object):
:return: 图像数据
"""
img_data = []
img_name = []
utils.mkdir_if_not_exist(img_dir)
files = os.listdir(img_dir)
if len(files) == 0:
@ -366,18 +367,19 @@ class WoodClass(object):
train_img = self.realtime_correct(train_img, 10, 20)
data = self.extract_feature(train_img)
img_data.append(data)
img_name.append(file)
img_data = np.array(img_data)
# 提取图像名称
img_name = [os.path.splitext(file)[0] for file in files]
# 提取每个图像名称中的数字
img_name = [name[3:] for name in img_name]
# 将图像名称个位数前补零
img_name = [name.zfill(2) for name in img_name]
# 打印图像名称
print('img_name:', img_name)
# # 提取图像名称
# img_name = [os.path.splitext(file)[0] for file in files]
# # 提取每个图像名称中的数字
# img_name = [name[3:] for name in img_name]
# # 将图像名称个位数前补零
# img_name = [name.zfill(2) for name in img_name]
# # 打印图像名称
# print('img_name:', img_name)
return img_data
return img_data, img_name
def get_train_data(self, data_dir=None, plot_2d=False, plot_data_3d=False, save_data=False):
"""
@ -385,9 +387,9 @@ class WoodClass(object):
:return: x_data, y_data
"""
data_dir = os.path.join(ROOT_DIR, "data", "data20220919") if data_dir is None else data_dir
dark_data = self.get_image_data(img_dir=os.path.join(data_dir, "dark"))
middle_data = self.get_image_data(img_dir=os.path.join(data_dir, "middle"))
light_data = self.get_image_data(img_dir=os.path.join(data_dir, "light"))
dark_data, dark_name = self.get_image_data(img_dir=os.path.join(data_dir, "dark"))
middle_data, middle_name = self.get_image_data(img_dir=os.path.join(data_dir, "middle"))
light_data, light_name = self.get_image_data(img_dir=os.path.join(data_dir, "light"))
if (dark_data is False) or (middle_data is False) or (light_data is False):
return False
x_data = np.vstack((dark_data, middle_data, light_data))
@ -396,45 +398,47 @@ class WoodClass(object):
light_label = 2 * np.ones(len(light_data)).T
y_data = np.hstack((dark_label, middle_label, light_label))
x_data = x_data[:, FEATURE_INDEX]
# dark_name, middle_name, light_name三个list合并
img_name = dark_name + middle_name + light_name
# 使用KMeans算法对图片数据进行聚类
kmeans = KMeans(n_clusters=3, random_state=0).fit(x_data)
z = kmeans.predict(x_data)
# 获取聚类后的数据
dark = x_data[kmeans.labels_ == 0]
middle = x_data[kmeans.labels_ == 1]
light = x_data[kmeans.labels_ == 2]
# 获取数据的均值
dark_mean = np.mean(dark, axis=0)
middle_mean = np.mean(middle, axis=0)
light_mean = np.mean(light, axis=0)
# 按照平均值从小到大排序
sorted_cluster_indices = np.argsort([dark_mean[0], middle_mean[0], light_mean[0]])
print('sorted_cluster_indices:', sorted_cluster_indices)
# 重新编号聚类标签
sorted_labels = np.zeros(len(kmeans.labels_), dtype=int)
for i, label in enumerate(kmeans.labels_):
sorted_labels[i] = sorted_cluster_indices[label]
# 更新kmeans.labels_
kmeans.labels_ = sorted_labels
print('kmeans.labels_:', kmeans.labels_)
# 获取更新聚类后的数据
dark_new = x_data[kmeans.labels_ == 0]
middle_new = x_data[kmeans.labels_ == 1]
light_new = x_data[kmeans.labels_ == 2]
# 获取更新数据的均值
dark_mean_new = np.mean(dark_new, axis=0)
middle_mean_new = np.mean(middle_new, axis=0)
light_mean_new = np.mean(light_new, axis=0)
# 打印每个聚类的平均值
print('Dark cluster mean:', dark_mean_new)
print('Middle cluster mean:', middle_mean_new)
print('Light cluster mean:', light_mean_new)
# plot_2d
plt.figure()
plt.scatter(x_data[:, 0], x_data[:, 1], c=z)
plt.show()
# # 使用KMeans算法对图片数据进行聚类
# kmeans = KMeans(n_clusters=3, random_state=0).fit(x_data)
# z = kmeans.predict(x_data)
# # 获取聚类后的数据
# dark = x_data[kmeans.labels_ == 0]
# middle = x_data[kmeans.labels_ == 1]
# light = x_data[kmeans.labels_ == 2]
# # 获取数据的均值
# dark_mean = np.mean(dark, axis=0)
# middle_mean = np.mean(middle, axis=0)
# light_mean = np.mean(light, axis=0)
#
# # 按照平均值从小到大排序
# sorted_cluster_indices = np.argsort([dark_mean[0], middle_mean[0], light_mean[0]])
# print('sorted_cluster_indices:', sorted_cluster_indices)
# # 重新编号聚类标签
# sorted_labels = np.zeros(len(kmeans.labels_), dtype=int)
# for i, label in enumerate(kmeans.labels_):
# sorted_labels[i] = sorted_cluster_indices[label]
# # 更新kmeans.labels_
# kmeans.labels_ = sorted_labels
# print('kmeans.labels_:', kmeans.labels_)
# # 获取更新聚类后的数据
# dark_new = x_data[kmeans.labels_ == 0]
# middle_new = x_data[kmeans.labels_ == 1]
# light_new = x_data[kmeans.labels_ == 2]
# # 获取更新数据的均值
# dark_mean_new = np.mean(dark_new, axis=0)
# middle_mean_new = np.mean(middle_new, axis=0)
# light_mean_new = np.mean(light_new, axis=0)
# # 打印每个聚类的平均值
# print('Dark cluster mean:', dark_mean_new)
# print('Middle cluster mean:', middle_mean_new)
# print('Light cluster mean:', light_mean_new)
# # plot_2d
# plt.figure()
# plt.scatter(x_data[:, 0], x_data[:, 1], c=z)
# plt.show()
# 进行色彩数据校正
if self.isCorrect:
@ -464,7 +468,7 @@ class WoodClass(object):
if save_data:
with open(os.path.join("data", "data.p"), "rb") as f:
pass
return x_data, y_data
return x_data, y_data, img_name
def realtime_correction(self, img):
"""
@ -503,6 +507,40 @@ class WoodClass(object):
img = np.clip(img, 0, 255).astype(dtype=np.uint8)
return img
def get_kmeans_data(self, data_dir=None, plot_2d=False):
"""
获取kmeans数据
:param data_dir: 图片路径
:param plot_2d: 是否绘制二维图
:return:
"""
x_data, y_data, img_names = self.get_train_data(data_dir, plot_2d=plot_2d)
kmeans = KMeans(n_clusters=3, random_state=0).fit(x_data)
# 获取聚类后的数据
dark = x_data[kmeans.labels_ == 0]
middle = x_data[kmeans.labels_ == 1]
light = x_data[kmeans.labels_ == 2]
dark_mean = np.mean(dark, axis=0)
middle_mean = np.mean(middle, axis=0)
light_mean = np.mean(light, axis=0)
sorted_cluster_indices = np.argsort([dark_mean[0], middle_mean[0], light_mean[0]])
labels = kmeans.labels_
for i in range(labels.shape[0]):
labels[i] = sorted_cluster_indices[labels[i]]
if plot_2d:
plt.figure()
plt.scatter(x_data[:, 0], x_data[:, 1], c=labels)
plt.show()
return x_data, y_data, labels, img_names
if __name__ == '__main__':
from config import Config
@ -517,6 +555,9 @@ if __name__ == '__main__':
# fit 相应的文件夹
settings.model_path = str(ROOT_DIR / 'models' / wood.fit_pictures(data_path=data_path))
wood.get_kmeans_data(data_path, plot_2d=True)
# 测试单张图片的预测predict_mode=True表示导入本地的model, False为现场训练的
pic = cv2.imread(r"data/318/dark/rgb89.png")
start_time = time.time()