添加了参数

This commit is contained in:
FEIJINTI 2023-03-06 15:35:40 +08:00
parent bf45d9360a
commit b12bb52a6d

View File

@ -14,6 +14,7 @@ from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score from sklearn.metrics import accuracy_score
from sklearn.tree import DecisionTreeClassifier from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from scipy.stats import binom from scipy.stats import binom
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import time import time
@ -26,7 +27,7 @@ sys.path.append(os.getcwd())
from root_dir import ROOT_DIR from root_dir import ROOT_DIR
import utils import utils
FEATURE_INDEX = [1, 2] FEATURE_INDEX = [0, 1]
class WoodClass(object): class WoodClass(object):
@ -53,6 +54,7 @@ class WoodClass(object):
self.set_purity(self.pur) self.set_purity(self.pur)
self.change_pick_mode(single_pick_mode) self.change_pick_mode(single_pick_mode)
self.model = LogisticRegression(C=1e5) self.model = LogisticRegression(C=1e5)
# self.model = KNeighborsClassifier()
else: else:
self.load(load_from) self.load(load_from)
self.isCorrect = False self.isCorrect = False
@ -102,7 +104,7 @@ class WoodClass(object):
:return: :return:
""" """
# 训练数据文件位置 # 训练数据文件位置
result = self.get_train_data(data_path, plot_2d=False) result = self.get_train_data(data_path, plot_2d=True)
if result is False: if result is False:
return 0 return 0
x, y = result x, y = result
@ -150,6 +152,9 @@ class WoodClass(object):
feature = feature.reshape(1, -1)[:, FEATURE_INDEX] feature = feature.reshape(1, -1)[:, FEATURE_INDEX]
if self.isCorrect: if self.isCorrect:
feature = feature / (self.correct_color+1e-4) feature = feature / (self.correct_color+1e-4)
plt.figure()
plt.scatter(feature[:, 0], feature[:, 1])
plt.show()
pred_color = self.model.predict(feature) pred_color = self.model.predict(feature)
if self.debug_mode: if self.debug_mode:
self.log.log(feature) self.log.log(feature)
@ -360,20 +365,20 @@ if __name__ == '__main__':
from config import Config from config import Config
settings = Config() settings = Config()
# 初始化wood # 初始化wood
wood = WoodClass(w=4096, h=1200, n=5000, p1=0.4, debug_mode=False) wood = WoodClass(w=4096, h=1200, n=12500, p1=0.4, debug_mode=False)
print("色彩纯度控制量{}/{}".format(wood.k, wood.n)) print("色彩纯度控制量{}/{}".format(wood.k, wood.n))
data_path = r"C:\Users\FEIJINTI\PycharmProjects\wood_color\data\data1015" data_path = settings.data_path
# wood.correct() # wood.correct()
# wood.load() # wood.load()
# fit 相应的文件夹 # fit 相应的文件夹
settings.model_path = ROOT_DIR / 'models' / wood.fit_pictures(data_path=data_path) settings.model_path = str(ROOT_DIR / 'models' / wood.fit_pictures(data_path=data_path))
# 测试单张图片的预测predict_mode=True表示导入本地的model, False为现场训练的 # 测试单张图片的预测predict_mode=True表示导入本地的model, False为现场训练的
# pic = cv2.imread(r"./data/dark/rgb60.png") pic = cv2.imread(r"data/duizhao/rgb7.png")
# start_time = time.time() start_time = time.time()
# for i in range(100): # for i in range(100):
# wood_color = wood.predict(pic) wood_color = wood.predict(pic)
# end_time = time.time() end_time = time.time()
# print("time consume:"+str((end_time - start_time)/100)) print("time consume:"+str((end_time - start_time)/100))
# print("wood_color:"+str(wood_color)) print("wood_color:"+str(wood_color))