调了一个还不错的参数

This commit is contained in:
duanmu 2023-03-16 21:39:36 +08:00
parent b37e064594
commit af148ed971

View File

@ -31,7 +31,7 @@ sys.path.append(os.getcwd())
from root_dir import ROOT_DIR from root_dir import ROOT_DIR
import utils import utils
FEATURE_INDEX = [0,1,2] FEATURE_INDEX = [1,2,3]
delete_columns = 10 # 已弃用 delete_columns = 10 # 已弃用
num_bins = 10 num_bins = 10
@ -60,10 +60,10 @@ class WoodClass(object):
self._single_pick = single_pick_mode self._single_pick = single_pick_mode
self.set_purity(self.pur) self.set_purity(self.pur)
self.change_pick_mode(single_pick_mode) self.change_pick_mode(single_pick_mode)
# self.model = LogisticRegression(C=1e5) self.model = LogisticRegression(C=1e5)
self.left_correct = left_correct self.left_correct = left_correct
# self.model = KNeighborsClassifier() # self.model = KNeighborsClassifier()
self.model = DecisionTreeClassifier() # self.model = DecisionTreeClassifier()
else: else:
self.load(load_from) self.load(load_from)
self.isCorrect = False self.isCorrect = False
@ -122,7 +122,7 @@ class WoodClass(object):
model_name = self.save(file_name) model_name = self.save(file_name)
return model_name return model_name
def fit(self, x, y, test_size=0.3): def fit(self, x, y, test_size=0.7):
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=test_size, random_state=0) x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=test_size, random_state=0)
self.model.fit(x_train, y_train) self.model.fit(x_train, y_train)
y_pred = self.model.predict(x_test) y_pred = self.model.predict(x_test)
@ -317,7 +317,7 @@ class WoodClass(object):
(x[:, 0] > bins[second_hist_number]) & (x[:, 0] < bins[second_hist_number + 1])), :] (x[:, 0] > bins[second_hist_number]) & (x[:, 0] < bins[second_hist_number + 1])), :]
# hist, bins = np.histogram(x[:, 0], bins=5) # hist, bins = np.histogram(x[:, 0], bins=9)
# sorted_indices = np.argsort(hist) # sorted_indices = np.argsort(hist)
# hist_number = sorted_indices[-1] # hist_number = sorted_indices[-1]
# second_hist_number = sorted_indices[-2] # second_hist_number = sorted_indices[-2]