添加了参数

This commit is contained in:
FEIJINTI 2023-03-06 15:35:40 +08:00
parent bf45d9360a
commit b12bb52a6d

View File

@ -14,6 +14,7 @@ from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from scipy.stats import binom
import matplotlib.pyplot as plt
import time
@ -26,7 +27,7 @@ sys.path.append(os.getcwd())
from root_dir import ROOT_DIR
import utils
FEATURE_INDEX = [1, 2]
FEATURE_INDEX = [0, 1]
class WoodClass(object):
@ -53,6 +54,7 @@ class WoodClass(object):
self.set_purity(self.pur)
self.change_pick_mode(single_pick_mode)
self.model = LogisticRegression(C=1e5)
# self.model = KNeighborsClassifier()
else:
self.load(load_from)
self.isCorrect = False
@ -102,7 +104,7 @@ class WoodClass(object):
:return:
"""
# 训练数据文件位置
result = self.get_train_data(data_path, plot_2d=False)
result = self.get_train_data(data_path, plot_2d=True)
if result is False:
return 0
x, y = result
@ -150,6 +152,9 @@ class WoodClass(object):
feature = feature.reshape(1, -1)[:, FEATURE_INDEX]
if self.isCorrect:
feature = feature / (self.correct_color+1e-4)
plt.figure()
plt.scatter(feature[:, 0], feature[:, 1])
plt.show()
pred_color = self.model.predict(feature)
if self.debug_mode:
self.log.log(feature)
@ -360,20 +365,20 @@ if __name__ == '__main__':
from config import Config
settings = Config()
# 初始化wood
wood = WoodClass(w=4096, h=1200, n=5000, p1=0.4, debug_mode=False)
wood = WoodClass(w=4096, h=1200, n=12500, p1=0.4, debug_mode=False)
print("色彩纯度控制量{}/{}".format(wood.k, wood.n))
data_path = r"C:\Users\FEIJINTI\PycharmProjects\wood_color\data\data1015"
data_path = settings.data_path
# wood.correct()
# wood.load()
# fit 相应的文件夹
settings.model_path = ROOT_DIR / 'models' / wood.fit_pictures(data_path=data_path)
settings.model_path = str(ROOT_DIR / 'models' / wood.fit_pictures(data_path=data_path))
# 测试单张图片的预测predict_mode=True表示导入本地的model, False为现场训练的
# pic = cv2.imread(r"./data/dark/rgb60.png")
# start_time = time.time()
pic = cv2.imread(r"data/duizhao/rgb7.png")
start_time = time.time()
# for i in range(100):
# wood_color = wood.predict(pic)
# end_time = time.time()
# print("time consume:"+str((end_time - start_time)/100))
# print("wood_color:"+str(wood_color))
wood_color = wood.predict(pic)
end_time = time.time()
print("time consume:"+str((end_time - start_time)/100))
print("wood_color:"+str(wood_color))