mirror of
https://github.com/NanjingForestryUniversity/supermachine-wood.git
synced 2025-11-08 10:13:53 +00:00
尝试加个功能
This commit is contained in:
parent
f79393100b
commit
f593a67e76
52
classifer.py
52
classifer.py
@ -31,7 +31,7 @@ sys.path.append(os.getcwd())
|
|||||||
from root_dir import ROOT_DIR
|
from root_dir import ROOT_DIR
|
||||||
import utils
|
import utils
|
||||||
|
|
||||||
FEATURE_INDEX = [0,1,2,3,4,5]
|
FEATURE_INDEX = [0,1,2]
|
||||||
delete_columns = 10 # 已弃用
|
delete_columns = 10 # 已弃用
|
||||||
num_bins = 10
|
num_bins = 10
|
||||||
|
|
||||||
@ -367,6 +367,16 @@ class WoodClass(object):
|
|||||||
data = self.extract_feature(train_img)
|
data = self.extract_feature(train_img)
|
||||||
img_data.append(data)
|
img_data.append(data)
|
||||||
img_data = np.array(img_data)
|
img_data = np.array(img_data)
|
||||||
|
|
||||||
|
# 提取图像名称
|
||||||
|
img_name = [os.path.splitext(file)[0] for file in files]
|
||||||
|
# 提取每个图像名称中的数字
|
||||||
|
img_name = [name[3:] for name in img_name]
|
||||||
|
# 将图像名称个位数前补零
|
||||||
|
img_name = [name.zfill(2) for name in img_name]
|
||||||
|
# 打印图像名称
|
||||||
|
print('img_name:', img_name)
|
||||||
|
|
||||||
return img_data
|
return img_data
|
||||||
|
|
||||||
def get_train_data(self, data_dir=None, plot_2d=False, plot_data_3d=False, save_data=False):
|
def get_train_data(self, data_dir=None, plot_2d=False, plot_data_3d=False, save_data=False):
|
||||||
@ -386,6 +396,46 @@ class WoodClass(object):
|
|||||||
light_label = 2 * np.ones(len(light_data)).T
|
light_label = 2 * np.ones(len(light_data)).T
|
||||||
y_data = np.hstack((dark_label, middle_label, light_label))
|
y_data = np.hstack((dark_label, middle_label, light_label))
|
||||||
x_data = x_data[:, FEATURE_INDEX]
|
x_data = x_data[:, FEATURE_INDEX]
|
||||||
|
|
||||||
|
# 使用KMeans算法对图片数据进行聚类
|
||||||
|
kmeans = KMeans(n_clusters=3, random_state=0).fit(x_data)
|
||||||
|
z = kmeans.predict(x_data)
|
||||||
|
# 获取聚类后的数据
|
||||||
|
dark = x_data[kmeans.labels_ == 0]
|
||||||
|
middle = x_data[kmeans.labels_ == 1]
|
||||||
|
light = x_data[kmeans.labels_ == 2]
|
||||||
|
# 获取数据的均值
|
||||||
|
dark_mean = np.mean(dark, axis=0)
|
||||||
|
middle_mean = np.mean(middle, axis=0)
|
||||||
|
light_mean = np.mean(light, axis=0)
|
||||||
|
|
||||||
|
# 按照平均值从小到大排序
|
||||||
|
sorted_cluster_indices = np.argsort([dark_mean[0], middle_mean[0], light_mean[0]])
|
||||||
|
print('sorted_cluster_indices:', sorted_cluster_indices)
|
||||||
|
# 重新编号聚类标签
|
||||||
|
sorted_labels = np.zeros(len(kmeans.labels_), dtype=int)
|
||||||
|
for i, label in enumerate(kmeans.labels_):
|
||||||
|
sorted_labels[i] = sorted_cluster_indices[label]
|
||||||
|
# 更新kmeans.labels_
|
||||||
|
kmeans.labels_ = sorted_labels
|
||||||
|
print('kmeans.labels_:', kmeans.labels_)
|
||||||
|
# 获取更新聚类后的数据
|
||||||
|
dark_new = x_data[kmeans.labels_ == 0]
|
||||||
|
middle_new = x_data[kmeans.labels_ == 1]
|
||||||
|
light_new = x_data[kmeans.labels_ == 2]
|
||||||
|
# 获取更新数据的均值
|
||||||
|
dark_mean_new = np.mean(dark_new, axis=0)
|
||||||
|
middle_mean_new = np.mean(middle_new, axis=0)
|
||||||
|
light_mean_new = np.mean(light_new, axis=0)
|
||||||
|
# 打印每个聚类的平均值
|
||||||
|
print('Dark cluster mean:', dark_mean_new)
|
||||||
|
print('Middle cluster mean:', middle_mean_new)
|
||||||
|
print('Light cluster mean:', light_mean_new)
|
||||||
|
# plot_2d
|
||||||
|
plt.figure()
|
||||||
|
plt.scatter(x_data[:, 0], x_data[:, 1], c=z)
|
||||||
|
plt.show()
|
||||||
|
|
||||||
# 进行色彩数据校正
|
# 进行色彩数据校正
|
||||||
if self.isCorrect:
|
if self.isCorrect:
|
||||||
x_data = x_data / (self.correct_color + 1e-4)
|
x_data = x_data / (self.correct_color + 1e-4)
|
||||||
|
|||||||
@ -45,6 +45,8 @@ def process_cmd(cmd: str, data: any, connected_sock: socket.socket, detector: Wo
|
|||||||
settings.model_path = data
|
settings.model_path = data
|
||||||
detector.load(path=settings.model_path)
|
detector.load(path=settings.model_path)
|
||||||
response = simple_sock(connected_sock, cmd_type=cmd)
|
response = simple_sock(connected_sock, cmd_type=cmd)
|
||||||
|
elif cmd == 'DT':
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
logging.error(f'错误指令,指令为{cmd}')
|
logging.error(f'错误指令,指令为{cmd}')
|
||||||
response = False
|
response = False
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user