diff --git a/.gitignore b/.gitignore index 8700afa..c871dbc 100644 --- a/.gitignore +++ b/.gitignore @@ -366,3 +366,4 @@ FodyWeavers.xsd cmake-build-* .DS_Store +DPL/dataset/* \ No newline at end of file diff --git a/DPL/dgd_class/export_11.19.py b/DPL/dgd_class/export.py similarity index 100% rename from DPL/dgd_class/export_11.19.py rename to DPL/dgd_class/export.py diff --git a/DPL/dgd_class/models.py b/DPL/dgd_class/models.py index a2882e6..72c14e4 100644 --- a/DPL/dgd_class/models.py +++ b/DPL/dgd_class/models.py @@ -1,12 +1,30 @@ -from symbol import pass_stmt - +import onnxruntime as ort +import torch +# 模型类 class Model: - def __init__(self): - pass + def __init__(self, model_path: str, device: torch.device): + """ + 初始化模型,加载 ONNX 模型,并设置设备(CPU 或 GPU)。 + """ + self.device = device + self.session = ort.InferenceSession(model_path) - def load(self, weight_path: str): - pass + def predict(self, img_tensor: torch.Tensor) -> torch.Tensor: + """ + 使用 ONNX 模型进行推理,返回预测结果。 + """ + # 转换为 ONNX 输入格式 + img_numpy = img_tensor.cpu().numpy() - def predict(self, weight_path: str): - pass \ No newline at end of file + # 获取输入名称和推理 + inputs = {self.session.get_inputs()[0].name: img_numpy} + outputs = self.session.run(None, inputs) + + return torch.tensor(outputs[0]) + + def load(self, model_path: str): + """ + 重新加载模型。 + """ + self.session = ort.InferenceSession(model_path) \ No newline at end of file diff --git a/DPL/dgd_class/predict.py b/DPL/dgd_class/predict.py index 97ea003..bb9da80 100644 --- a/DPL/dgd_class/predict.py +++ b/DPL/dgd_class/predict.py @@ -1,67 +1,126 @@ -import torch -from PIL import Image -from torch import nn -import onnx -import onnxruntime as ort # 用于ONNX推理 +import argparse +from models import * -from torchvision import datasets, models, transforms +from PIL import Image import os import matplotlib.pyplot as plt - -from train import device, class_names, imshow - - -# 加载已训练的 ONNX 模型 -def load_onnx_model(model_path='model/best_model_11.14.19.30.onnx'): - # 使用 ONNX Runtime 加载模型 - session = ort.InferenceSession(model_path) - return session +from PIL.ImageDraw import Draw +from PIL import ImageDraw, ImageFont +from torchvision import transforms +import time # 导入time模块 # 预测函数 -def visualize_model_predictions(onnx_session: str, img_path: str): +def visualize_model_predictions(model: Model, img_path: str, save_dir: str, class_names: list): + """ + 预测图像并可视化结果,保存预测后的图片到指定文件夹。 + """ + start_time = time.time() # 开始时间 + img = Image.open(img_path) img = img.convert('RGB') # 转换为 RGB 模式 + # 图像预处理 data_transforms = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) - img = data_transforms(img) - img = img.unsqueeze(0) - img = img.to(device) + img_transformed = data_transforms(img) + img_transformed = img_transformed.unsqueeze(0) + img_transformed = img_transformed.to(model.device) - # 将输入转换为 ONNX 兼容的格式(numpy 数组) - img = img.cpu().numpy() - - # 使用 ONNX Runtime 进行推理 - inputs = {onnx_session.get_inputs()[0].name: img} - outputs = onnx_session.run(None, inputs) - preds = outputs[0] + # 使用模型进行预测 + preds = model.predict(img_transformed) # 获取预测类别 - _, predicted_class = torch.max(torch.tensor(preds), 1) + _, predicted_class = torch.max(preds, 1) - # 可视化结果 - ax = plt.subplot(2, 2, 1) - ax.axis('off') - ax.set_title(f'Predicted: {class_names[predicted_class[0]]}') - imshow(img[0]) + # 在图像上添加预测结果文本 + predicted_label = class_names[predicted_class[0]] + + # 在图片上绘制文本 + img_with_text = img.copy() + draw = ImageDraw.Draw(img_with_text) + font = ImageFont.load_default() # 可以根据需要更改字体 + text = f'Predicted: {predicted_label}' + text_position = (10, 10) # 文本的位置,可以根据需要调整 + draw.text(text_position, text, font=font, fill=(0, 255, 0)) # 白色文字 + + # 显示结果图片 + img_with_text.show() + + # 保存预测后的图像 + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + # 获取文件名和保存路径 + img_name = os.path.basename(img_path) + save_path = os.path.join(save_dir, f"pred_{img_name}") + + # 保存图片 + img_with_text.save(save_path) + print(f"Prediction saved at: {save_path}") + + end_time = time.time() # 结束时间 + processing_time = (end_time - start_time) * 1000 # 转换为毫秒 + print(f"Time taken to process {img_name}: {processing_time:.2f} ms") # 打印每张图片的处理时间(毫秒) -# 使用已训练的 ONNX 模型进行预测 -if __name__ == '__main__': - # TODO: - # 写一个模型类model = Model() - # 1. 标准化模型加载接口 model.load(weight_path: str) - # 2. 标准化预测接口调用方式为 output_image: torch.tensor = model.predict(input_image: torch.tensor, param: dict) +def process_image_folder(model: Model, folder_path: str, save_dir: str, class_names: list): + """ + 处理文件夹中的所有图像,并预测每张图像的类别。 + """ + start_time = time.time() # 记录总开始时间 - # 加载 ONNX 模型 - model_path = 'model/best_model_11.14.19.30.onnx' - onnx_session = load_onnx_model(model_path) + # 获取文件夹中所有图片文件 + image_files = [f for f in os.listdir(folder_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] - # 图像路径 - img_path = 'd_2/train/dgd/transformed_1.jpg' # 更改为你的图像路径 - visualize_model_predictions(onnx_session, img_path) + # 遍历文件夹中的所有图像 + for image_file in image_files: + img_path = os.path.join(folder_path, image_file) + print(f"Processing image: {img_path}") + visualize_model_predictions(model, img_path, save_dir, class_names) + + end_time = time.time() # 记录总结束时间 + total_time = (end_time - start_time) * 1000 # 转换为毫秒 + print(f"Total time taken to process all images: {total_time:.2f} ms") # 打印总处理时间(毫秒) + + +def main(): + # 命令行参数解析 + parser = argparse.ArgumentParser(description="Use an ONNX model for inference.") + + # 设置默认值,并允许用户通过命令行进行修改 + parser.add_argument('--weights', type=str, default='model/model_1.onnx', help='Path to ONNX model file') + parser.add_argument('--img-path', type=str, default='dataset/val/dgd', + help='Path to image or folder for inference') + parser.add_argument('--save-dir', type=str, default='detect', help='Directory to save output images') + parser.add_argument('--gpu', action='store_true', help='Use GPU for inference') + + args = parser.parse_args() + + # 设置设备 + device = torch.device('cuda' if args.gpu and torch.cuda.is_available() else 'cpu') + if args.gpu and not torch.cuda.is_available(): + print("GPU not available, switching to CPU.") + + # 加载模型 + model = Model(model_path=args.weights, device=device) + + # 模拟加载 class_names + # 假设模型类别数量为2 + class_names = ['class_0', 'class_1'] + + # 检查输入路径是否为文件夹 + if os.path.isdir(args.img_path): + # 如果是文件夹,处理文件夹中的所有图片 + process_image_folder(model, args.img_path, args.save_dir, class_names) + else: + # 如果是单个图片文件,进行预测 + visualize_model_predictions(model, args.img_path, args.save_dir, class_names) + + +if __name__ == "__main__": + main() diff --git a/DPL/dgd_class/predict_11.19.py b/DPL/dgd_class/predict_11.19.py deleted file mode 100644 index 312e700..0000000 --- a/DPL/dgd_class/predict_11.19.py +++ /dev/null @@ -1,154 +0,0 @@ -import argparse -import torch -import onnxruntime as ort -from PIL import Image -import os -import matplotlib.pyplot as plt -from PIL.ImageDraw import Draw -from PIL import ImageDraw, ImageFont -from torchvision import transforms -import time # 导入time模块 - -# 模型类 -class Model: - def __init__(self, model_path: str, device: torch.device): - """ - 初始化模型,加载 ONNX 模型,并设置设备(CPU 或 GPU)。 - """ - self.device = device - self.session = ort.InferenceSession(model_path) - - def predict(self, img_tensor: torch.Tensor) -> torch.Tensor: - """ - 使用 ONNX 模型进行推理,返回预测结果。 - """ - # 转换为 ONNX 输入格式 - img_numpy = img_tensor.cpu().numpy() - - # 获取输入名称和推理 - inputs = {self.session.get_inputs()[0].name: img_numpy} - outputs = self.session.run(None, inputs) - - return torch.tensor(outputs[0]) - - def load(self, model_path: str): - """ - 重新加载模型。 - """ - self.session = ort.InferenceSession(model_path) - - -# 预测函数 -def visualize_model_predictions(model: Model, img_path: str, save_dir: str, class_names: list): - """ - 预测图像并可视化结果,保存预测后的图片到指定文件夹。 - """ - start_time = time.time() # 开始时间 - - img = Image.open(img_path) - img = img.convert('RGB') # 转换为 RGB 模式 - - # 图像预处理 - data_transforms = transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) - ]) - img_transformed = data_transforms(img) - img_transformed = img_transformed.unsqueeze(0) - img_transformed = img_transformed.to(model.device) - - # 使用模型进行预测 - preds = model.predict(img_transformed) - - # 获取预测类别 - _, predicted_class = torch.max(preds, 1) - - # 在图像上添加预测结果文本 - predicted_label = class_names[predicted_class[0]] - - # 在图片上绘制文本 - img_with_text = img.copy() - draw = ImageDraw.Draw(img_with_text) - font = ImageFont.load_default() # 可以根据需要更改字体 - text = f'Predicted: {predicted_label}' - text_position = (10, 10) # 文本的位置,可以根据需要调整 - draw.text(text_position, text, font=font, fill=(0, 255, 0)) # 白色文字 - - # 显示结果图片 - img_with_text.show() - - # 保存预测后的图像 - if not os.path.exists(save_dir): - os.makedirs(save_dir) - - # 获取文件名和保存路径 - img_name = os.path.basename(img_path) - save_path = os.path.join(save_dir, f"pred_{img_name}") - - # 保存图片 - img_with_text.save(save_path) - print(f"Prediction saved at: {save_path}") - - end_time = time.time() # 结束时间 - processing_time = (end_time - start_time) * 1000 # 转换为毫秒 - print(f"Time taken to process {img_name}: {processing_time:.2f} ms") # 打印每张图片的处理时间(毫秒) - - -def process_image_folder(model: Model, folder_path: str, save_dir: str, class_names: list): - """ - 处理文件夹中的所有图像,并预测每张图像的类别。 - """ - start_time = time.time() # 记录总开始时间 - - # 获取文件夹中所有图片文件 - image_files = [f for f in os.listdir(folder_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))] - - # 遍历文件夹中的所有图像 - for image_file in image_files: - img_path = os.path.join(folder_path, image_file) - print(f"Processing image: {img_path}") - visualize_model_predictions(model, img_path, save_dir, class_names) - - end_time = time.time() # 记录总结束时间 - total_time = (end_time - start_time) * 1000 # 转换为毫秒 - print(f"Total time taken to process all images: {total_time:.2f} ms") # 打印总处理时间(毫秒) - - -def main(): - # 命令行参数解析 - parser = argparse.ArgumentParser(description="Use an ONNX model for inference.") - - # 设置默认值,并允许用户通过命令行进行修改 - parser.add_argument('--weights', type=str, default='model/model_1.onnx', help='Path to ONNX model file') - parser.add_argument('--img-path', type=str, default='d_2/val/dgd', - help='Path to image or folder for inference') - parser.add_argument('--save-dir', type=str, default='detect', help='Directory to save output images') - parser.add_argument('--gpu', action='store_true', help='Use GPU for inference') - - args = parser.parse_args() - - # 设置设备 - device = torch.device('cuda' if args.gpu and torch.cuda.is_available() else 'cpu') - if args.gpu and not torch.cuda.is_available(): - print("GPU not available, switching to CPU.") - - # 加载模型 - model = Model(model_path=args.weights, device=device) - - # 模拟加载 class_names - # 假设模型类别数量为2 - class_names = ['class_0', 'class_1'] - - # 检查输入路径是否为文件夹 - if os.path.isdir(args.img_path): - # 如果是文件夹,处理文件夹中的所有图片 - process_image_folder(model, args.img_path, args.save_dir, class_names) - else: - # 如果是单个图片文件,进行预测 - visualize_model_predictions(model, args.img_path, args.save_dir, class_names) - - -if __name__ == "__main__": - main() diff --git a/DPL/dgd_class/train.py b/DPL/dgd_class/train.py index e8ad3e0..4f0ce9c 100644 --- a/DPL/dgd_class/train.py +++ b/DPL/dgd_class/train.py @@ -1,74 +1,76 @@ +import os +import argparse +import time import torch import torch.nn as nn import torch.optim as optim from torch.optim import lr_scheduler -import torch.backends.cudnn as cudnn -import numpy as np -import torchvision from torchvision import datasets, models, transforms import matplotlib.pyplot as plt -import time -import os -from PIL import Image -# 初始化 cudnn 优化 -cudnn.benchmark = True -plt.ion() # interactive mode +# 设置 cudnn 优化 +torch.backends.cudnn.benchmark = True +plt.ion() -data_transforms = { - 'train': transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) - ]), - 'val': transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) - ]), -} +def get_next_model_name(save_dir, base_name, ext): + """ + 检索保存目录,生成递增编号的模型文件名。 -data_dir = 'd_2' -image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), - data_transforms[x]) - for x in ['train', 'val']} -dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4, - shuffle=True, num_workers=4) - for x in ['train', 'val']} -dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} -class_names = image_datasets['train'].classes + Args: + save_dir (str): 模型保存目录 + base_name (str): 模型基础名称 + ext (str): 模型文件扩展名(例如 ".pt" 或 ".onnx") -device = torch.device("cuda:0") - -# -def imshow(inp, title=None): - """Display image for Tensor.""" - if isinstance(inp, np.ndarray): - inp = inp.transpose((1, 2, 0)) - # inp = inp.numpy().transpose((1, 2, 0)) - mean = np.array([0.485, 0.456, 0.406]) - std = np.array([0.229, 0.224, 0.225]) - inp = std * inp + mean - inp = np.clip(inp, 0, 1) - plt.imshow(inp) - if title is not None: - plt.title(title) - plt.pause(0.001) # pause a bit so that plots are updated + Returns: + str: 自动递增编号的新模型文件名 + """ + os.makedirs(save_dir, exist_ok=True) + existing_files = [f for f in os.listdir(save_dir) if f.startswith(base_name) and f.endswith(ext)] + existing_numbers = [] + for f in existing_files: + try: + num = int(f[len(base_name):].split('.')[0].strip('_')) + existing_numbers.append(num) + except ValueError: + continue + next_number = max(existing_numbers, default=0) + 1 + return os.path.join(save_dir, f"{base_name}_{next_number}{ext}") -def train_model(model, criterion, optimizer, scheduler, num_epochs=25, model_save_dir='model', save_onnx=False): +def get_data_loaders(data_dir, batch_size): + """Prepare data loaders for training and validation.""" + data_transforms = { + 'train': transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]), + 'val': transforms.Compose([ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]), + } + + image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) + for x in ['train', 'val']} + dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, + shuffle=True, num_workers=4) + for x in ['train', 'val']} + dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} + class_names = image_datasets['train'].classes + + return dataloaders, dataset_sizes, class_names + + +def train_model(model, criterion, optimizer, scheduler, dataloaders, dataset_sizes, device, + num_epochs, model_save_dir, model_base_name): + """Train the model and save the best weights.""" since = time.time() - - # 创建保存模型的目录(如果不存在) os.makedirs(model_save_dir, exist_ok=True) - # 设置模型保存路径 - best_model_params_path = os.path.join(model_save_dir, 'best_model_params_11.14.19.30.pt') - best_model_onnx_path = os.path.join(model_save_dir, 'best_model_11.14.19.30.onnx') - - # 初始保存模型 - torch.save(model.state_dict(), best_model_params_path) best_acc = 0.0 + best_model_path = None # 用于保存最佳模型路径 for epoch in range(num_epochs): print(f'Epoch {epoch}/{num_epochs - 1}') @@ -76,16 +78,15 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25, model_sav for phase in ['train', 'val']: if phase == 'train': - model.train() # Set model to training mode + model.train() else: - model.eval() # Set model to evaluate mode + model.eval() running_loss = 0.0 running_corrects = 0 for inputs, labels in dataloaders[phase]: - inputs = inputs.to(device) - labels = labels.to(device) + inputs, labels = inputs.to(device), labels.to(device) optimizer.zero_grad() @@ -111,78 +112,56 @@ def train_model(model, criterion, optimizer, scheduler, num_epochs=25, model_sav if phase == 'val' and epoch_acc > best_acc: best_acc = epoch_acc - # 保存最佳模型 - torch.save(model.state_dict(), best_model_params_path) - # 如果需要保存ONNX格式的模型,可以在这里执行 - if save_onnx: - # 导出模型为 ONNX 格式 - model.eval() - dummy_input = torch.randn(1, 3, 224, 224).to(device) # 用于推理的假输入(与训练时输入的大小一致) - torch.onnx.export(model, dummy_input, best_model_onnx_path, - export_params=True, opset_version=11, - do_constant_folding=True, input_names=['input'], output_names=['output']) - print(f"ONNX model saved at {best_model_onnx_path}") + # 保存新的最佳模型 + best_model_path = get_next_model_name(model_save_dir, model_base_name, '.pt') - print() + torch.save(model.state_dict(), best_model_path) + print(f"Best model weights saved at {best_model_path}") time_elapsed = time.time() - since print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s') - print(f'Best val Acc: {best_acc:4f}') + print(f'Best val Acc: {best_acc:.4f}') - model.load_state_dict(torch.load(best_model_params_path, weights_only=True)) + if best_model_path: + # 仅加载最后保存的最佳模型 + model.load_state_dict(torch.load(best_model_path, weights_only=True)) + else: + print("No best model was saved during training.") return model -# def visualize_model(model, num_images=6): -# was_training = model.training -# model.eval() -# images_so_far = 0 -# fig = plt.figure() -# -# with torch.no_grad(): -# for i, (inputs, labels) in enumerate(dataloaders['train']): -# inputs = inputs.to(device) -# labels = labels.to(device) -# -# outputs = model(inputs) -# _, preds = torch.max(outputs, 1) -# -# for j in range(inputs.size()[0]): -# images_so_far += 1 -# ax = plt.subplot(num_images // 2, 2, images_so_far) -# ax.axis('off') -# ax.set_title(f'predicted: {class_names[preds[j]]}') -# plt.imshow(inputs.cpu().data[j]) -# -# if images_so_far == num_images: -# model.train(mode=was_training) -# return -# model.train(mode=was_training) +# 参数解析部分 +parser = argparse.ArgumentParser(description="Train a ResNet18 model on custom dataset.") +parser.add_argument('--data-dir', type=str, default='dataset', help='Path to the dataset directory.') +parser.add_argument('--model-save-dir', type=str, default='model', help='Directory to save the trained model.') +parser.add_argument('--batch-size', type=int, default=8, help='Batch size for training.') +parser.add_argument('--epochs', type=int, default=5, help='Number of training epochs.') +parser.add_argument('--lr', type=float, default=0.0005, help='Learning rate.') +parser.add_argument('--momentum', type=float, default=0.95, help='Momentum for SGD.') +parser.add_argument('--step-size', type=int, default=5, help='Step size for LR scheduler.') +parser.add_argument('--gamma', type=float, default=0.2, help='Gamma for LR scheduler.') +parser.add_argument('--num-classes', type=int, default=2, help='Number of output classes.') +parser.add_argument('--model-base-name', type=str, default='best_model', help='Base name for saved models.') +def main(args): + """Main function to train the model based on command-line arguments.""" + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -if __name__ == '__main__': - # TODO: - # 1. 能像yolo一样使用不同参数运行 - # 2. 能够比较不同模型的预测时间,根据配置可以切换模型。 + dataloaders, dataset_sizes, class_names = get_data_loaders(args.data_dir, args.batch_size) - model_ft = models.resnet18(weights='IMAGENET1K_V1') - num_ftrs = model_ft.fc.in_features - model_ft.fc = nn.Linear(num_ftrs, 2) - - model_ft = model_ft.to(device) + model = models.resnet18(weights='IMAGENET1K_V1') + num_ftrs = model.fc.in_features + model.fc = nn.Linear(num_ftrs, args.num_classes) + model = model.to(device) criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) + scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) - optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9) + train_model(model, criterion, optimizer, scheduler, dataloaders, dataset_sizes, device, + args.epochs, args.model_save_dir, args.model_base_name) - exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1) - - # 指定保存模型的目录 - model_save_dir = 'model' # 你可以修改为你希望保存的路径 - - # 训练模型并保存 ONNX 格式的模型 - model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25, - model_save_dir=model_save_dir, save_onnx=True) - # visualize_model(model_ft) - # TODO: 3. 新增dgd_class/export.py, 可以把训练得到的pt文件转换为onnx +if __name__ == '__main__': + args = parser.parse_args() + main(args) diff --git a/DPL/dgd_class/train_11.19.py b/DPL/dgd_class/train_11.19.py deleted file mode 100644 index a38fd23..0000000 --- a/DPL/dgd_class/train_11.19.py +++ /dev/null @@ -1,167 +0,0 @@ -import os -import argparse -import time -import torch -import torch.nn as nn -import torch.optim as optim -from torch.optim import lr_scheduler -from torchvision import datasets, models, transforms -import matplotlib.pyplot as plt - -# 设置 cudnn 优化 -torch.backends.cudnn.benchmark = True -plt.ion() - -def get_next_model_name(save_dir, base_name, ext): - """ - 检索保存目录,生成递增编号的模型文件名。 - - Args: - save_dir (str): 模型保存目录 - base_name (str): 模型基础名称 - ext (str): 模型文件扩展名(例如 ".pt" 或 ".onnx") - - Returns: - str: 自动递增编号的新模型文件名 - """ - os.makedirs(save_dir, exist_ok=True) - existing_files = [f for f in os.listdir(save_dir) if f.startswith(base_name) and f.endswith(ext)] - existing_numbers = [] - for f in existing_files: - try: - num = int(f[len(base_name):].split('.')[0].strip('_')) - existing_numbers.append(num) - except ValueError: - continue - next_number = max(existing_numbers, default=0) + 1 - return os.path.join(save_dir, f"{base_name}_{next_number}{ext}") - - -def get_data_loaders(data_dir, batch_size): - """Prepare data loaders for training and validation.""" - data_transforms = { - 'train': transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) - ]), - 'val': transforms.Compose([ - transforms.Resize(256), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) - ]), - } - - image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) - for x in ['train', 'val']} - dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, - shuffle=True, num_workers=4) - for x in ['train', 'val']} - dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']} - class_names = image_datasets['train'].classes - - return dataloaders, dataset_sizes, class_names - - -def train_model(model, criterion, optimizer, scheduler, dataloaders, dataset_sizes, device, - num_epochs, model_save_dir, model_base_name): - """Train the model and save the best weights.""" - since = time.time() - os.makedirs(model_save_dir, exist_ok=True) - - best_acc = 0.0 - best_model_path = None # 用于保存最佳模型路径 - - for epoch in range(num_epochs): - print(f'Epoch {epoch}/{num_epochs - 1}') - print('-' * 10) - - for phase in ['train', 'val']: - if phase == 'train': - model.train() - else: - model.eval() - - running_loss = 0.0 - running_corrects = 0 - - for inputs, labels in dataloaders[phase]: - inputs, labels = inputs.to(device), labels.to(device) - - optimizer.zero_grad() - - with torch.set_grad_enabled(phase == 'train'): - outputs = model(inputs) - _, preds = torch.max(outputs, 1) - loss = criterion(outputs, labels) - - if phase == 'train': - loss.backward() - optimizer.step() - - running_loss += loss.item() * inputs.size(0) - running_corrects += torch.sum(preds == labels.data) - - if phase == 'train': - scheduler.step() - - epoch_loss = running_loss / dataset_sizes[phase] - epoch_acc = running_corrects.double() / dataset_sizes[phase] - - print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}') - - if phase == 'val' and epoch_acc > best_acc: - best_acc = epoch_acc - - # 保存新的最佳模型 - best_model_path = get_next_model_name(model_save_dir, model_base_name, '.pt') - - torch.save(model.state_dict(), best_model_path) - print(f"Best model weights saved at {best_model_path}") - - time_elapsed = time.time() - since - print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s') - print(f'Best val Acc: {best_acc:.4f}') - - if best_model_path: - # 仅加载最后保存的最佳模型 - model.load_state_dict(torch.load(best_model_path, weights_only=True)) - else: - print("No best model was saved during training.") - return model - - -# 参数解析部分 -parser = argparse.ArgumentParser(description="Train a ResNet18 model on custom dataset.") -parser.add_argument('--data-dir', type=str, default='d_2', help='Path to the dataset directory.') -parser.add_argument('--model-save-dir', type=str, default='model', help='Directory to save the trained model.') -parser.add_argument('--batch-size', type=int, default=8, help='Batch size for training.') -parser.add_argument('--epochs', type=int, default=5, help='Number of training epochs.') -parser.add_argument('--lr', type=float, default=0.0005, help='Learning rate.') -parser.add_argument('--momentum', type=float, default=0.95, help='Momentum for SGD.') -parser.add_argument('--step-size', type=int, default=5, help='Step size for LR scheduler.') -parser.add_argument('--gamma', type=float, default=0.2, help='Gamma for LR scheduler.') -parser.add_argument('--num-classes', type=int, default=2, help='Number of output classes.') -parser.add_argument('--model-base-name', type=str, default='best_model', help='Base name for saved models.') - -def main(args): - """Main function to train the model based on command-line arguments.""" - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - dataloaders, dataset_sizes, class_names = get_data_loaders(args.data_dir, args.batch_size) - - model = models.resnet18(weights='IMAGENET1K_V1') - num_ftrs = model.fc.in_features - model.fc = nn.Linear(num_ftrs, args.num_classes) - model = model.to(device) - - criterion = nn.CrossEntropyLoss() - optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum) - scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma) - - train_model(model, criterion, optimizer, scheduler, dataloaders, dataset_sizes, device, - args.epochs, args.model_save_dir, args.model_base_name) - -if __name__ == '__main__': - args = parser.parse_args() - main(args)