将滴灌带分类更改为yolo格式,并添加export.py

This commit is contained in:
wrz1-zzzzz 2024-11-19 12:09:14 +08:00
parent e3d7d3fefe
commit ba15b49cc2
3 changed files with 419 additions and 0 deletions

View File

@ -0,0 +1,98 @@
import argparse
import torch
from torchvision import models
from torch import nn
import os
import re
def load_pytorch_model(model_path, num_classes, device):
"""
加载 PyTorch 模型并设置最后的分类层
"""
print(f"Loading PyTorch model from {model_path}...")
model = models.resnet18(weights=None) # 使用 ResNet18 作为示例
model.fc = nn.Linear(model.fc.in_features, num_classes) # 修改最后一层
model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
model.to(device) # 将模型加载到指定设备
model.eval()
print("PyTorch model loaded successfully.")
return model
def export_to_onnx(model, onnx_path, img_size, batch_size, device):
"""
导出 PyTorch 模型为 ONNX 格式自动递增文件名并支持 GPU
"""
# 确保 `onnx_path` 是一个具体的文件路径
if not onnx_path.endswith('.onnx'):
os.makedirs(onnx_path, exist_ok=True) # 创建文件夹
base_dir = onnx_path
else:
base_dir = os.path.dirname(onnx_path) or '.' # 提取文件夹部分
os.makedirs(base_dir, exist_ok=True)
# 自动递增文件名
base_name = "model"
extension = ".onnx"
existing_files = [f for f in os.listdir(base_dir) if f.endswith(extension)]
# 使用正则匹配现有文件名
pattern = re.compile(rf"^{base_name}_(\d+){extension}$")
numbers = [
int(match.group(1)) for f in existing_files if (match := pattern.match(f))
]
next_number = max(numbers, default=0) + 1 # 计算下一个编号
final_name = f"{base_name}_{next_number}{extension}"
final_path = os.path.join(base_dir, final_name)
print(f"Exporting model to ONNX format at {final_path}...")
# 创建虚拟输入张量,并将其移动到指定设备
dummy_input = torch.randn(batch_size, 3, img_size, img_size, device=device)
# 导出 ONNX
torch.onnx.export(
model,
dummy_input,
final_path,
input_names=['input'],
output_names=['output'],
opset_version=11, # ONNX opset 版本
dynamic_axes={
'input': {0: 'batch_size'}, # 动态批量维度
'output': {0: 'batch_size'}
}
)
print(f"Model exported successfully to {final_path}.")
def main():
parser = argparse.ArgumentParser(description="Export PyTorch model to ONNX format.")
parser.add_argument('--weights', type=str, required=True, help='Path to PyTorch model weights (.pt file)')
parser.add_argument('--onnx-path', type=str, default='onnx', help='Output path for ONNX model')
parser.add_argument('--img-size', type=int, default=224, help='Input image size (default: 224)')
parser.add_argument('--batch-size', type=int, default=8, help='Input batch size (default: 1)')
parser.add_argument('--num-classes', type=int, default=2, help='Number of classes in the model (default: 1000)')
parser.add_argument('--use-gpu', action='store_true', help='Enable GPU support during export')
args = parser.parse_args()
# 设置设备
device = torch.device('cuda' if args.use_gpu and torch.cuda.is_available() else 'cpu')
if args.use_gpu and not torch.cuda.is_available():
print("GPU not available, switching to CPU.")
# 检查权重文件是否存在
if not os.path.isfile(args.weights):
raise FileNotFoundError(f"Model weights file not found: {args.weights}")
# 加载模型
model = load_pytorch_model(args.weights, args.num_classes, device)
# 导出为 ONNX
export_to_onnx(model, args.onnx_path, args.img_size, args.batch_size, device)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,154 @@
import argparse
import torch
import onnxruntime as ort
from PIL import Image
import os
import matplotlib.pyplot as plt
from PIL.ImageDraw import Draw
from PIL import ImageDraw, ImageFont
from torchvision import transforms
import time # 导入time模块
# 模型类
class Model:
def __init__(self, model_path: str, device: torch.device):
"""
初始化模型加载 ONNX 模型并设置设备CPU GPU
"""
self.device = device
self.session = ort.InferenceSession(model_path)
def predict(self, img_tensor: torch.Tensor) -> torch.Tensor:
"""
使用 ONNX 模型进行推理返回预测结果
"""
# 转换为 ONNX 输入格式
img_numpy = img_tensor.cpu().numpy()
# 获取输入名称和推理
inputs = {self.session.get_inputs()[0].name: img_numpy}
outputs = self.session.run(None, inputs)
return torch.tensor(outputs[0])
def load(self, model_path: str):
"""
重新加载模型
"""
self.session = ort.InferenceSession(model_path)
# 预测函数
def visualize_model_predictions(model: Model, img_path: str, save_dir: str, class_names: list):
"""
预测图像并可视化结果保存预测后的图片到指定文件夹
"""
start_time = time.time() # 开始时间
img = Image.open(img_path)
img = img.convert('RGB') # 转换为 RGB 模式
# 图像预处理
data_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
img_transformed = data_transforms(img)
img_transformed = img_transformed.unsqueeze(0)
img_transformed = img_transformed.to(model.device)
# 使用模型进行预测
preds = model.predict(img_transformed)
# 获取预测类别
_, predicted_class = torch.max(preds, 1)
# 在图像上添加预测结果文本
predicted_label = class_names[predicted_class[0]]
# 在图片上绘制文本
img_with_text = img.copy()
draw = ImageDraw.Draw(img_with_text)
font = ImageFont.load_default() # 可以根据需要更改字体
text = f'Predicted: {predicted_label}'
text_position = (10, 10) # 文本的位置,可以根据需要调整
draw.text(text_position, text, font=font, fill=(0, 255, 0)) # 白色文字
# 显示结果图片
img_with_text.show()
# 保存预测后的图像
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# 获取文件名和保存路径
img_name = os.path.basename(img_path)
save_path = os.path.join(save_dir, f"pred_{img_name}")
# 保存图片
img_with_text.save(save_path)
print(f"Prediction saved at: {save_path}")
end_time = time.time() # 结束时间
processing_time = (end_time - start_time) * 1000 # 转换为毫秒
print(f"Time taken to process {img_name}: {processing_time:.2f} ms") # 打印每张图片的处理时间(毫秒)
def process_image_folder(model: Model, folder_path: str, save_dir: str, class_names: list):
"""
处理文件夹中的所有图像并预测每张图像的类别
"""
start_time = time.time() # 记录总开始时间
# 获取文件夹中所有图片文件
image_files = [f for f in os.listdir(folder_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
# 遍历文件夹中的所有图像
for image_file in image_files:
img_path = os.path.join(folder_path, image_file)
print(f"Processing image: {img_path}")
visualize_model_predictions(model, img_path, save_dir, class_names)
end_time = time.time() # 记录总结束时间
total_time = (end_time - start_time) * 1000 # 转换为毫秒
print(f"Total time taken to process all images: {total_time:.2f} ms") # 打印总处理时间(毫秒)
def main():
# 命令行参数解析
parser = argparse.ArgumentParser(description="Use an ONNX model for inference.")
# 设置默认值,并允许用户通过命令行进行修改
parser.add_argument('--weights', type=str, default='model/model_1.onnx', help='Path to ONNX model file')
parser.add_argument('--img-path', type=str, default='d_2/val/dgd',
help='Path to image or folder for inference')
parser.add_argument('--save-dir', type=str, default='detect', help='Directory to save output images')
parser.add_argument('--gpu', action='store_true', help='Use GPU for inference')
args = parser.parse_args()
# 设置设备
device = torch.device('cuda' if args.gpu and torch.cuda.is_available() else 'cpu')
if args.gpu and not torch.cuda.is_available():
print("GPU not available, switching to CPU.")
# 加载模型
model = Model(model_path=args.weights, device=device)
# 模拟加载 class_names
# 假设模型类别数量为2
class_names = ['class_0', 'class_1']
# 检查输入路径是否为文件夹
if os.path.isdir(args.img_path):
# 如果是文件夹,处理文件夹中的所有图片
process_image_folder(model, args.img_path, args.save_dir, class_names)
else:
# 如果是单个图片文件,进行预测
visualize_model_predictions(model, args.img_path, args.save_dir, class_names)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,167 @@
import os
import argparse
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
# 设置 cudnn 优化
torch.backends.cudnn.benchmark = True
plt.ion()
def get_next_model_name(save_dir, base_name, ext):
"""
检索保存目录生成递增编号的模型文件名
Args:
save_dir (str): 模型保存目录
base_name (str): 模型基础名称
ext (str): 模型文件扩展名例如 ".pt" ".onnx"
Returns:
str: 自动递增编号的新模型文件名
"""
os.makedirs(save_dir, exist_ok=True)
existing_files = [f for f in os.listdir(save_dir) if f.startswith(base_name) and f.endswith(ext)]
existing_numbers = []
for f in existing_files:
try:
num = int(f[len(base_name):].split('.')[0].strip('_'))
existing_numbers.append(num)
except ValueError:
continue
next_number = max(existing_numbers, default=0) + 1
return os.path.join(save_dir, f"{base_name}_{next_number}{ext}")
def get_data_loaders(data_dir, batch_size):
"""Prepare data loaders for training and validation."""
data_transforms = {
'train': transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size,
shuffle=True, num_workers=4)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
return dataloaders, dataset_sizes, class_names
def train_model(model, criterion, optimizer, scheduler, dataloaders, dataset_sizes, device,
num_epochs, model_save_dir, model_base_name):
"""Train the model and save the best weights."""
since = time.time()
os.makedirs(model_save_dir, exist_ok=True)
best_acc = 0.0
best_model_path = None # 用于保存最佳模型路径
for epoch in range(num_epochs):
print(f'Epoch {epoch}/{num_epochs - 1}')
print('-' * 10)
for phase in ['train', 'val']:
if phase == 'train':
model.train()
else:
model.eval()
running_loss = 0.0
running_corrects = 0
for inputs, labels in dataloaders[phase]:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
if phase == 'train':
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
if phase == 'train':
scheduler.step()
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects.double() / dataset_sizes[phase]
print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
# 保存新的最佳模型
best_model_path = get_next_model_name(model_save_dir, model_base_name, '.pt')
torch.save(model.state_dict(), best_model_path)
print(f"Best model weights saved at {best_model_path}")
time_elapsed = time.time() - since
print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
print(f'Best val Acc: {best_acc:.4f}')
if best_model_path:
# 仅加载最后保存的最佳模型
model.load_state_dict(torch.load(best_model_path, weights_only=True))
else:
print("No best model was saved during training.")
return model
# 参数解析部分
parser = argparse.ArgumentParser(description="Train a ResNet18 model on custom dataset.")
parser.add_argument('--data-dir', type=str, default='d_2', help='Path to the dataset directory.')
parser.add_argument('--model-save-dir', type=str, default='model', help='Directory to save the trained model.')
parser.add_argument('--batch-size', type=int, default=8, help='Batch size for training.')
parser.add_argument('--epochs', type=int, default=5, help='Number of training epochs.')
parser.add_argument('--lr', type=float, default=0.0005, help='Learning rate.')
parser.add_argument('--momentum', type=float, default=0.95, help='Momentum for SGD.')
parser.add_argument('--step-size', type=int, default=5, help='Step size for LR scheduler.')
parser.add_argument('--gamma', type=float, default=0.2, help='Gamma for LR scheduler.')
parser.add_argument('--num-classes', type=int, default=2, help='Number of output classes.')
parser.add_argument('--model-base-name', type=str, default='best_model', help='Base name for saved models.')
def main(args):
"""Main function to train the model based on command-line arguments."""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dataloaders, dataset_sizes, class_names = get_data_loaders(args.data_dir, args.batch_size)
model = models.resnet18(weights='IMAGENET1K_V1')
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, args.num_classes)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
train_model(model, criterion, optimizer, scheduler, dataloaders, dataset_sizes, device,
args.epochs, args.model_save_dir, args.model_base_name)
if __name__ == '__main__':
args = parser.parse_args()
main(args)