代码工程化

This commit is contained in:
ZhenyeLi 2024-11-19 14:31:56 +08:00
parent ba15b49cc2
commit f3a9e625d8
7 changed files with 227 additions and 491 deletions

1
.gitignore vendored
View File

@ -366,3 +366,4 @@ FodyWeavers.xsd
cmake-build-*
.DS_Store
DPL/dataset/*

View File

@ -1,12 +1,30 @@
from symbol import pass_stmt
import onnxruntime as ort
import torch
# 模型类
class Model:
def __init__(self):
pass
def __init__(self, model_path: str, device: torch.device):
"""
初始化模型加载 ONNX 模型并设置设备CPU GPU
"""
self.device = device
self.session = ort.InferenceSession(model_path)
def load(self, weight_path: str):
pass
def predict(self, img_tensor: torch.Tensor) -> torch.Tensor:
"""
使用 ONNX 模型进行推理返回预测结果
"""
# 转换为 ONNX 输入格式
img_numpy = img_tensor.cpu().numpy()
def predict(self, weight_path: str):
pass
# 获取输入名称和推理
inputs = {self.session.get_inputs()[0].name: img_numpy}
outputs = self.session.run(None, inputs)
return torch.tensor(outputs[0])
def load(self, model_path: str):
"""
重新加载模型
"""
self.session = ort.InferenceSession(model_path)

View File

@ -1,67 +1,126 @@
import torch
from PIL import Image
from torch import nn
import onnx
import onnxruntime as ort # 用于ONNX推理
import argparse
from models import *
from torchvision import datasets, models, transforms
from PIL import Image
import os
import matplotlib.pyplot as plt
from train import device, class_names, imshow
# 加载已训练的 ONNX 模型
def load_onnx_model(model_path='model/best_model_11.14.19.30.onnx'):
# 使用 ONNX Runtime 加载模型
session = ort.InferenceSession(model_path)
return session
from PIL.ImageDraw import Draw
from PIL import ImageDraw, ImageFont
from torchvision import transforms
import time # 导入time模块
# 预测函数
def visualize_model_predictions(onnx_session: str, img_path: str):
def visualize_model_predictions(model: Model, img_path: str, save_dir: str, class_names: list):
"""
预测图像并可视化结果保存预测后的图片到指定文件夹
"""
start_time = time.time() # 开始时间
img = Image.open(img_path)
img = img.convert('RGB') # 转换为 RGB 模式
# 图像预处理
data_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
img = data_transforms(img)
img = img.unsqueeze(0)
img = img.to(device)
img_transformed = data_transforms(img)
img_transformed = img_transformed.unsqueeze(0)
img_transformed = img_transformed.to(model.device)
# 将输入转换为 ONNX 兼容的格式numpy 数组)
img = img.cpu().numpy()
# 使用 ONNX Runtime 进行推理
inputs = {onnx_session.get_inputs()[0].name: img}
outputs = onnx_session.run(None, inputs)
preds = outputs[0]
# 使用模型进行预测
preds = model.predict(img_transformed)
# 获取预测类别
_, predicted_class = torch.max(torch.tensor(preds), 1)
_, predicted_class = torch.max(preds, 1)
# 可视化结果
ax = plt.subplot(2, 2, 1)
ax.axis('off')
ax.set_title(f'Predicted: {class_names[predicted_class[0]]}')
imshow(img[0])
# 在图像上添加预测结果文本
predicted_label = class_names[predicted_class[0]]
# 在图片上绘制文本
img_with_text = img.copy()
draw = ImageDraw.Draw(img_with_text)
font = ImageFont.load_default() # 可以根据需要更改字体
text = f'Predicted: {predicted_label}'
text_position = (10, 10) # 文本的位置,可以根据需要调整
draw.text(text_position, text, font=font, fill=(0, 255, 0)) # 白色文字
# 显示结果图片
img_with_text.show()
# 保存预测后的图像
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# 获取文件名和保存路径
img_name = os.path.basename(img_path)
save_path = os.path.join(save_dir, f"pred_{img_name}")
# 保存图片
img_with_text.save(save_path)
print(f"Prediction saved at: {save_path}")
end_time = time.time() # 结束时间
processing_time = (end_time - start_time) * 1000 # 转换为毫秒
print(f"Time taken to process {img_name}: {processing_time:.2f} ms") # 打印每张图片的处理时间(毫秒)
# 使用已训练的 ONNX 模型进行预测
if __name__ == '__main__':
# TODO:
# 写一个模型类model = Model()
# 1. 标准化模型加载接口 model.load(weight_path: str)
# 2. 标准化预测接口调用方式为 output_image: torch.tensor = model.predict(input_image: torch.tensor, param: dict)
def process_image_folder(model: Model, folder_path: str, save_dir: str, class_names: list):
"""
处理文件夹中的所有图像并预测每张图像的类别
"""
start_time = time.time() # 记录总开始时间
# 加载 ONNX 模型
model_path = 'model/best_model_11.14.19.30.onnx'
onnx_session = load_onnx_model(model_path)
# 获取文件夹中所有图片文件
image_files = [f for f in os.listdir(folder_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
# 图像路径
img_path = 'd_2/train/dgd/transformed_1.jpg' # 更改为你的图像路径
visualize_model_predictions(onnx_session, img_path)
# 遍历文件夹中的所有图像
for image_file in image_files:
img_path = os.path.join(folder_path, image_file)
print(f"Processing image: {img_path}")
visualize_model_predictions(model, img_path, save_dir, class_names)
end_time = time.time() # 记录总结束时间
total_time = (end_time - start_time) * 1000 # 转换为毫秒
print(f"Total time taken to process all images: {total_time:.2f} ms") # 打印总处理时间(毫秒)
def main():
# 命令行参数解析
parser = argparse.ArgumentParser(description="Use an ONNX model for inference.")
# 设置默认值,并允许用户通过命令行进行修改
parser.add_argument('--weights', type=str, default='model/model_1.onnx', help='Path to ONNX model file')
parser.add_argument('--img-path', type=str, default='dataset/val/dgd',
help='Path to image or folder for inference')
parser.add_argument('--save-dir', type=str, default='detect', help='Directory to save output images')
parser.add_argument('--gpu', action='store_true', help='Use GPU for inference')
args = parser.parse_args()
# 设置设备
device = torch.device('cuda' if args.gpu and torch.cuda.is_available() else 'cpu')
if args.gpu and not torch.cuda.is_available():
print("GPU not available, switching to CPU.")
# 加载模型
model = Model(model_path=args.weights, device=device)
# 模拟加载 class_names
# 假设模型类别数量为2
class_names = ['class_0', 'class_1']
# 检查输入路径是否为文件夹
if os.path.isdir(args.img_path):
# 如果是文件夹,处理文件夹中的所有图片
process_image_folder(model, args.img_path, args.save_dir, class_names)
else:
# 如果是单个图片文件,进行预测
visualize_model_predictions(model, args.img_path, args.save_dir, class_names)
if __name__ == "__main__":
main()

View File

@ -1,154 +0,0 @@
import argparse
import torch
import onnxruntime as ort
from PIL import Image
import os
import matplotlib.pyplot as plt
from PIL.ImageDraw import Draw
from PIL import ImageDraw, ImageFont
from torchvision import transforms
import time # 导入time模块
# 模型类
class Model:
def __init__(self, model_path: str, device: torch.device):
"""
初始化模型加载 ONNX 模型并设置设备CPU GPU
"""
self.device = device
self.session = ort.InferenceSession(model_path)
def predict(self, img_tensor: torch.Tensor) -> torch.Tensor:
"""
使用 ONNX 模型进行推理返回预测结果
"""
# 转换为 ONNX 输入格式
img_numpy = img_tensor.cpu().numpy()
# 获取输入名称和推理
inputs = {self.session.get_inputs()[0].name: img_numpy}
outputs = self.session.run(None, inputs)
return torch.tensor(outputs[0])
def load(self, model_path: str):
"""
重新加载模型
"""
self.session = ort.InferenceSession(model_path)
# 预测函数
def visualize_model_predictions(model: Model, img_path: str, save_dir: str, class_names: list):
"""
预测图像并可视化结果保存预测后的图片到指定文件夹
"""
start_time = time.time() # 开始时间
img = Image.open(img_path)
img = img.convert('RGB') # 转换为 RGB 模式
# 图像预处理
data_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
img_transformed = data_transforms(img)
img_transformed = img_transformed.unsqueeze(0)
img_transformed = img_transformed.to(model.device)
# 使用模型进行预测
preds = model.predict(img_transformed)
# 获取预测类别
_, predicted_class = torch.max(preds, 1)
# 在图像上添加预测结果文本
predicted_label = class_names[predicted_class[0]]
# 在图片上绘制文本
img_with_text = img.copy()
draw = ImageDraw.Draw(img_with_text)
font = ImageFont.load_default() # 可以根据需要更改字体
text = f'Predicted: {predicted_label}'
text_position = (10, 10) # 文本的位置,可以根据需要调整
draw.text(text_position, text, font=font, fill=(0, 255, 0)) # 白色文字
# 显示结果图片
img_with_text.show()
# 保存预测后的图像
if not os.path.exists(save_dir):
os.makedirs(save_dir)
# 获取文件名和保存路径
img_name = os.path.basename(img_path)
save_path = os.path.join(save_dir, f"pred_{img_name}")
# 保存图片
img_with_text.save(save_path)
print(f"Prediction saved at: {save_path}")
end_time = time.time() # 结束时间
processing_time = (end_time - start_time) * 1000 # 转换为毫秒
print(f"Time taken to process {img_name}: {processing_time:.2f} ms") # 打印每张图片的处理时间(毫秒)
def process_image_folder(model: Model, folder_path: str, save_dir: str, class_names: list):
"""
处理文件夹中的所有图像并预测每张图像的类别
"""
start_time = time.time() # 记录总开始时间
# 获取文件夹中所有图片文件
image_files = [f for f in os.listdir(folder_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
# 遍历文件夹中的所有图像
for image_file in image_files:
img_path = os.path.join(folder_path, image_file)
print(f"Processing image: {img_path}")
visualize_model_predictions(model, img_path, save_dir, class_names)
end_time = time.time() # 记录总结束时间
total_time = (end_time - start_time) * 1000 # 转换为毫秒
print(f"Total time taken to process all images: {total_time:.2f} ms") # 打印总处理时间(毫秒)
def main():
# 命令行参数解析
parser = argparse.ArgumentParser(description="Use an ONNX model for inference.")
# 设置默认值,并允许用户通过命令行进行修改
parser.add_argument('--weights', type=str, default='model/model_1.onnx', help='Path to ONNX model file')
parser.add_argument('--img-path', type=str, default='d_2/val/dgd',
help='Path to image or folder for inference')
parser.add_argument('--save-dir', type=str, default='detect', help='Directory to save output images')
parser.add_argument('--gpu', action='store_true', help='Use GPU for inference')
args = parser.parse_args()
# 设置设备
device = torch.device('cuda' if args.gpu and torch.cuda.is_available() else 'cpu')
if args.gpu and not torch.cuda.is_available():
print("GPU not available, switching to CPU.")
# 加载模型
model = Model(model_path=args.weights, device=device)
# 模拟加载 class_names
# 假设模型类别数量为2
class_names = ['class_0', 'class_1']
# 检查输入路径是否为文件夹
if os.path.isdir(args.img_path):
# 如果是文件夹,处理文件夹中的所有图片
process_image_folder(model, args.img_path, args.save_dir, class_names)
else:
# 如果是单个图片文件,进行预测
visualize_model_predictions(model, args.img_path, args.save_dir, class_names)
if __name__ == "__main__":
main()

View File

@ -1,74 +1,76 @@
import os
import argparse
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
from PIL import Image
# 初始化 cudnn 优化
cudnn.benchmark = True
plt.ion() # interactive mode
# 设置 cudnn 优化
torch.backends.cudnn.benchmark = True
plt.ion()
data_transforms = {
'train': transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
def get_next_model_name(save_dir, base_name, ext):
"""
检索保存目录生成递增编号的模型文件名
data_dir = 'd_2'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
shuffle=True, num_workers=4)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
Args:
save_dir (str): 模型保存目录
base_name (str): 模型基础名称
ext (str): 模型文件扩展名例如 ".pt" ".onnx"
device = torch.device("cuda:0")
#
def imshow(inp, title=None):
"""Display image for Tensor."""
if isinstance(inp, np.ndarray):
inp = inp.transpose((1, 2, 0))
# inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(0.001) # pause a bit so that plots are updated
Returns:
str: 自动递增编号的新模型文件名
"""
os.makedirs(save_dir, exist_ok=True)
existing_files = [f for f in os.listdir(save_dir) if f.startswith(base_name) and f.endswith(ext)]
existing_numbers = []
for f in existing_files:
try:
num = int(f[len(base_name):].split('.')[0].strip('_'))
existing_numbers.append(num)
except ValueError:
continue
next_number = max(existing_numbers, default=0) + 1
return os.path.join(save_dir, f"{base_name}_{next_number}{ext}")
def train_model(model, criterion, optimizer, scheduler, num_epochs=25, model_save_dir='model', save_onnx=False):
def get_data_loaders(data_dir, batch_size):
"""Prepare data loaders for training and validation."""
data_transforms = {
'train': transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size,
shuffle=True, num_workers=4)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
return dataloaders, dataset_sizes, class_names
def train_model(model, criterion, optimizer, scheduler, dataloaders, dataset_sizes, device,
num_epochs, model_save_dir, model_base_name):
"""Train the model and save the best weights."""
since = time.time()
# 创建保存模型的目录(如果不存在)
os.makedirs(model_save_dir, exist_ok=True)
# 设置模型保存路径
best_model_params_path = os.path.join(model_save_dir, 'best_model_params_11.14.19.30.pt')
best_model_onnx_path = os.path.join(model_save_dir, 'best_model_11.14.19.30.onnx')
# 初始保存模型
torch.save(model.state_dict(), best_model_params_path)
best_acc = 0.0
best_model_path = None # 用于保存最佳模型路径
for epoch in range(num_epochs):
print(f'Epoch {epoch}/{num_epochs - 1}')
@ -76,16 +78,15 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25, model_sav
for phase in ['train', 'val']:
if phase == 'train':
model.train() # Set model to training mode
model.train()
else:
model.eval() # Set model to evaluate mode
model.eval()
running_loss = 0.0
running_corrects = 0
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
@ -111,78 +112,56 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25, model_sav
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
# 保存最佳模型
torch.save(model.state_dict(), best_model_params_path)
# 如果需要保存ONNX格式的模型可以在这里执行
if save_onnx:
# 导出模型为 ONNX 格式
model.eval()
dummy_input = torch.randn(1, 3, 224, 224).to(device) # 用于推理的假输入(与训练时输入的大小一致)
torch.onnx.export(model, dummy_input, best_model_onnx_path,
export_params=True, opset_version=11,
do_constant_folding=True, input_names=['input'], output_names=['output'])
print(f"ONNX model saved at {best_model_onnx_path}")
# 保存新的最佳模型
best_model_path = get_next_model_name(model_save_dir, model_base_name, '.pt')
print()
torch.save(model.state_dict(), best_model_path)
print(f"Best model weights saved at {best_model_path}")
time_elapsed = time.time() - since
print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
print(f'Best val Acc: {best_acc:4f}')
print(f'Best val Acc: {best_acc:.4f}')
model.load_state_dict(torch.load(best_model_params_path, weights_only=True))
if best_model_path:
# 仅加载最后保存的最佳模型
model.load_state_dict(torch.load(best_model_path, weights_only=True))
else:
print("No best model was saved during training.")
return model
# def visualize_model(model, num_images=6):
# was_training = model.training
# model.eval()
# images_so_far = 0
# fig = plt.figure()
#
# with torch.no_grad():
# for i, (inputs, labels) in enumerate(dataloaders['train']):
# inputs = inputs.to(device)
# labels = labels.to(device)
#
# outputs = model(inputs)
# _, preds = torch.max(outputs, 1)
#
# for j in range(inputs.size()[0]):
# images_so_far += 1
# ax = plt.subplot(num_images // 2, 2, images_so_far)
# ax.axis('off')
# ax.set_title(f'predicted: {class_names[preds[j]]}')
# plt.imshow(inputs.cpu().data[j])
#
# if images_so_far == num_images:
# model.train(mode=was_training)
# return
# model.train(mode=was_training)
# 参数解析部分
parser = argparse.ArgumentParser(description="Train a ResNet18 model on custom dataset.")
parser.add_argument('--data-dir', type=str, default='dataset', help='Path to the dataset directory.')
parser.add_argument('--model-save-dir', type=str, default='model', help='Directory to save the trained model.')
parser.add_argument('--batch-size', type=int, default=8, help='Batch size for training.')
parser.add_argument('--epochs', type=int, default=5, help='Number of training epochs.')
parser.add_argument('--lr', type=float, default=0.0005, help='Learning rate.')
parser.add_argument('--momentum', type=float, default=0.95, help='Momentum for SGD.')
parser.add_argument('--step-size', type=int, default=5, help='Step size for LR scheduler.')
parser.add_argument('--gamma', type=float, default=0.2, help='Gamma for LR scheduler.')
parser.add_argument('--num-classes', type=int, default=2, help='Number of output classes.')
parser.add_argument('--model-base-name', type=str, default='best_model', help='Base name for saved models.')
def main(args):
"""Main function to train the model based on command-line arguments."""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if __name__ == '__main__':
# TODO:
# 1. 能像yolo一样使用不同参数运行
# 2. 能够比较不同模型的预测时间,根据配置可以切换模型。
dataloaders, dataset_sizes, class_names = get_data_loaders(args.data_dir, args.batch_size)
model_ft = models.resnet18(weights='IMAGENET1K_V1')
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 2)
model_ft = model_ft.to(device)
model = models.resnet18(weights='IMAGENET1K_V1')
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, args.num_classes)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
train_model(model, criterion, optimizer, scheduler, dataloaders, dataset_sizes, device,
args.epochs, args.model_save_dir, args.model_base_name)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
# 指定保存模型的目录
model_save_dir = 'model' # 你可以修改为你希望保存的路径
# 训练模型并保存 ONNX 格式的模型
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25,
model_save_dir=model_save_dir, save_onnx=True)
# visualize_model(model_ft)
# TODO: 3. 新增dgd_class/export.py, 可以把训练得到的pt文件转换为onnx
if __name__ == '__main__':
args = parser.parse_args()
main(args)

View File

@ -1,167 +0,0 @@
import os
import argparse
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
# 设置 cudnn 优化
torch.backends.cudnn.benchmark = True
plt.ion()
def get_next_model_name(save_dir, base_name, ext):
"""
检索保存目录生成递增编号的模型文件名
Args:
save_dir (str): 模型保存目录
base_name (str): 模型基础名称
ext (str): 模型文件扩展名例如 ".pt" ".onnx"
Returns:
str: 自动递增编号的新模型文件名
"""
os.makedirs(save_dir, exist_ok=True)
existing_files = [f for f in os.listdir(save_dir) if f.startswith(base_name) and f.endswith(ext)]
existing_numbers = []
for f in existing_files:
try:
num = int(f[len(base_name):].split('.')[0].strip('_'))
existing_numbers.append(num)
except ValueError:
continue
next_number = max(existing_numbers, default=0) + 1
return os.path.join(save_dir, f"{base_name}_{next_number}{ext}")
def get_data_loaders(data_dir, batch_size):
"""Prepare data loaders for training and validation."""
data_transforms = {
'train': transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size,
shuffle=True, num_workers=4)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
return dataloaders, dataset_sizes, class_names
def train_model(model, criterion, optimizer, scheduler, dataloaders, dataset_sizes, device,
num_epochs, model_save_dir, model_base_name):
"""Train the model and save the best weights."""
since = time.time()
os.makedirs(model_save_dir, exist_ok=True)
best_acc = 0.0
best_model_path = None # 用于保存最佳模型路径
for epoch in range(num_epochs):
print(f'Epoch {epoch}/{num_epochs - 1}')
print('-' * 10)
for phase in ['train', 'val']:
if phase == 'train':
model.train()
else:
model.eval()
running_loss = 0.0
running_corrects = 0
for inputs, labels in dataloaders[phase]:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
if phase == 'train':
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
if phase == 'train':
scheduler.step()
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects.double() / dataset_sizes[phase]
print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
# 保存新的最佳模型
best_model_path = get_next_model_name(model_save_dir, model_base_name, '.pt')
torch.save(model.state_dict(), best_model_path)
print(f"Best model weights saved at {best_model_path}")
time_elapsed = time.time() - since
print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
print(f'Best val Acc: {best_acc:.4f}')
if best_model_path:
# 仅加载最后保存的最佳模型
model.load_state_dict(torch.load(best_model_path, weights_only=True))
else:
print("No best model was saved during training.")
return model
# 参数解析部分
parser = argparse.ArgumentParser(description="Train a ResNet18 model on custom dataset.")
parser.add_argument('--data-dir', type=str, default='d_2', help='Path to the dataset directory.')
parser.add_argument('--model-save-dir', type=str, default='model', help='Directory to save the trained model.')
parser.add_argument('--batch-size', type=int, default=8, help='Batch size for training.')
parser.add_argument('--epochs', type=int, default=5, help='Number of training epochs.')
parser.add_argument('--lr', type=float, default=0.0005, help='Learning rate.')
parser.add_argument('--momentum', type=float, default=0.95, help='Momentum for SGD.')
parser.add_argument('--step-size', type=int, default=5, help='Step size for LR scheduler.')
parser.add_argument('--gamma', type=float, default=0.2, help='Gamma for LR scheduler.')
parser.add_argument('--num-classes', type=int, default=2, help='Number of output classes.')
parser.add_argument('--model-base-name', type=str, default='best_model', help='Base name for saved models.')
def main(args):
"""Main function to train the model based on command-line arguments."""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dataloaders, dataset_sizes, class_names = get_data_loaders(args.data_dir, args.batch_size)
model = models.resnet18(weights='IMAGENET1K_V1')
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, args.num_classes)
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
train_model(model, criterion, optimizer, scheduler, dataloaders, dataset_sizes, device,
args.epochs, args.model_save_dir, args.model_base_name)
if __name__ == '__main__':
args = parser.parse_args()
main(args)