From b12bb52a6d71479e358a1f51889ff4cca0906449 Mon Sep 17 00:00:00 2001 From: FEIJINTI <83849113+FEIJINTI@users.noreply.github.com> Date: Mon, 6 Mar 2023 15:35:40 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86=E5=8F=82=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- classifer.py | 27 ++++++++++++++++----------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/classifer.py b/classifer.py index 07ddfb2..fc2a5a1 100644 --- a/classifer.py +++ b/classifer.py @@ -14,6 +14,7 @@ from sklearn.linear_model import LogisticRegression from sklearn.model_selection import train_test_split from sklearn.metrics import accuracy_score from sklearn.tree import DecisionTreeClassifier +from sklearn.neighbors import KNeighborsClassifier from scipy.stats import binom import matplotlib.pyplot as plt import time @@ -26,7 +27,7 @@ sys.path.append(os.getcwd()) from root_dir import ROOT_DIR import utils -FEATURE_INDEX = [1, 2] +FEATURE_INDEX = [0, 1] class WoodClass(object): @@ -53,6 +54,7 @@ class WoodClass(object): self.set_purity(self.pur) self.change_pick_mode(single_pick_mode) self.model = LogisticRegression(C=1e5) + # self.model = KNeighborsClassifier() else: self.load(load_from) self.isCorrect = False @@ -102,7 +104,7 @@ class WoodClass(object): :return: """ # 训练数据文件位置 - result = self.get_train_data(data_path, plot_2d=False) + result = self.get_train_data(data_path, plot_2d=True) if result is False: return 0 x, y = result @@ -150,6 +152,9 @@ class WoodClass(object): feature = feature.reshape(1, -1)[:, FEATURE_INDEX] if self.isCorrect: feature = feature / (self.correct_color+1e-4) + plt.figure() + plt.scatter(feature[:, 0], feature[:, 1]) + plt.show() pred_color = self.model.predict(feature) if self.debug_mode: self.log.log(feature) @@ -360,20 +365,20 @@ if __name__ == '__main__': from config import Config settings = Config() # 初始化wood - wood = WoodClass(w=4096, h=1200, n=5000, p1=0.4, debug_mode=False) + wood = WoodClass(w=4096, h=1200, n=12500, p1=0.4, debug_mode=False) print("色彩纯度控制量{}/{}".format(wood.k, wood.n)) - data_path = r"C:\Users\FEIJINTI\PycharmProjects\wood_color\data\data1015" + data_path = settings.data_path # wood.correct() # wood.load() # fit 相应的文件夹 - settings.model_path = ROOT_DIR / 'models' / wood.fit_pictures(data_path=data_path) + settings.model_path = str(ROOT_DIR / 'models' / wood.fit_pictures(data_path=data_path)) # 测试单张图片的预测,predict_mode=True表示导入本地的model, False为现场训练的 - # pic = cv2.imread(r"./data/dark/rgb60.png") - # start_time = time.time() + pic = cv2.imread(r"data/duizhao/rgb7.png") + start_time = time.time() # for i in range(100): - # wood_color = wood.predict(pic) - # end_time = time.time() - # print("time consume:"+str((end_time - start_time)/100)) - # print("wood_color:"+str(wood_color)) + wood_color = wood.predict(pic) + end_time = time.time() + print("time consume:"+str((end_time - start_time)/100)) + print("wood_color:"+str(wood_color))