调了一个现场还不错的参数

This commit is contained in:
duanmu 2023-03-17 15:25:37 +08:00
parent af148ed971
commit 99a72c86ac
2 changed files with 13 additions and 11 deletions

View File

@ -31,7 +31,7 @@ sys.path.append(os.getcwd())
from root_dir import ROOT_DIR
import utils
FEATURE_INDEX = [1,2,3]
FEATURE_INDEX = [0,1,2,4]
delete_columns = 10 # 已弃用
num_bins = 10
@ -60,10 +60,10 @@ class WoodClass(object):
self._single_pick = single_pick_mode
self.set_purity(self.pur)
self.change_pick_mode(single_pick_mode)
self.model = LogisticRegression(C=1e5)
# self.model = LogisticRegression(C=1e5)
self.left_correct = left_correct
# self.model = KNeighborsClassifier()
# self.model = DecisionTreeClassifier()
self.model = DecisionTreeClassifier()
else:
self.load(load_from)
self.isCorrect = False
@ -113,7 +113,7 @@ class WoodClass(object):
:return:
"""
# 训练数据文件位置
result = self.get_train_data(data_path, plot_2d=True)
result = self.get_train_data(data_path, plot_2d=False)
if result is False:
return 0
x, y = result
@ -182,9 +182,11 @@ class WoodClass(object):
feature = feature.reshape(1, -1)[:, FEATURE_INDEX]
if self.isCorrect:
feature = feature / (self.correct_color + 1e-4)
plt.figure()
plt.scatter(feature[:, 0], feature[:, 1])
plt.show()
# plt.figure()
# plt.scatter(feature[:, 0], feature[:, 1])
# plt.show()
pred_color = self.model.predict(feature)
if self.debug_mode:
self.log.log(feature)

View File

@ -33,7 +33,7 @@ def process_cmd(cmd: str, data: any, connected_sock: socket.socket, detector: Wo
result = {0: 'dark', 1: 'middle', 2: 'light'}[wood_color]
response = simple_sock(connected_sock, cmd_type=cmd, result=wood_color)
elif cmd == 'TR':
detector = WoodClass(w=4096, h=1200, n=3000, debug_mode=False)
detector = WoodClass(w=4096, h=1200, n=8000, p1=0.8, debug_mode=False)
model_name = None
if ":" in data:
data, model_name = data.split(":", 1)
@ -66,7 +66,7 @@ def main(is_debug=False):
while not dual_sock.status:
dual_sock.reconnect()
detector = WoodClass(w=4096, h=1200, n=3000, debug_mode=False)
detector = WoodClass(w=4096, h=1200, n=8000, p1=0.8, debug_mode=False)
detector.load(path=settings.model_path)
_ = detector.predict(np.random.randint(1, 254, (1200, 4096, 3), dtype=np.uint8))
while True:
@ -80,8 +80,8 @@ def main(is_debug=False):
# ack_sock(received_sock, cmd_type=cmd)
response, result = process_cmd(cmd=cmd, data=data, connected_sock=dual_sock, detector=detector, settings=settings)
if result != "":
database.add_data(result)
# if result != "":
# database.add_data(result)