mirror of
https://github.com/NanjingForestryUniversity/supermachine-wood.git
synced 2025-11-08 10:13:53 +00:00
修改了自动保存模型路径
This commit is contained in:
parent
f1cf707f6e
commit
4bb299d02b
16
classifer.py
16
classifer.py
@ -19,6 +19,9 @@ import matplotlib.pyplot as plt
|
|||||||
import time
|
import time
|
||||||
import pickle
|
import pickle
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
import config
|
||||||
|
|
||||||
sys.path.append(os.getcwd())
|
sys.path.append(os.getcwd())
|
||||||
from root_dir import ROOT_DIR
|
from root_dir import ROOT_DIR
|
||||||
import utils
|
import utils
|
||||||
@ -99,13 +102,14 @@ class WoodClass(object):
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
# 训练数据文件位置
|
# 训练数据文件位置
|
||||||
result = self.get_train_data(data_path)
|
result = self.get_train_data(data_path, plot_2d=False)
|
||||||
if result is False:
|
if result is False:
|
||||||
return 0
|
return 0
|
||||||
x, y = result
|
x, y = result
|
||||||
score = self.fit(x, y)
|
score = self.fit(x, y)
|
||||||
self.save()
|
print('model score', score)
|
||||||
return score
|
model_name = self.save()
|
||||||
|
return model_name
|
||||||
|
|
||||||
def fit(self, x, y, test_size=0.1):
|
def fit(self, x, y, test_size=0.1):
|
||||||
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=test_size, random_state=0)
|
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=test_size, random_state=0)
|
||||||
@ -207,6 +211,7 @@ class WoodClass(object):
|
|||||||
with open(file_name, "wb") as f:
|
with open(file_name, "wb") as f:
|
||||||
pickle.dump(model_dic, f)
|
pickle.dump(model_dic, f)
|
||||||
self.log.log("Save file to '" + str(file_name) + "'")
|
self.log.log("Save file to '" + str(file_name) + "'")
|
||||||
|
return file_name
|
||||||
|
|
||||||
def load(self, path=None):
|
def load(self, path=None):
|
||||||
if path is None:
|
if path is None:
|
||||||
@ -352,13 +357,16 @@ class WoodClass(object):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
from config import Config
|
||||||
|
settings = Config()
|
||||||
# 初始化wood
|
# 初始化wood
|
||||||
wood = WoodClass(w=4096, h=1200, n=5000, p1=0.4, debug_mode=False)
|
wood = WoodClass(w=4096, h=1200, n=5000, p1=0.4, debug_mode=False)
|
||||||
print("色彩纯度控制量{}/{}".format(wood.k, wood.n))
|
print("色彩纯度控制量{}/{}".format(wood.k, wood.n))
|
||||||
|
data_path = r"C:\Users\FEIJINTI\PycharmProjects\wood_color\data\data1015"
|
||||||
# wood.correct()
|
# wood.correct()
|
||||||
# wood.load()
|
# wood.load()
|
||||||
# fit 相应的文件夹
|
# fit 相应的文件夹
|
||||||
wood.fit_pictures(data_path=r"C:\Users\FEIJINTI\PycharmProjects\wood_color\data\data20220919")
|
settings.model_path = ROOT_DIR / 'models' / wood.fit_pictures(data_path=data_path)
|
||||||
|
|
||||||
# 测试单张图片的预测,predict_mode=True表示导入本地的model, False为现场训练的
|
# 测试单张图片的预测,predict_mode=True表示导入本地的model, False为现场训练的
|
||||||
# pic = cv2.imread(r"./data/dark/rgb60.png")
|
# pic = cv2.imread(r"./data/dark/rgb60.png")
|
||||||
|
|||||||
57
config.py
Normal file
57
config.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
# -*- codeing = utf-8 -*-
|
||||||
|
# Time : 2022/10/17 11:07
|
||||||
|
# @Auther : zhouchao
|
||||||
|
# @File: config.py
|
||||||
|
# @Software:PyCharm
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import WindowsPath
|
||||||
|
|
||||||
|
from root_dir import ROOT_DIR
|
||||||
|
|
||||||
|
|
||||||
|
class Config(object):
|
||||||
|
model_path = ROOT_DIR / 'config.json'
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._param_dict = {}
|
||||||
|
if os.path.exists(Config.model_path):
|
||||||
|
self._read()
|
||||||
|
else:
|
||||||
|
self.model_path = str(ROOT_DIR / 'models/model_2022-10-17_11-10.p')
|
||||||
|
self.data_path = str(ROOT_DIR / 'data/data20220919')
|
||||||
|
self._param_dict['model_path'] = self.model_path
|
||||||
|
self._param_dict['data_path'] = self.data_path
|
||||||
|
|
||||||
|
def __setitem__(self, key, value):
|
||||||
|
if key in self._param_dict:
|
||||||
|
self._param_dict[key] = value
|
||||||
|
self._write()
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
if item in self._param_dict:
|
||||||
|
return self._param_dict[item]
|
||||||
|
|
||||||
|
def __setattr__(self, key, value):
|
||||||
|
self.__dict__[key] = value
|
||||||
|
if '_param_dict' in self.__dict__ and key != '_param_dict':
|
||||||
|
if isinstance(value, WindowsPath):
|
||||||
|
value = str(value)
|
||||||
|
self.__dict__['_param_dict'][key] = value
|
||||||
|
self._write()
|
||||||
|
|
||||||
|
def _read(self):
|
||||||
|
with open(Config.model_path, 'r') as f:
|
||||||
|
self._param_dict = json.load(f)
|
||||||
|
self.data_path = self._param_dict['data_path']
|
||||||
|
self.model_path = self._param_dict['model_path']
|
||||||
|
|
||||||
|
def _write(self):
|
||||||
|
with open(Config.model_path, 'w') as f:
|
||||||
|
json.dump(self._param_dict, f)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
config = Config()
|
||||||
|
print(config.model_path)
|
||||||
|
print(config.data_path)
|
||||||
@ -5,7 +5,7 @@ Created on Nov 3 21:18:26 2020
|
|||||||
@author: l.z.y
|
@author: l.z.y
|
||||||
@e-mail: li.zhenye@qq.com
|
@e-mail: li.zhenye@qq.com
|
||||||
"""
|
"""
|
||||||
import os
|
import pathlib
|
||||||
|
|
||||||
# ROOT_DIR = r"C:\Users\FEIJINTI\PycharmProjects\wood_color"
|
file_path = pathlib.Path(__file__)
|
||||||
ROOT_DIR = r"/Users/zhouchao/Library/CloudStorage/OneDrive-macrosolid/PycharmProjects/wood_color"
|
ROOT_DIR = file_path.parent
|
||||||
|
|||||||
@ -12,9 +12,10 @@ import os
|
|||||||
from root_dir import ROOT_DIR
|
from root_dir import ROOT_DIR
|
||||||
from utils import PreSocket, receive_sock, parse_protocol, ack_sock, done_sock, DualSock, simple_sock
|
from utils import PreSocket, receive_sock, parse_protocol, ack_sock, done_sock, DualSock, simple_sock
|
||||||
import logging
|
import logging
|
||||||
|
from config import Config
|
||||||
|
|
||||||
|
|
||||||
def process_cmd(cmd: str, data: any, connected_sock: socket.socket, detector: WoodClass) -> bool:
|
def process_cmd(cmd: str, data: any, connected_sock: socket.socket, detector: WoodClass, settings: Config) -> bool:
|
||||||
"""
|
"""
|
||||||
处理指令
|
处理指令
|
||||||
|
|
||||||
@ -30,10 +31,12 @@ def process_cmd(cmd: str, data: any, connected_sock: socket.socket, detector: Wo
|
|||||||
response = simple_sock(connected_sock, cmd_type=cmd, result=wood_color)
|
response = simple_sock(connected_sock, cmd_type=cmd, result=wood_color)
|
||||||
elif cmd == 'TR':
|
elif cmd == 'TR':
|
||||||
detector = WoodClass(w=4096, h=1200, n=3000, debug_mode=False)
|
detector = WoodClass(w=4096, h=1200, n=3000, debug_mode=False)
|
||||||
detector.fit_pictures(data_path=data)
|
settings.data_path = data
|
||||||
|
settings.model_path = ROOT_DIR / 'models' / detector.fit_pictures(data_path=settings.data_path)
|
||||||
response = simple_sock(connected_sock, cmd_type=cmd)
|
response = simple_sock(connected_sock, cmd_type=cmd)
|
||||||
elif cmd == 'MD':
|
elif cmd == 'MD':
|
||||||
detector.load(path=data)
|
settings.model_path = data
|
||||||
|
detector.load(path=settings.model_path)
|
||||||
response = simple_sock(connected_sock, cmd_type=cmd)
|
response = simple_sock(connected_sock, cmd_type=cmd)
|
||||||
else:
|
else:
|
||||||
logging.error(f'错误指令,指令为{cmd}')
|
logging.error(f'错误指令,指令为{cmd}')
|
||||||
@ -42,6 +45,7 @@ def process_cmd(cmd: str, data: any, connected_sock: socket.socket, detector: Wo
|
|||||||
|
|
||||||
|
|
||||||
def main(is_debug=False):
|
def main(is_debug=False):
|
||||||
|
settings = Config()
|
||||||
file_handler = logging.FileHandler(os.path.join(ROOT_DIR, 'report.log'))
|
file_handler = logging.FileHandler(os.path.join(ROOT_DIR, 'report.log'))
|
||||||
file_handler.setLevel(logging.DEBUG if is_debug else logging.WARNING)
|
file_handler.setLevel(logging.DEBUG if is_debug else logging.WARNING)
|
||||||
console_handler = logging.StreamHandler(sys.stdout)
|
console_handler = logging.StreamHandler(sys.stdout)
|
||||||
@ -52,10 +56,8 @@ def main(is_debug=False):
|
|||||||
|
|
||||||
while not dual_sock.status:
|
while not dual_sock.status:
|
||||||
dual_sock.reconnect()
|
dual_sock.reconnect()
|
||||||
model_path = os.path.join(ROOT_DIR, r"models/model_2022-09-28_13-15.p")
|
|
||||||
# model_path = os.path.join(ROOT_DIR, r"models\model_2022-09-28_13-15.p")
|
|
||||||
detector = WoodClass(w=4096, h=1200, n=3000, debug_mode=False)
|
detector = WoodClass(w=4096, h=1200, n=3000, debug_mode=False)
|
||||||
detector.load(path=model_path)
|
detector.load(path=settings.model_path)
|
||||||
_ = detector.predict(np.random.randint(1, 254, (1200, 4096, 3), dtype=np.uint8))
|
_ = detector.predict(np.random.randint(1, 254, (1200, 4096, 3), dtype=np.uint8))
|
||||||
while True:
|
while True:
|
||||||
pack, next_pack = receive_sock(dual_sock)
|
pack, next_pack = receive_sock(dual_sock)
|
||||||
@ -66,7 +68,7 @@ def main(is_debug=False):
|
|||||||
|
|
||||||
cmd, data = parse_protocol(pack)
|
cmd, data = parse_protocol(pack)
|
||||||
# ack_sock(received_sock, cmd_type=cmd)
|
# ack_sock(received_sock, cmd_type=cmd)
|
||||||
process_cmd(cmd=cmd, data=data, connected_sock=dual_sock, detector=detector)
|
process_cmd(cmd=cmd, data=data, connected_sock=dual_sock, detector=detector, settings=settings)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user