diff --git a/classifer.py b/classifer.py index fc2a5a1..0805821 100644 --- a/classifer.py +++ b/classifer.py @@ -97,7 +97,7 @@ class WoodClass(object): sample = cv2.resize(x, (self.ww, self.hh)) return sample - def fit_pictures(self, data_path=ROOT_DIR): + def fit_pictures(self, data_path=ROOT_DIR, file_name=None): """ 根据给出的data_path 进行 fit.如果没有给出data目录,那么将会使用当前文件夹 :param data_path: @@ -110,7 +110,7 @@ class WoodClass(object): x, y = result score = self.fit(x, y) print('model score', score) - model_name = self.save() + model_name = self.save(file_name) return model_name def fit(self, x, y, test_size=0.1): diff --git a/socket_detector.py b/socket_detector.py index a9468b1..4ce2e05 100644 --- a/socket_detector.py +++ b/socket_detector.py @@ -33,8 +33,12 @@ def process_cmd(cmd: str, data: any, connected_sock: socket.socket, detector: Wo response = simple_sock(connected_sock, cmd_type=cmd, result=wood_color) elif cmd == 'TR': detector = WoodClass(w=4096, h=1200, n=3000, debug_mode=False) + model_name = None + if "/n" in data: + data, model_name = data.split("/n", 1) + model_name = model_name + ".p" settings.data_path = data - settings.model_path = ROOT_DIR / 'models' / detector.fit_pictures(data_path=settings.data_path) + settings.model_path = ROOT_DIR / 'models' / detector.fit_pictures(data_path=settings.data_path, file_name=model_name) response = simple_sock(connected_sock, cmd_type=cmd) elif cmd == 'MD': settings.model_path = data