mirror of
https://github.com/Karllzy/cotton_color.git
synced 2025-11-08 18:53:53 +00:00
任务描述
This commit is contained in:
parent
5c9df03c21
commit
e3d7d3fefe
12
DPL/dgd_class/models.py
Normal file
12
DPL/dgd_class/models.py
Normal file
@ -0,0 +1,12 @@
|
||||
from symbol import pass_stmt
|
||||
|
||||
|
||||
class Model:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def load(self, weight_path: str):
|
||||
pass
|
||||
|
||||
def predict(self, weight_path: str):
|
||||
pass
|
||||
@ -19,7 +19,7 @@ def load_onnx_model(model_path='model/best_model_11.14.19.30.onnx'):
|
||||
|
||||
|
||||
# 预测函数
|
||||
def visualize_model_predictions(onnx_session, img_path):
|
||||
def visualize_model_predictions(onnx_session: str, img_path: str):
|
||||
img = Image.open(img_path)
|
||||
img = img.convert('RGB') # 转换为 RGB 模式
|
||||
|
||||
@ -53,6 +53,11 @@ def visualize_model_predictions(onnx_session, img_path):
|
||||
|
||||
# 使用已训练的 ONNX 模型进行预测
|
||||
if __name__ == '__main__':
|
||||
# TODO:
|
||||
# 写一个模型类model = Model()
|
||||
# 1. 标准化模型加载接口 model.load(weight_path: str)
|
||||
# 2. 标准化预测接口调用方式为 output_image: torch.tensor = model.predict(input_image: torch.tensor, param: dict)
|
||||
|
||||
# 加载 ONNX 模型
|
||||
model_path = 'model/best_model_11.14.19.30.onnx'
|
||||
onnx_session = load_onnx_model(model_path)
|
||||
|
||||
@ -162,6 +162,10 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25, model_sav
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# TODO:
|
||||
# 1. 能像yolo一样使用不同参数运行
|
||||
# 2. 能够比较不同模型的预测时间,根据配置可以切换模型。
|
||||
|
||||
model_ft = models.resnet18(weights='IMAGENET1K_V1')
|
||||
num_ftrs = model_ft.fc.in_features
|
||||
model_ft.fc = nn.Linear(num_ftrs, 2)
|
||||
@ -181,3 +185,4 @@ if __name__ == '__main__':
|
||||
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25,
|
||||
model_save_dir=model_save_dir, save_onnx=True)
|
||||
# visualize_model(model_ft)
|
||||
# TODO: 3. 新增dgd_class/export.py, 可以把训练得到的pt文件转换为onnx
|
||||
|
||||
Loading…
Reference in New Issue
Block a user