任务描述

This commit is contained in:
wrz1-zzzzz 2024-11-18 20:55:41 +08:00
parent 5c9df03c21
commit e3d7d3fefe
3 changed files with 23 additions and 1 deletions

12
DPL/dgd_class/models.py Normal file
View File

@ -0,0 +1,12 @@
from symbol import pass_stmt
class Model:
def __init__(self):
pass
def load(self, weight_path: str):
pass
def predict(self, weight_path: str):
pass

View File

@ -19,7 +19,7 @@ def load_onnx_model(model_path='model/best_model_11.14.19.30.onnx'):
# 预测函数
def visualize_model_predictions(onnx_session, img_path):
def visualize_model_predictions(onnx_session: str, img_path: str):
img = Image.open(img_path)
img = img.convert('RGB') # 转换为 RGB 模式
@ -53,6 +53,11 @@ def visualize_model_predictions(onnx_session, img_path):
# 使用已训练的 ONNX 模型进行预测
if __name__ == '__main__':
# TODO:
# 写一个模型类model = Model()
# 1. 标准化模型加载接口 model.load(weight_path: str)
# 2. 标准化预测接口调用方式为 output_image: torch.tensor = model.predict(input_image: torch.tensor, param: dict)
# 加载 ONNX 模型
model_path = 'model/best_model_11.14.19.30.onnx'
onnx_session = load_onnx_model(model_path)

View File

@ -162,6 +162,10 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25, model_sav
if __name__ == '__main__':
# TODO:
# 1. 能像yolo一样使用不同参数运行
# 2. 能够比较不同模型的预测时间,根据配置可以切换模型。
model_ft = models.resnet18(weights='IMAGENET1K_V1')
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 2)
@ -181,3 +185,4 @@ if __name__ == '__main__':
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25,
model_save_dir=model_save_dir, save_onnx=True)
# visualize_model(model_ft)
# TODO: 3. 新增dgd_class/export.py, 可以把训练得到的pt文件转换为onnx