diff --git a/DPL/main.py b/DPL/main.py index e69e832..cbda99d 100644 --- a/DPL/main.py +++ b/DPL/main.py @@ -35,8 +35,5 @@ def main(): model = ClsModel(model_path=args.weights, device=device) - - - if __name__ == '__main__': main() \ No newline at end of file diff --git a/DPL/model_cls/models.py b/DPL/model_cls/models.py index 72c14e4..cd2745a 100644 --- a/DPL/model_cls/models.py +++ b/DPL/model_cls/models.py @@ -1,30 +1,57 @@ +from typing import Optional +from torchvision import transforms +import numpy as np import onnxruntime as ort import torch # 模型类 class Model: - def __init__(self, model_path: str, device: torch.device): + def __init__(self, model_path: str, device: torch.device, block_size: Optional[tuple] = None): """ 初始化模型,加载 ONNX 模型,并设置设备(CPU 或 GPU)。 """ self.device = device self.session = ort.InferenceSession(model_path) + self.block_size = block_size + self.data_transforms = transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) - def predict(self, img_tensor: torch.Tensor) -> torch.Tensor: + def predict(self, img: np.ndarray) -> torch.Tensor: """ 使用 ONNX 模型进行推理,返回预测结果。 """ - # 转换为 ONNX 输入格式 - img_numpy = img_tensor.cpu().numpy() + img = self.data_transforms(img) + img = img.unsqueeze(0) + + # # 转换为 ONNX 输入格式 + img_numpy = img.cpu().numpy() - # 获取输入名称和推理 inputs = {self.session.get_inputs()[0].name: img_numpy} outputs = self.session.run(None, inputs) + pred = torch.tensor(outputs[0]) + _, predicted_class = torch.max(pred, 1) - return torch.tensor(outputs[0]) + return predicted_class def load(self, model_path: str): """ 重新加载模型。 """ - self.session = ort.InferenceSession(model_path) \ No newline at end of file + self.session = ort.InferenceSession(model_path, providers=['TensorrtExecutionProvider']) + + def preprocess_predict_img(self, img: np.ndarray, block_size: Optional[tuple] = None) -> np.ndarray: + if block_size is None: + block_size = self.block_size + if block_size: + img = img.reshape((self.block_size[0], -1, self.block_size[1], 3)) + img = img.transpose((1, 0, 2, 3)) + else: + img = img[np.newaxis, ...] + + pred = self.predict(img) + pred = pred.squeeze().numpy() + return pred diff --git a/DPL/model_cls/predict.py b/DPL/model_cls/predict.py index bb9da80..cbf7edc 100644 --- a/DPL/model_cls/predict.py +++ b/DPL/model_cls/predict.py @@ -20,25 +20,12 @@ def visualize_model_predictions(model: Model, img_path: str, save_dir: str, clas img = Image.open(img_path) img = img.convert('RGB') # 转换为 RGB 模式 - # 图像预处理 - data_transforms = transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) - ]) - img_transformed = data_transforms(img) - img_transformed = img_transformed.unsqueeze(0) - img_transformed = img_transformed.to(model.device) - # 使用模型进行预测 - preds = model.predict(img_transformed) - - # 获取预测类别 - _, predicted_class = torch.max(preds, 1) + preds = model.predict(img) + print(preds) # 在图像上添加预测结果文本 - predicted_label = class_names[predicted_class[0]] + predicted_label = class_names[preds[0]] # 在图片上绘制文本 img_with_text = img.copy() @@ -93,8 +80,9 @@ def main(): parser = argparse.ArgumentParser(description="Use an ONNX model for inference.") # 设置默认值,并允许用户通过命令行进行修改 - parser.add_argument('--weights', type=str, default='model/model_1.onnx', help='Path to ONNX model file') - parser.add_argument('--img-path', type=str, default='dataset/val/dgd', + parser.add_argument('--weights', type=str, default=r'..\onnxs\dgd_class_11.14.onnx', + help='Path to ONNX model file') + parser.add_argument('--img-path', type=str, default='../dataset/val/dgd', help='Path to image or folder for inference') parser.add_argument('--save-dir', type=str, default='detect', help='Directory to save output images') parser.add_argument('--gpu', action='store_true', help='Use GPU for inference')