mirror of
https://github.com/NanjingForestryUniversity/supermachine-wood.git
synced 2025-11-08 10:13:53 +00:00
实现了聚类重新给标签,返回x,y,labels,names的功能,后续需要联调实现发送功能
This commit is contained in:
parent
81f7b1dedc
commit
a72c954b0a
147
classifer.py
147
classifer.py
@ -31,7 +31,7 @@ sys.path.append(os.getcwd())
|
|||||||
from root_dir import ROOT_DIR
|
from root_dir import ROOT_DIR
|
||||||
import utils
|
import utils
|
||||||
|
|
||||||
FEATURE_INDEX = [0,1,2]
|
FEATURE_INDEX = [0,1,2]#因为显示需要,所以0,1,2分别为lab,是必须使用的,不然会影响图片显示
|
||||||
delete_columns = 10 # 已弃用
|
delete_columns = 10 # 已弃用
|
||||||
num_bins = 10
|
num_bins = 10
|
||||||
|
|
||||||
@ -116,7 +116,7 @@ class WoodClass(object):
|
|||||||
result = self.get_train_data(data_path, plot_2d=True)
|
result = self.get_train_data(data_path, plot_2d=True)
|
||||||
if result is False:
|
if result is False:
|
||||||
return 0
|
return 0
|
||||||
x, y = result
|
x, y, name = result
|
||||||
score = self.fit(x, y)
|
score = self.fit(x, y)
|
||||||
print('model score', score)
|
print('model score', score)
|
||||||
model_name = self.save(file_name)
|
model_name = self.save(file_name)
|
||||||
@ -354,6 +354,7 @@ class WoodClass(object):
|
|||||||
:return: 图像数据
|
:return: 图像数据
|
||||||
"""
|
"""
|
||||||
img_data = []
|
img_data = []
|
||||||
|
img_name = []
|
||||||
utils.mkdir_if_not_exist(img_dir)
|
utils.mkdir_if_not_exist(img_dir)
|
||||||
files = os.listdir(img_dir)
|
files = os.listdir(img_dir)
|
||||||
if len(files) == 0:
|
if len(files) == 0:
|
||||||
@ -366,18 +367,19 @@ class WoodClass(object):
|
|||||||
train_img = self.realtime_correct(train_img, 10, 20)
|
train_img = self.realtime_correct(train_img, 10, 20)
|
||||||
data = self.extract_feature(train_img)
|
data = self.extract_feature(train_img)
|
||||||
img_data.append(data)
|
img_data.append(data)
|
||||||
|
img_name.append(file)
|
||||||
img_data = np.array(img_data)
|
img_data = np.array(img_data)
|
||||||
|
|
||||||
# 提取图像名称
|
# # 提取图像名称
|
||||||
img_name = [os.path.splitext(file)[0] for file in files]
|
# img_name = [os.path.splitext(file)[0] for file in files]
|
||||||
# 提取每个图像名称中的数字
|
# # 提取每个图像名称中的数字
|
||||||
img_name = [name[3:] for name in img_name]
|
# img_name = [name[3:] for name in img_name]
|
||||||
# 将图像名称个位数前补零
|
# # 将图像名称个位数前补零
|
||||||
img_name = [name.zfill(2) for name in img_name]
|
# img_name = [name.zfill(2) for name in img_name]
|
||||||
# 打印图像名称
|
# # 打印图像名称
|
||||||
print('img_name:', img_name)
|
# print('img_name:', img_name)
|
||||||
|
|
||||||
return img_data
|
return img_data, img_name
|
||||||
|
|
||||||
def get_train_data(self, data_dir=None, plot_2d=False, plot_data_3d=False, save_data=False):
|
def get_train_data(self, data_dir=None, plot_2d=False, plot_data_3d=False, save_data=False):
|
||||||
"""
|
"""
|
||||||
@ -385,9 +387,9 @@ class WoodClass(object):
|
|||||||
:return: x_data, y_data
|
:return: x_data, y_data
|
||||||
"""
|
"""
|
||||||
data_dir = os.path.join(ROOT_DIR, "data", "data20220919") if data_dir is None else data_dir
|
data_dir = os.path.join(ROOT_DIR, "data", "data20220919") if data_dir is None else data_dir
|
||||||
dark_data = self.get_image_data(img_dir=os.path.join(data_dir, "dark"))
|
dark_data, dark_name = self.get_image_data(img_dir=os.path.join(data_dir, "dark"))
|
||||||
middle_data = self.get_image_data(img_dir=os.path.join(data_dir, "middle"))
|
middle_data, middle_name = self.get_image_data(img_dir=os.path.join(data_dir, "middle"))
|
||||||
light_data = self.get_image_data(img_dir=os.path.join(data_dir, "light"))
|
light_data, light_name = self.get_image_data(img_dir=os.path.join(data_dir, "light"))
|
||||||
if (dark_data is False) or (middle_data is False) or (light_data is False):
|
if (dark_data is False) or (middle_data is False) or (light_data is False):
|
||||||
return False
|
return False
|
||||||
x_data = np.vstack((dark_data, middle_data, light_data))
|
x_data = np.vstack((dark_data, middle_data, light_data))
|
||||||
@ -396,45 +398,47 @@ class WoodClass(object):
|
|||||||
light_label = 2 * np.ones(len(light_data)).T
|
light_label = 2 * np.ones(len(light_data)).T
|
||||||
y_data = np.hstack((dark_label, middle_label, light_label))
|
y_data = np.hstack((dark_label, middle_label, light_label))
|
||||||
x_data = x_data[:, FEATURE_INDEX]
|
x_data = x_data[:, FEATURE_INDEX]
|
||||||
|
# dark_name, middle_name, light_name三个list合并
|
||||||
|
img_name = dark_name + middle_name + light_name
|
||||||
|
|
||||||
# 使用KMeans算法对图片数据进行聚类
|
# # 使用KMeans算法对图片数据进行聚类
|
||||||
kmeans = KMeans(n_clusters=3, random_state=0).fit(x_data)
|
# kmeans = KMeans(n_clusters=3, random_state=0).fit(x_data)
|
||||||
z = kmeans.predict(x_data)
|
# z = kmeans.predict(x_data)
|
||||||
# 获取聚类后的数据
|
# # 获取聚类后的数据
|
||||||
dark = x_data[kmeans.labels_ == 0]
|
# dark = x_data[kmeans.labels_ == 0]
|
||||||
middle = x_data[kmeans.labels_ == 1]
|
# middle = x_data[kmeans.labels_ == 1]
|
||||||
light = x_data[kmeans.labels_ == 2]
|
# light = x_data[kmeans.labels_ == 2]
|
||||||
# 获取数据的均值
|
# # 获取数据的均值
|
||||||
dark_mean = np.mean(dark, axis=0)
|
# dark_mean = np.mean(dark, axis=0)
|
||||||
middle_mean = np.mean(middle, axis=0)
|
# middle_mean = np.mean(middle, axis=0)
|
||||||
light_mean = np.mean(light, axis=0)
|
# light_mean = np.mean(light, axis=0)
|
||||||
|
#
|
||||||
# 按照平均值从小到大排序
|
# # 按照平均值从小到大排序
|
||||||
sorted_cluster_indices = np.argsort([dark_mean[0], middle_mean[0], light_mean[0]])
|
# sorted_cluster_indices = np.argsort([dark_mean[0], middle_mean[0], light_mean[0]])
|
||||||
print('sorted_cluster_indices:', sorted_cluster_indices)
|
# print('sorted_cluster_indices:', sorted_cluster_indices)
|
||||||
# 重新编号聚类标签
|
# # 重新编号聚类标签
|
||||||
sorted_labels = np.zeros(len(kmeans.labels_), dtype=int)
|
# sorted_labels = np.zeros(len(kmeans.labels_), dtype=int)
|
||||||
for i, label in enumerate(kmeans.labels_):
|
# for i, label in enumerate(kmeans.labels_):
|
||||||
sorted_labels[i] = sorted_cluster_indices[label]
|
# sorted_labels[i] = sorted_cluster_indices[label]
|
||||||
# 更新kmeans.labels_
|
# # 更新kmeans.labels_
|
||||||
kmeans.labels_ = sorted_labels
|
# kmeans.labels_ = sorted_labels
|
||||||
print('kmeans.labels_:', kmeans.labels_)
|
# print('kmeans.labels_:', kmeans.labels_)
|
||||||
# 获取更新聚类后的数据
|
# # 获取更新聚类后的数据
|
||||||
dark_new = x_data[kmeans.labels_ == 0]
|
# dark_new = x_data[kmeans.labels_ == 0]
|
||||||
middle_new = x_data[kmeans.labels_ == 1]
|
# middle_new = x_data[kmeans.labels_ == 1]
|
||||||
light_new = x_data[kmeans.labels_ == 2]
|
# light_new = x_data[kmeans.labels_ == 2]
|
||||||
# 获取更新数据的均值
|
# # 获取更新数据的均值
|
||||||
dark_mean_new = np.mean(dark_new, axis=0)
|
# dark_mean_new = np.mean(dark_new, axis=0)
|
||||||
middle_mean_new = np.mean(middle_new, axis=0)
|
# middle_mean_new = np.mean(middle_new, axis=0)
|
||||||
light_mean_new = np.mean(light_new, axis=0)
|
# light_mean_new = np.mean(light_new, axis=0)
|
||||||
# 打印每个聚类的平均值
|
# # 打印每个聚类的平均值
|
||||||
print('Dark cluster mean:', dark_mean_new)
|
# print('Dark cluster mean:', dark_mean_new)
|
||||||
print('Middle cluster mean:', middle_mean_new)
|
# print('Middle cluster mean:', middle_mean_new)
|
||||||
print('Light cluster mean:', light_mean_new)
|
# print('Light cluster mean:', light_mean_new)
|
||||||
# plot_2d
|
# # plot_2d
|
||||||
plt.figure()
|
# plt.figure()
|
||||||
plt.scatter(x_data[:, 0], x_data[:, 1], c=z)
|
# plt.scatter(x_data[:, 0], x_data[:, 1], c=z)
|
||||||
plt.show()
|
# plt.show()
|
||||||
|
|
||||||
# 进行色彩数据校正
|
# 进行色彩数据校正
|
||||||
if self.isCorrect:
|
if self.isCorrect:
|
||||||
@ -464,7 +468,7 @@ class WoodClass(object):
|
|||||||
if save_data:
|
if save_data:
|
||||||
with open(os.path.join("data", "data.p"), "rb") as f:
|
with open(os.path.join("data", "data.p"), "rb") as f:
|
||||||
pass
|
pass
|
||||||
return x_data, y_data
|
return x_data, y_data, img_name
|
||||||
|
|
||||||
def realtime_correction(self, img):
|
def realtime_correction(self, img):
|
||||||
"""
|
"""
|
||||||
@ -503,6 +507,40 @@ class WoodClass(object):
|
|||||||
img = np.clip(img, 0, 255).astype(dtype=np.uint8)
|
img = np.clip(img, 0, 255).astype(dtype=np.uint8)
|
||||||
return img
|
return img
|
||||||
|
|
||||||
|
def get_kmeans_data(self, data_dir=None, plot_2d=False):
|
||||||
|
"""
|
||||||
|
获取kmeans数据
|
||||||
|
:param data_dir: 图片路径
|
||||||
|
:param plot_2d: 是否绘制二维图
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
x_data, y_data, img_names = self.get_train_data(data_dir, plot_2d=plot_2d)
|
||||||
|
kmeans = KMeans(n_clusters=3, random_state=0).fit(x_data)
|
||||||
|
|
||||||
|
# 获取聚类后的数据
|
||||||
|
dark = x_data[kmeans.labels_ == 0]
|
||||||
|
middle = x_data[kmeans.labels_ == 1]
|
||||||
|
light = x_data[kmeans.labels_ == 2]
|
||||||
|
|
||||||
|
dark_mean = np.mean(dark, axis=0)
|
||||||
|
middle_mean = np.mean(middle, axis=0)
|
||||||
|
light_mean = np.mean(light, axis=0)
|
||||||
|
|
||||||
|
sorted_cluster_indices = np.argsort([dark_mean[0], middle_mean[0], light_mean[0]])
|
||||||
|
labels = kmeans.labels_
|
||||||
|
for i in range(labels.shape[0]):
|
||||||
|
labels[i] = sorted_cluster_indices[labels[i]]
|
||||||
|
|
||||||
|
|
||||||
|
if plot_2d:
|
||||||
|
plt.figure()
|
||||||
|
plt.scatter(x_data[:, 0], x_data[:, 1], c=labels)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
return x_data, y_data, labels, img_names
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
from config import Config
|
from config import Config
|
||||||
@ -517,6 +555,9 @@ if __name__ == '__main__':
|
|||||||
# fit 相应的文件夹
|
# fit 相应的文件夹
|
||||||
settings.model_path = str(ROOT_DIR / 'models' / wood.fit_pictures(data_path=data_path))
|
settings.model_path = str(ROOT_DIR / 'models' / wood.fit_pictures(data_path=data_path))
|
||||||
|
|
||||||
|
wood.get_kmeans_data(data_path, plot_2d=True)
|
||||||
|
|
||||||
|
|
||||||
# 测试单张图片的预测,predict_mode=True表示导入本地的model, False为现场训练的
|
# 测试单张图片的预测,predict_mode=True表示导入本地的model, False为现场训练的
|
||||||
pic = cv2.imread(r"data/318/dark/rgb89.png")
|
pic = cv2.imread(r"data/318/dark/rgb89.png")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user