mask代码修改

This commit is contained in:
wrz1-zzzzz 2024-11-20 16:51:42 +08:00
commit 81ccb6681a
14 changed files with 346 additions and 337 deletions

1
.gitignore vendored
View File

@ -366,3 +366,4 @@ FodyWeavers.xsd
cmake-build-*
.DS_Store
DPL/dataset/*

View File

@ -1,12 +0,0 @@
from symbol import pass_stmt
class Model:
def __init__(self):
pass
def load(self, weight_path: str):
pass
def predict(self, weight_path: str):
pass

View File

@ -1,67 +0,0 @@
import torch
from PIL import Image
from torch import nn
import onnx
import onnxruntime as ort # 用于ONNX推理
from torchvision import datasets, models, transforms
import os
import matplotlib.pyplot as plt
from train import device, class_names, imshow
# 加载已训练的 ONNX 模型
def load_onnx_model(model_path='model/best_model_11.14.19.30.onnx'):
# 使用 ONNX Runtime 加载模型
session = ort.InferenceSession(model_path)
return session
# 预测函数
def visualize_model_predictions(onnx_session: str, img_path: str):
img = Image.open(img_path)
img = img.convert('RGB') # 转换为 RGB 模式
data_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
img = data_transforms(img)
img = img.unsqueeze(0)
img = img.to(device)
# 将输入转换为 ONNX 兼容的格式numpy 数组)
img = img.cpu().numpy()
# 使用 ONNX Runtime 进行推理
inputs = {onnx_session.get_inputs()[0].name: img}
outputs = onnx_session.run(None, inputs)
preds = outputs[0]
# 获取预测类别
_, predicted_class = torch.max(torch.tensor(preds), 1)
# 可视化结果
ax = plt.subplot(2, 2, 1)
ax.axis('off')
ax.set_title(f'Predicted: {class_names[predicted_class[0]]}')
imshow(img[0])
# 使用已训练的 ONNX 模型进行预测
if __name__ == '__main__':
# TODO:
# 写一个模型类model = Model()
# 1. 标准化模型加载接口 model.load(weight_path: str)
# 2. 标准化预测接口调用方式为 output_image: torch.tensor = model.predict(input_image: torch.tensor, param: dict)
# 加载 ONNX 模型
model_path = 'model/best_model_11.14.19.30.onnx'
onnx_session = load_onnx_model(model_path)
# 图像路径
img_path = 'd_2/train/dgd/transformed_1.jpg' # 更改为你的图像路径
visualize_model_predictions(onnx_session, img_path)

View File

@ -1,188 +0,0 @@
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
from PIL import Image
# 初始化 cudnn 优化
cudnn.benchmark = True
plt.ion() # interactive mode
data_transforms = {
'train': transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
data_dir = 'd_2'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
shuffle=True, num_workers=4)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
device = torch.device("cuda:0")
#
def imshow(inp, title=None):
"""Display image for Tensor."""
if isinstance(inp, np.ndarray):
inp = inp.transpose((1, 2, 0))
# inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = np.clip(inp, 0, 1)
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(0.001) # pause a bit so that plots are updated
def train_model(model, criterion, optimizer, scheduler, num_epochs=25, model_save_dir='model', save_onnx=False):
since = time.time()
# 创建保存模型的目录(如果不存在)
os.makedirs(model_save_dir, exist_ok=True)
# 设置模型保存路径
best_model_params_path = os.path.join(model_save_dir, 'best_model_params_11.14.19.30.pt')
best_model_onnx_path = os.path.join(model_save_dir, 'best_model_11.14.19.30.onnx')
# 初始保存模型
torch.save(model.state_dict(), best_model_params_path)
best_acc = 0.0
for epoch in range(num_epochs):
print(f'Epoch {epoch}/{num_epochs - 1}')
print('-' * 10)
for phase in ['train', 'val']:
if phase == 'train':
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
running_loss = 0.0
running_corrects = 0
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
if phase == 'train':
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
if phase == 'train':
scheduler.step()
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects.double() / dataset_sizes[phase]
print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
# 保存最佳模型
torch.save(model.state_dict(), best_model_params_path)
# 如果需要保存ONNX格式的模型可以在这里执行
if save_onnx:
# 导出模型为 ONNX 格式
model.eval()
dummy_input = torch.randn(1, 3, 224, 224).to(device) # 用于推理的假输入(与训练时输入的大小一致)
torch.onnx.export(model, dummy_input, best_model_onnx_path,
export_params=True, opset_version=11,
do_constant_folding=True, input_names=['input'], output_names=['output'])
print(f"ONNX model saved at {best_model_onnx_path}")
print()
time_elapsed = time.time() - since
print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
print(f'Best val Acc: {best_acc:4f}')
model.load_state_dict(torch.load(best_model_params_path, weights_only=True))
return model
# def visualize_model(model, num_images=6):
# was_training = model.training
# model.eval()
# images_so_far = 0
# fig = plt.figure()
#
# with torch.no_grad():
# for i, (inputs, labels) in enumerate(dataloaders['train']):
# inputs = inputs.to(device)
# labels = labels.to(device)
#
# outputs = model(inputs)
# _, preds = torch.max(outputs, 1)
#
# for j in range(inputs.size()[0]):
# images_so_far += 1
# ax = plt.subplot(num_images // 2, 2, images_so_far)
# ax.axis('off')
# ax.set_title(f'predicted: {class_names[preds[j]]}')
# plt.imshow(inputs.cpu().data[j])
#
# if images_so_far == num_images:
# model.train(mode=was_training)
# return
# model.train(mode=was_training)
if __name__ == '__main__':
# TODO:
# 1. 能像yolo一样使用不同参数运行
# 2. 能够比较不同模型的预测时间,根据配置可以切换模型。
model_ft = models.resnet18(weights='IMAGENET1K_V1')
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, 2)
model_ft = model_ft.to(device)
criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
# 指定保存模型的目录
model_save_dir = 'model' # 你可以修改为你希望保存的路径
# 训练模型并保存 ONNX 格式的模型
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25,
model_save_dir=model_save_dir, save_onnx=True)
# visualize_model(model_ft)
# TODO: 3. 新增dgd_class/export.py, 可以把训练得到的pt文件转换为onnx

39
DPL/main.py Normal file
View File

@ -0,0 +1,39 @@
import argparse
import os
import sys
import torch
from model_cls.models import Model as ClsModel
from pathlib import Path
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0] # YOLOv5 root directory
if str(ROOT) not in sys.path:
sys.path.append(str(ROOT)) # add ROOT to PATH
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
def main():
# 命令行参数解析
parser = argparse.ArgumentParser(description="Use an ONNX model for inference.")
# 设置默认值,并允许用户通过命令行进行修改
parser.add_argument('--input-dir', type=str, default='dataset/', help='Directory to input images')
parser.add_argument('--save-dir', type=str, default='detect', help='Directory to save output images')
parser.add_argument('--gpu', action='store_true', help='Use GPU for inference')
args = parser.parse_args()
# 设置设备
device = torch.device('cuda' if args.gpu and torch.cuda.is_available() else 'cpu')
if args.gpu and not torch.cuda.is_available():
print("GPU not available, switching to CPU.")
# 加载模型
model = ClsModel(model_path=args.weights, device=device)
if __name__ == '__main__':
main()

57
DPL/model_cls/models.py Normal file
View File

@ -0,0 +1,57 @@
from typing import Optional
from torchvision import transforms
import numpy as np
import onnxruntime as ort
import torch
# 模型类
class Model:
def __init__(self, model_path: str, device: torch.device, block_size: Optional[tuple] = None):
"""
初始化模型加载 ONNX 模型并设置设备CPU GPU
"""
self.device = device
self.session = ort.InferenceSession(model_path)
self.block_size = block_size
self.data_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def predict(self, img: np.ndarray) -> torch.Tensor:
"""
使用 ONNX 模型进行推理返回预测结果
"""
img = self.data_transforms(img)
img = img.unsqueeze(0)
# # 转换为 ONNX 输入格式
img_numpy = img.cpu().numpy()
inputs = {self.session.get_inputs()[0].name: img_numpy}
outputs = self.session.run(None, inputs)
pred = torch.tensor(outputs[0])
_, predicted_class = torch.max(pred, 1)
return predicted_class
def load(self, model_path: str):
"""
重新加载模型
"""
self.session = ort.InferenceSession(model_path, providers=['TensorrtExecutionProvider'])
def preprocess_predict_img(self, img: np.ndarray, block_size: Optional[tuple] = None) -> np.ndarray:
if block_size is None:
block_size = self.block_size
if block_size:
img = img.reshape((self.block_size[0], -1, self.block_size[1], 3))
img = img.transpose((1, 0, 2, 3))
else:
img = img[np.newaxis, ...]
pred = self.predict(img)
pred = pred.squeeze().numpy()
return pred

View File

@ -1,6 +1,6 @@
import argparse
import torch
import onnxruntime as ort
from models import *
from PIL import Image
import os
import matplotlib.pyplot as plt
@ -9,34 +9,6 @@ from PIL import ImageDraw, ImageFont
from torchvision import transforms
import time # 导入time模块
# 模型类
class Model:
def __init__(self, model_path: str, device: torch.device):
"""
初始化模型加载 ONNX 模型并设置设备CPU GPU
"""
self.device = device
self.session = ort.InferenceSession(model_path)
def predict(self, img_tensor: torch.Tensor) -> torch.Tensor:
"""
使用 ONNX 模型进行推理返回预测结果
"""
# 转换为 ONNX 输入格式
img_numpy = img_tensor.cpu().numpy()
# 获取输入名称和推理
inputs = {self.session.get_inputs()[0].name: img_numpy}
outputs = self.session.run(None, inputs)
return torch.tensor(outputs[0])
def load(self, model_path: str):
"""
重新加载模型
"""
self.session = ort.InferenceSession(model_path)
# 预测函数
def visualize_model_predictions(model: Model, img_path: str, save_dir: str, class_names: list):
@ -48,25 +20,12 @@ def visualize_model_predictions(model: Model, img_path: str, save_dir: str, clas
img = Image.open(img_path)
img = img.convert('RGB') # 转换为 RGB 模式
# 图像预处理
data_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
img_transformed = data_transforms(img)
img_transformed = img_transformed.unsqueeze(0)
img_transformed = img_transformed.to(model.device)
# 使用模型进行预测
preds = model.predict(img_transformed)
# 获取预测类别
_, predicted_class = torch.max(preds, 1)
preds = model.predict(img)
print(preds)
# 在图像上添加预测结果文本
predicted_label = class_names[predicted_class[0]]
predicted_label = class_names[preds[0]]
# 在图片上绘制文本
img_with_text = img.copy()
@ -121,8 +80,9 @@ def main():
parser = argparse.ArgumentParser(description="Use an ONNX model for inference.")
# 设置默认值,并允许用户通过命令行进行修改
parser.add_argument('--weights', type=str, default='model/model_1.onnx', help='Path to ONNX model file')
parser.add_argument('--img-path', type=str, default='d_2/val/dgd',
parser.add_argument('--weights', type=str, default=r'..\onnxs\dgd_class_11.14.onnx',
help='Path to ONNX model file')
parser.add_argument('--img-path', type=str, default='../dataset/val/dgd',
help='Path to image or folder for inference')
parser.add_argument('--save-dir', type=str, default='detect', help='Directory to save output images')
parser.add_argument('--gpu', action='store_true', help='Use GPU for inference')

196
DPL/model_cls/preprocess.py Normal file
View File

@ -0,0 +1,196 @@
import argparse
import os
import random
import shutil
import cv2
import numpy as np
import json
from shapely.geometry import Polygon, box
from shapely.affinity import translate
def create_dataset_from_folder(image_folder, label_folder, block_size, output_dir):
"""
批量处理文件夹中的图片和对应标签生成分类模型的数据集
Args:
image_folder (str): 图片文件夹路径
label_folder (str): 标签文件夹路径对应Labelme生成的JSON文件
block_size (tuple): 分块的尺寸 (width, height)
output_dir (str): 输出数据集的根目录
"""
# 创建输出文件夹
has_label_dir = os.path.join(output_dir, "has_label")
no_label_dir = os.path.join(output_dir, "no_label")
os.makedirs(has_label_dir, exist_ok=True)
os.makedirs(no_label_dir, exist_ok=True)
# 遍历图片文件夹
for filename in os.listdir(image_folder):
if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tif')):
image_path = os.path.join(image_folder, filename)
label_path = os.path.join(label_folder, os.path.splitext(filename)[0] + ".json")
# 检查标签文件是否存在
if not os.path.exists(label_path):
print(f"Label file not found for image: {filename}")
continue
print(f"Processing {filename}...")
process_single_image(image_path, label_path, block_size, has_label_dir, no_label_dir)
def process_single_image(image_path, label_path, block_size, has_label_dir, no_label_dir):
"""
处理单张图片并分块保存到对应的文件夹
Args:
image_path (str): 图片路径
label_path (str): 标签路径
block_size (tuple): 分块的尺寸 (width, height)
has_label_dir (str): 包含标注的分块保存目录
no_label_dir (str): 无标注的分块保存目录
"""
# 加载图片
image = cv2.imread(image_path)
img_height, img_width, _ = image.shape
# 加载Labelme JSON文件
with open(label_path, 'r', encoding='utf-8') as f:
label_data = json.load(f)
# 提取多边形标注
polygons = []
for shape in label_data['shapes']:
if shape['shape_type'] == 'polygon':
points = shape['points']
polygons.append(Polygon(points))
if roi:
x_min, y_min, x_max, y_max = roi
x_min, y_min = max(0, x_min), max(0, y_min)
x_max, y_max = min(img_width, x_max), min(img_height, y_max)
image = image[y_min:y_max, x_min:x_max]
img_height, img_width = y_max - y_min, x_max - x_min
# 偏移标注的多边形
polygons = [translate(poly.intersection(box(x_min, y_min, x_max, y_max)), -x_min, -y_min) for poly in polygons]
# 分割图片并保存到对应的文件夹
block_width, block_height = block_size
block_id = 0
base_name = os.path.splitext(os.path.basename(image_path))[0]
for y in range(0, img_height, block_height):
for x in range(0, img_width, block_width):
# 当前分块的边界框
block_polygon = box(x, y, x + block_width, y + block_height)
# 判断是否与任何标注的多边形相交
contains_label = any(poly.intersects(block_polygon) for poly in polygons)
# 裁剪当前块
block = image[y:y + block_height, x:x + block_width]
# 保存到对应文件夹
folder = has_label_dir if contains_label else no_label_dir
block_filename = os.path.join(folder, f"{base_name}_block_{block_id}.jpg")
cv2.imwrite(block_filename, block)
block_id += 1
print(f"Saved {block_filename} to {'has_label' if contains_label else 'no_label'} folder.")
def split_dataset(input_dir, output_dir, train_ratio=0.7, val_ratio=0.2, test_ratio=0.1):
"""
将分类后的数据集划分为 trainval test 数据集保持来源目录结构
Args:
input_dir (str): 分类结果根目录包含多个子文件夹标签文件夹
output_dir (str): 输出数据集目录将生成 trainval test 文件夹
train_ratio (float): train 集比例或固定数量
val_ratio (float): val 集比例或固定数量
test_ratio (float): test 集比例或固定数量
"""
# 定义输出子目录
train_dir = os.path.join(output_dir, "train")
val_dir = os.path.join(output_dir, "val")
test_dir = os.path.join(output_dir, "test")
# 遍历所有子文件夹(标签文件夹)
for category in os.listdir(input_dir):
category_dir = os.path.join(input_dir, category)
if not os.path.isdir(category_dir): # 忽略非文件夹
continue
# 为当前类别在 train/val/test 创建相同的子文件夹结构
os.makedirs(os.path.join(train_dir, category), exist_ok=True)
os.makedirs(os.path.join(val_dir, category), exist_ok=True)
os.makedirs(os.path.join(test_dir, category), exist_ok=True)
# 获取当前类别下的所有文件
files = os.listdir(category_dir)
random.shuffle(files)
# 计算分割点
total_files = len(files)
if train_ratio < 1:
train_count = int(total_files * train_ratio)
val_count = int(total_files * val_ratio)
test_count = total_files - train_count - val_count
else:
train_count = int(train_ratio)
val_count = int(val_ratio)
test_count = int(test_ratio)
# 确保不超过文件总数
train_count = min(train_count, total_files)
val_count = min(val_count, total_files - train_count)
test_count = min(test_count, total_files - train_count - val_count)
print(f"Category {category}: {train_count} train, {val_count} val, {test_count} test files")
# 划分数据集
train_files = files[:train_count]
val_files = files[train_count:train_count + val_count]
test_files = files[train_count + val_count:]
# 复制文件到对应的子目录
for file in train_files:
shutil.copy(os.path.join(category_dir, file), os.path.join(train_dir, category, file))
for file in val_files:
shutil.copy(os.path.join(category_dir, file), os.path.join(val_dir, category, file))
for file in test_files:
shutil.copy(os.path.join(category_dir, file), os.path.join(test_dir, category, file))
print(f"Category {category} processed. Train: {len(train_files)}, Val: {len(val_files)}, Test: {len(test_files)}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Preprocess images to required shapes")
# 设置默认值,并允许用户通过命令行进行修改
parser.add_argument('--img-dir', type=str, default=r'..\dataset\dgd\test_img', help='Directory to input images')
parser.add_argument('--label-dir', type=str, default=r'..\dataset\dgd\test_img', help='Directory to input labels')
parser.add_argument('--output-dir', type=str, default=r'..\dataset\dgd\runs', help='Directory to save output images')
parser.add_argument('--roi', type=int, nargs=4, default=None, help='ROI region (x_min y_min x_max y_max)')
parser.add_argument('--train-ratio', type=float, default=0.7, help='Train set ratio or count')
parser.add_argument('--val-ratio', type=float, default=0.2, help='Validation set ratio or count')
parser.add_argument('--test-ratio', type=float, default=0.1, help='Test set ratio or count')
args = parser.parse_args()
# 输入文件夹路径
image_folder = args.img_dir
label_folder = args.label_dir
# 输出路径
output_dir = args.output_dir
# 分块大小
block_size = (170, 170) # 替换为希望的分块尺寸 (宽, 高)
roi = tuple(args.roi) if args.roi else None
all_output = os.path.join(output_dir, 'all')
# 批量生成数据集
create_dataset_from_folder(image_folder, label_folder, block_size, all_output)
split_dataset(all_output, output_dir, args.train_ratio, args.val_ratio, args.test_ratio)

View File

@ -1,6 +1,8 @@
import os
import argparse
import time
from datetime import datetime
import torch
import torch.nn as nn
import torch.optim as optim
@ -10,7 +12,7 @@ import matplotlib.pyplot as plt
# 设置 cudnn 优化
torch.backends.cudnn.benchmark = True
plt.ion()
def get_next_model_name(save_dir, base_name, ext):
"""
@ -131,37 +133,58 @@ def train_model(model, criterion, optimizer, scheduler, dataloaders, dataset_siz
return model
# 参数解析部分
parser = argparse.ArgumentParser(description="Train a ResNet18 model on custom dataset.")
parser.add_argument('--data-dir', type=str, default='d_2', help='Path to the dataset directory.')
parser.add_argument('--model-save-dir', type=str, default='model', help='Directory to save the trained model.')
parser.add_argument('--batch-size', type=int, default=8, help='Batch size for training.')
parser.add_argument('--epochs', type=int, default=5, help='Number of training epochs.')
parser.add_argument('--lr', type=float, default=0.0005, help='Learning rate.')
parser.add_argument('--momentum', type=float, default=0.95, help='Momentum for SGD.')
parser.add_argument('--step-size', type=int, default=5, help='Step size for LR scheduler.')
parser.add_argument('--gamma', type=float, default=0.2, help='Gamma for LR scheduler.')
parser.add_argument('--num-classes', type=int, default=2, help='Number of output classes.')
parser.add_argument('--model-base-name', type=str, default='best_model', help='Base name for saved models.')
def main(args):
"""Main function to train the model based on command-line arguments."""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
dataloaders, dataset_sizes, class_names = get_data_loaders(args.data_dir, args.batch_size)
model = models.resnet18(weights='IMAGENET1K_V1')
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, args.num_classes)
if args.model_name == 'resnet18':
model = models.resnet18(weights='IMAGENET1K_V1')
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, args.num_classes)
elif args.model_name == 'resnet50':
model = models.resnet50(weights='IMAGENET1K_V1')
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, args.num_classes)
elif args.model_name == 'alexnet':
model = models.alexnet(pretrained=True)
num_ftrs = model.classifier[-1].in_features
model.classifier[-1] = nn.Linear(num_ftrs, args.num_classes)
elif args.model_name == 'vgg16':
model = models.vgg16(pretrained=True)
num_ftrs = model.classifier[-1].in_features
model.classifier[-1] = nn.Linear(num_ftrs, args.num_classes)
elif args.model_name == 'densenet':
model = models.densenet121(pretrained=True)
num_ftrs = model.classifier.in_features
model.classifier = nn.Linear(num_ftrs, args.num_classes)
elif args.model_name == 'inceptionv3':
model = models.inception_v3(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, args.num_classes)
model = model.to(device)
# training start time and model info
model_base_name = f"{args.model_name}_bs{args.batch_size}_ep{args.epochs}_{datetime.now().strftime('%y_%m_%d')}.pt"
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
train_model(model, criterion, optimizer, scheduler, dataloaders, dataset_sizes, device,
args.epochs, args.model_save_dir, args.model_base_name)
args.epochs, args.model_save_dir, model_base_name)
if __name__ == '__main__':
# 参数解析部分
parser = argparse.ArgumentParser(description="Train a ResNet18 model on custom dataset.")
parser.add_argument('--model_name', type=str, default='resnet18', help='Which model to train as base model')
parser.add_argument('--data-dir', type=str, default='dataset', help='Path to the dataset directory.')
parser.add_argument('--model-save-dir', type=str, default='model', help='Directory to save the trained model.')
parser.add_argument('--batch-size', type=int, default=8, help='Batch size for training.')
parser.add_argument('--epochs', type=int, default=5, help='Number of training epochs.')
parser.add_argument('--lr', type=float, default=0.0005, help='Learning rate.')
parser.add_argument('--momentum', type=float, default=0.95, help='Momentum for SGD.')
parser.add_argument('--step-size', type=int, default=5, help='Step size for LR scheduler.')
parser.add_argument('--gamma', type=float, default=0.2, help='Gamma for LR scheduler.')
parser.add_argument('--num-classes', type=int, default=2, help='Number of output classes.')
args = parser.parse_args()
main(args)

0
DPL/model_cls/utils.py Normal file
View File

View File

@ -2,9 +2,9 @@ from model import Model
# 初始化 Model 对象,实例化时自动生成并保存掩膜
model = Model(
image_folder="runs/detect/exp6",
label_folder="runs/detect/labels",
output_folder="datasets/mask"
image_folder="",
label_folder="",
output_folder=""
)
# 读取某张图片对应的掩膜