修改了预测函数

This commit is contained in:
ZhenyeLi 2024-11-19 18:51:06 +08:00
parent 6cb55b59b3
commit 00550bfb09
3 changed files with 40 additions and 28 deletions

View File

@ -35,8 +35,5 @@ def main():
model = ClsModel(model_path=args.weights, device=device)
if __name__ == '__main__':
main()

View File

@ -1,30 +1,57 @@
from typing import Optional
from torchvision import transforms
import numpy as np
import onnxruntime as ort
import torch
# 模型类
class Model:
def __init__(self, model_path: str, device: torch.device):
def __init__(self, model_path: str, device: torch.device, block_size: Optional[tuple] = None):
"""
初始化模型加载 ONNX 模型并设置设备CPU GPU
"""
self.device = device
self.session = ort.InferenceSession(model_path)
self.block_size = block_size
self.data_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def predict(self, img_tensor: torch.Tensor) -> torch.Tensor:
def predict(self, img: np.ndarray) -> torch.Tensor:
"""
使用 ONNX 模型进行推理返回预测结果
"""
# 转换为 ONNX 输入格式
img_numpy = img_tensor.cpu().numpy()
img = self.data_transforms(img)
img = img.unsqueeze(0)
# # 转换为 ONNX 输入格式
img_numpy = img.cpu().numpy()
# 获取输入名称和推理
inputs = {self.session.get_inputs()[0].name: img_numpy}
outputs = self.session.run(None, inputs)
pred = torch.tensor(outputs[0])
_, predicted_class = torch.max(pred, 1)
return torch.tensor(outputs[0])
return predicted_class
def load(self, model_path: str):
"""
重新加载模型
"""
self.session = ort.InferenceSession(model_path)
self.session = ort.InferenceSession(model_path, providers=['TensorrtExecutionProvider'])
def preprocess_predict_img(self, img: np.ndarray, block_size: Optional[tuple] = None) -> np.ndarray:
if block_size is None:
block_size = self.block_size
if block_size:
img = img.reshape((self.block_size[0], -1, self.block_size[1], 3))
img = img.transpose((1, 0, 2, 3))
else:
img = img[np.newaxis, ...]
pred = self.predict(img)
pred = pred.squeeze().numpy()
return pred

View File

@ -20,25 +20,12 @@ def visualize_model_predictions(model: Model, img_path: str, save_dir: str, clas
img = Image.open(img_path)
img = img.convert('RGB') # 转换为 RGB 模式
# 图像预处理
data_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
img_transformed = data_transforms(img)
img_transformed = img_transformed.unsqueeze(0)
img_transformed = img_transformed.to(model.device)
# 使用模型进行预测
preds = model.predict(img_transformed)
# 获取预测类别
_, predicted_class = torch.max(preds, 1)
preds = model.predict(img)
print(preds)
# 在图像上添加预测结果文本
predicted_label = class_names[predicted_class[0]]
predicted_label = class_names[preds[0]]
# 在图片上绘制文本
img_with_text = img.copy()
@ -93,8 +80,9 @@ def main():
parser = argparse.ArgumentParser(description="Use an ONNX model for inference.")
# 设置默认值,并允许用户通过命令行进行修改
parser.add_argument('--weights', type=str, default='model/model_1.onnx', help='Path to ONNX model file')
parser.add_argument('--img-path', type=str, default='dataset/val/dgd',
parser.add_argument('--weights', type=str, default=r'..\onnxs\dgd_class_11.14.onnx',
help='Path to ONNX model file')
parser.add_argument('--img-path', type=str, default='../dataset/val/dgd',
help='Path to image or folder for inference')
parser.add_argument('--save-dir', type=str, default='detect', help='Directory to save output images')
parser.add_argument('--gpu', action='store_true', help='Use GPU for inference')