mirror of
https://github.com/Karllzy/cotton_color.git
synced 2025-11-08 18:53:53 +00:00
修改了预测函数
This commit is contained in:
parent
6cb55b59b3
commit
00550bfb09
@ -35,8 +35,5 @@ def main():
|
||||
model = ClsModel(model_path=args.weights, device=device)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
@ -1,30 +1,57 @@
|
||||
from typing import Optional
|
||||
from torchvision import transforms
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
|
||||
# 模型类
|
||||
class Model:
|
||||
def __init__(self, model_path: str, device: torch.device):
|
||||
def __init__(self, model_path: str, device: torch.device, block_size: Optional[tuple] = None):
|
||||
"""
|
||||
初始化模型,加载 ONNX 模型,并设置设备(CPU 或 GPU)。
|
||||
"""
|
||||
self.device = device
|
||||
self.session = ort.InferenceSession(model_path)
|
||||
self.block_size = block_size
|
||||
self.data_transforms = transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
def predict(self, img_tensor: torch.Tensor) -> torch.Tensor:
|
||||
def predict(self, img: np.ndarray) -> torch.Tensor:
|
||||
"""
|
||||
使用 ONNX 模型进行推理,返回预测结果。
|
||||
"""
|
||||
# 转换为 ONNX 输入格式
|
||||
img_numpy = img_tensor.cpu().numpy()
|
||||
img = self.data_transforms(img)
|
||||
img = img.unsqueeze(0)
|
||||
|
||||
# # 转换为 ONNX 输入格式
|
||||
img_numpy = img.cpu().numpy()
|
||||
|
||||
# 获取输入名称和推理
|
||||
inputs = {self.session.get_inputs()[0].name: img_numpy}
|
||||
outputs = self.session.run(None, inputs)
|
||||
pred = torch.tensor(outputs[0])
|
||||
_, predicted_class = torch.max(pred, 1)
|
||||
|
||||
return torch.tensor(outputs[0])
|
||||
return predicted_class
|
||||
|
||||
def load(self, model_path: str):
|
||||
"""
|
||||
重新加载模型。
|
||||
"""
|
||||
self.session = ort.InferenceSession(model_path)
|
||||
self.session = ort.InferenceSession(model_path, providers=['TensorrtExecutionProvider'])
|
||||
|
||||
def preprocess_predict_img(self, img: np.ndarray, block_size: Optional[tuple] = None) -> np.ndarray:
|
||||
if block_size is None:
|
||||
block_size = self.block_size
|
||||
if block_size:
|
||||
img = img.reshape((self.block_size[0], -1, self.block_size[1], 3))
|
||||
img = img.transpose((1, 0, 2, 3))
|
||||
else:
|
||||
img = img[np.newaxis, ...]
|
||||
|
||||
pred = self.predict(img)
|
||||
pred = pred.squeeze().numpy()
|
||||
return pred
|
||||
|
||||
@ -20,25 +20,12 @@ def visualize_model_predictions(model: Model, img_path: str, save_dir: str, clas
|
||||
img = Image.open(img_path)
|
||||
img = img.convert('RGB') # 转换为 RGB 模式
|
||||
|
||||
# 图像预处理
|
||||
data_transforms = transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
img_transformed = data_transforms(img)
|
||||
img_transformed = img_transformed.unsqueeze(0)
|
||||
img_transformed = img_transformed.to(model.device)
|
||||
|
||||
# 使用模型进行预测
|
||||
preds = model.predict(img_transformed)
|
||||
|
||||
# 获取预测类别
|
||||
_, predicted_class = torch.max(preds, 1)
|
||||
preds = model.predict(img)
|
||||
print(preds)
|
||||
|
||||
# 在图像上添加预测结果文本
|
||||
predicted_label = class_names[predicted_class[0]]
|
||||
predicted_label = class_names[preds[0]]
|
||||
|
||||
# 在图片上绘制文本
|
||||
img_with_text = img.copy()
|
||||
@ -93,8 +80,9 @@ def main():
|
||||
parser = argparse.ArgumentParser(description="Use an ONNX model for inference.")
|
||||
|
||||
# 设置默认值,并允许用户通过命令行进行修改
|
||||
parser.add_argument('--weights', type=str, default='model/model_1.onnx', help='Path to ONNX model file')
|
||||
parser.add_argument('--img-path', type=str, default='dataset/val/dgd',
|
||||
parser.add_argument('--weights', type=str, default=r'..\onnxs\dgd_class_11.14.onnx',
|
||||
help='Path to ONNX model file')
|
||||
parser.add_argument('--img-path', type=str, default='../dataset/val/dgd',
|
||||
help='Path to image or folder for inference')
|
||||
parser.add_argument('--save-dir', type=str, default='detect', help='Directory to save output images')
|
||||
parser.add_argument('--gpu', action='store_true', help='Use GPU for inference')
|
||||
|
||||
Loading…
Reference in New Issue
Block a user