From e3d7d3fefed06ad4f82835dd1a308d551ebe5cd1 Mon Sep 17 00:00:00 2001 From: wrz1-zzzzz <914921336@qq.com> Date: Mon, 18 Nov 2024 20:55:41 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BB=BB=E5=8A=A1=E6=8F=8F=E8=BF=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- DPL/dgd_class/models.py | 12 ++++++++++++ DPL/dgd_class/predict.py | 7 ++++++- DPL/dgd_class/train.py | 5 +++++ 3 files changed, 23 insertions(+), 1 deletion(-) create mode 100644 DPL/dgd_class/models.py diff --git a/DPL/dgd_class/models.py b/DPL/dgd_class/models.py new file mode 100644 index 0000000..a2882e6 --- /dev/null +++ b/DPL/dgd_class/models.py @@ -0,0 +1,12 @@ +from symbol import pass_stmt + + +class Model: + def __init__(self): + pass + + def load(self, weight_path: str): + pass + + def predict(self, weight_path: str): + pass \ No newline at end of file diff --git a/DPL/dgd_class/predict.py b/DPL/dgd_class/predict.py index 679f0fe..97ea003 100644 --- a/DPL/dgd_class/predict.py +++ b/DPL/dgd_class/predict.py @@ -19,7 +19,7 @@ def load_onnx_model(model_path='model/best_model_11.14.19.30.onnx'): # 预测函数 -def visualize_model_predictions(onnx_session, img_path): +def visualize_model_predictions(onnx_session: str, img_path: str): img = Image.open(img_path) img = img.convert('RGB') # 转换为 RGB 模式 @@ -53,6 +53,11 @@ def visualize_model_predictions(onnx_session, img_path): # 使用已训练的 ONNX 模型进行预测 if __name__ == '__main__': + # TODO: + # 写一个模型类model = Model() + # 1. 标准化模型加载接口 model.load(weight_path: str) + # 2. 标准化预测接口调用方式为 output_image: torch.tensor = model.predict(input_image: torch.tensor, param: dict) + # 加载 ONNX 模型 model_path = 'model/best_model_11.14.19.30.onnx' onnx_session = load_onnx_model(model_path) diff --git a/DPL/dgd_class/train.py b/DPL/dgd_class/train.py index e18b5dd..e8ad3e0 100644 --- a/DPL/dgd_class/train.py +++ b/DPL/dgd_class/train.py @@ -162,6 +162,10 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25, model_sav if __name__ == '__main__': + # TODO: + # 1. 能像yolo一样使用不同参数运行 + # 2. 能够比较不同模型的预测时间,根据配置可以切换模型。 + model_ft = models.resnet18(weights='IMAGENET1K_V1') num_ftrs = model_ft.fc.in_features model_ft.fc = nn.Linear(num_ftrs, 2) @@ -181,3 +185,4 @@ if __name__ == '__main__': model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25, model_save_dir=model_save_dir, save_onnx=True) # visualize_model(model_ft) + # TODO: 3. 新增dgd_class/export.py, 可以把训练得到的pt文件转换为onnx