mirror of
https://github.com/Karllzy/cotton_color.git
synced 2025-11-08 18:53:53 +00:00
commit
96efe117b8
1
.gitignore
vendored
1
.gitignore
vendored
@ -366,3 +366,4 @@ FodyWeavers.xsd
|
||||
cmake-build-*
|
||||
|
||||
.DS_Store
|
||||
DPL/dataset/*
|
||||
@ -1,62 +0,0 @@
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch import nn
|
||||
import onnx
|
||||
import onnxruntime as ort # 用于ONNX推理
|
||||
|
||||
from torchvision import datasets, models, transforms
|
||||
import os
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from train import device, class_names, imshow
|
||||
|
||||
|
||||
# 加载已训练的 ONNX 模型
|
||||
def load_onnx_model(model_path='model/best_model_11.14.19.30.onnx'):
|
||||
# 使用 ONNX Runtime 加载模型
|
||||
session = ort.InferenceSession(model_path)
|
||||
return session
|
||||
|
||||
|
||||
# 预测函数
|
||||
def visualize_model_predictions(onnx_session, img_path):
|
||||
img = Image.open(img_path)
|
||||
img = img.convert('RGB') # 转换为 RGB 模式
|
||||
|
||||
data_transforms = transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
img = data_transforms(img)
|
||||
img = img.unsqueeze(0)
|
||||
img = img.to(device)
|
||||
|
||||
# 将输入转换为 ONNX 兼容的格式(numpy 数组)
|
||||
img = img.cpu().numpy()
|
||||
|
||||
# 使用 ONNX Runtime 进行推理
|
||||
inputs = {onnx_session.get_inputs()[0].name: img}
|
||||
outputs = onnx_session.run(None, inputs)
|
||||
preds = outputs[0]
|
||||
|
||||
# 获取预测类别
|
||||
_, predicted_class = torch.max(torch.tensor(preds), 1)
|
||||
|
||||
# 可视化结果
|
||||
ax = plt.subplot(2, 2, 1)
|
||||
ax.axis('off')
|
||||
ax.set_title(f'Predicted: {class_names[predicted_class[0]]}')
|
||||
imshow(img[0])
|
||||
|
||||
|
||||
# 使用已训练的 ONNX 模型进行预测
|
||||
if __name__ == '__main__':
|
||||
# 加载 ONNX 模型
|
||||
model_path = 'model/best_model_11.14.19.30.onnx'
|
||||
onnx_session = load_onnx_model(model_path)
|
||||
|
||||
# 图像路径
|
||||
img_path = 'd_2/train/dgd/transformed_1.jpg' # 更改为你的图像路径
|
||||
visualize_model_predictions(onnx_session, img_path)
|
||||
@ -1,183 +0,0 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.optim import lr_scheduler
|
||||
import torch.backends.cudnn as cudnn
|
||||
import numpy as np
|
||||
import torchvision
|
||||
from torchvision import datasets, models, transforms
|
||||
import matplotlib.pyplot as plt
|
||||
import time
|
||||
import os
|
||||
from PIL import Image
|
||||
|
||||
# 初始化 cudnn 优化
|
||||
cudnn.benchmark = True
|
||||
plt.ion() # interactive mode
|
||||
|
||||
data_transforms = {
|
||||
'train': transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
]),
|
||||
'val': transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
]),
|
||||
}
|
||||
|
||||
data_dir = 'd_2'
|
||||
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
|
||||
data_transforms[x])
|
||||
for x in ['train', 'val']}
|
||||
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
|
||||
shuffle=True, num_workers=4)
|
||||
for x in ['train', 'val']}
|
||||
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
|
||||
class_names = image_datasets['train'].classes
|
||||
|
||||
device = torch.device("cuda:0")
|
||||
|
||||
#
|
||||
def imshow(inp, title=None):
|
||||
"""Display image for Tensor."""
|
||||
if isinstance(inp, np.ndarray):
|
||||
inp = inp.transpose((1, 2, 0))
|
||||
# inp = inp.numpy().transpose((1, 2, 0))
|
||||
mean = np.array([0.485, 0.456, 0.406])
|
||||
std = np.array([0.229, 0.224, 0.225])
|
||||
inp = std * inp + mean
|
||||
inp = np.clip(inp, 0, 1)
|
||||
plt.imshow(inp)
|
||||
if title is not None:
|
||||
plt.title(title)
|
||||
plt.pause(0.001) # pause a bit so that plots are updated
|
||||
|
||||
|
||||
def train_model(model, criterion, optimizer, scheduler, num_epochs=25, model_save_dir='model', save_onnx=False):
|
||||
since = time.time()
|
||||
|
||||
# 创建保存模型的目录(如果不存在)
|
||||
os.makedirs(model_save_dir, exist_ok=True)
|
||||
|
||||
# 设置模型保存路径
|
||||
best_model_params_path = os.path.join(model_save_dir, 'best_model_params_11.14.19.30.pt')
|
||||
best_model_onnx_path = os.path.join(model_save_dir, 'best_model_11.14.19.30.onnx')
|
||||
|
||||
# 初始保存模型
|
||||
torch.save(model.state_dict(), best_model_params_path)
|
||||
best_acc = 0.0
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
print(f'Epoch {epoch}/{num_epochs - 1}')
|
||||
print('-' * 10)
|
||||
|
||||
for phase in ['train', 'val']:
|
||||
if phase == 'train':
|
||||
model.train() # Set model to training mode
|
||||
else:
|
||||
model.eval() # Set model to evaluate mode
|
||||
|
||||
running_loss = 0.0
|
||||
running_corrects = 0
|
||||
|
||||
for inputs, labels in dataloaders[phase]:
|
||||
inputs = inputs.to(device)
|
||||
labels = labels.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
with torch.set_grad_enabled(phase == 'train'):
|
||||
outputs = model(inputs)
|
||||
_, preds = torch.max(outputs, 1)
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
if phase == 'train':
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
running_loss += loss.item() * inputs.size(0)
|
||||
running_corrects += torch.sum(preds == labels.data)
|
||||
|
||||
if phase == 'train':
|
||||
scheduler.step()
|
||||
|
||||
epoch_loss = running_loss / dataset_sizes[phase]
|
||||
epoch_acc = running_corrects.double() / dataset_sizes[phase]
|
||||
|
||||
print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
|
||||
|
||||
if phase == 'val' and epoch_acc > best_acc:
|
||||
best_acc = epoch_acc
|
||||
# 保存最佳模型
|
||||
torch.save(model.state_dict(), best_model_params_path)
|
||||
|
||||
# 如果需要保存ONNX格式的模型,可以在这里执行
|
||||
if save_onnx:
|
||||
# 导出模型为 ONNX 格式
|
||||
model.eval()
|
||||
dummy_input = torch.randn(1, 3, 224, 224).to(device) # 用于推理的假输入(与训练时输入的大小一致)
|
||||
torch.onnx.export(model, dummy_input, best_model_onnx_path,
|
||||
export_params=True, opset_version=11,
|
||||
do_constant_folding=True, input_names=['input'], output_names=['output'])
|
||||
print(f"ONNX model saved at {best_model_onnx_path}")
|
||||
|
||||
print()
|
||||
|
||||
time_elapsed = time.time() - since
|
||||
print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
|
||||
print(f'Best val Acc: {best_acc:4f}')
|
||||
|
||||
model.load_state_dict(torch.load(best_model_params_path, weights_only=True))
|
||||
return model
|
||||
|
||||
|
||||
# def visualize_model(model, num_images=6):
|
||||
# was_training = model.training
|
||||
# model.eval()
|
||||
# images_so_far = 0
|
||||
# fig = plt.figure()
|
||||
#
|
||||
# with torch.no_grad():
|
||||
# for i, (inputs, labels) in enumerate(dataloaders['train']):
|
||||
# inputs = inputs.to(device)
|
||||
# labels = labels.to(device)
|
||||
#
|
||||
# outputs = model(inputs)
|
||||
# _, preds = torch.max(outputs, 1)
|
||||
#
|
||||
# for j in range(inputs.size()[0]):
|
||||
# images_so_far += 1
|
||||
# ax = plt.subplot(num_images // 2, 2, images_so_far)
|
||||
# ax.axis('off')
|
||||
# ax.set_title(f'predicted: {class_names[preds[j]]}')
|
||||
# plt.imshow(inputs.cpu().data[j])
|
||||
#
|
||||
# if images_so_far == num_images:
|
||||
# model.train(mode=was_training)
|
||||
# return
|
||||
# model.train(mode=was_training)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model_ft = models.resnet18(weights='IMAGENET1K_V1')
|
||||
num_ftrs = model_ft.fc.in_features
|
||||
model_ft.fc = nn.Linear(num_ftrs, 2)
|
||||
|
||||
model_ft = model_ft.to(device)
|
||||
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
|
||||
|
||||
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
|
||||
|
||||
# 指定保存模型的目录
|
||||
model_save_dir = 'model' # 你可以修改为你希望保存的路径
|
||||
|
||||
# 训练模型并保存 ONNX 格式的模型
|
||||
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=25,
|
||||
model_save_dir=model_save_dir, save_onnx=True)
|
||||
# visualize_model(model_ft)
|
||||
39
DPL/main.py
Normal file
39
DPL/main.py
Normal file
@ -0,0 +1,39 @@
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
from model_cls.models import Model as ClsModel
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
FILE = Path(__file__).resolve()
|
||||
ROOT = FILE.parents[0] # YOLOv5 root directory
|
||||
if str(ROOT) not in sys.path:
|
||||
sys.path.append(str(ROOT)) # add ROOT to PATH
|
||||
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
|
||||
|
||||
def main():
|
||||
# 命令行参数解析
|
||||
parser = argparse.ArgumentParser(description="Use an ONNX model for inference.")
|
||||
|
||||
# 设置默认值,并允许用户通过命令行进行修改
|
||||
parser.add_argument('--input-dir', type=str, default='dataset/', help='Directory to input images')
|
||||
parser.add_argument('--save-dir', type=str, default='detect', help='Directory to save output images')
|
||||
parser.add_argument('--gpu', action='store_true', help='Use GPU for inference')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 设置设备
|
||||
device = torch.device('cuda' if args.gpu and torch.cuda.is_available() else 'cpu')
|
||||
if args.gpu and not torch.cuda.is_available():
|
||||
print("GPU not available, switching to CPU.")
|
||||
|
||||
# 加载模型
|
||||
model = ClsModel(model_path=args.weights, device=device)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
98
DPL/model_cls/export.py
Normal file
98
DPL/model_cls/export.py
Normal file
@ -0,0 +1,98 @@
|
||||
import argparse
|
||||
import torch
|
||||
from torchvision import models
|
||||
from torch import nn
|
||||
import os
|
||||
import re
|
||||
|
||||
|
||||
def load_pytorch_model(model_path, num_classes, device):
|
||||
"""
|
||||
加载 PyTorch 模型,并设置最后的分类层。
|
||||
"""
|
||||
print(f"Loading PyTorch model from {model_path}...")
|
||||
model = models.resnet18(weights=None) # 使用 ResNet18 作为示例
|
||||
model.fc = nn.Linear(model.fc.in_features, num_classes) # 修改最后一层
|
||||
model.load_state_dict(torch.load(model_path, map_location=device, weights_only=True))
|
||||
model.to(device) # 将模型加载到指定设备
|
||||
model.eval()
|
||||
print("PyTorch model loaded successfully.")
|
||||
return model
|
||||
|
||||
|
||||
def export_to_onnx(model, onnx_path, img_size, batch_size, device):
|
||||
"""
|
||||
导出 PyTorch 模型为 ONNX 格式,自动递增文件名,并支持 GPU。
|
||||
"""
|
||||
# 确保 `onnx_path` 是一个具体的文件路径
|
||||
if not onnx_path.endswith('.onnx'):
|
||||
os.makedirs(onnx_path, exist_ok=True) # 创建文件夹
|
||||
base_dir = onnx_path
|
||||
else:
|
||||
base_dir = os.path.dirname(onnx_path) or '.' # 提取文件夹部分
|
||||
os.makedirs(base_dir, exist_ok=True)
|
||||
|
||||
# 自动递增文件名
|
||||
base_name = "model"
|
||||
extension = ".onnx"
|
||||
existing_files = [f for f in os.listdir(base_dir) if f.endswith(extension)]
|
||||
|
||||
# 使用正则匹配现有文件名
|
||||
pattern = re.compile(rf"^{base_name}_(\d+){extension}$")
|
||||
numbers = [
|
||||
int(match.group(1)) for f in existing_files if (match := pattern.match(f))
|
||||
]
|
||||
next_number = max(numbers, default=0) + 1 # 计算下一个编号
|
||||
final_name = f"{base_name}_{next_number}{extension}"
|
||||
final_path = os.path.join(base_dir, final_name)
|
||||
|
||||
print(f"Exporting model to ONNX format at {final_path}...")
|
||||
|
||||
# 创建虚拟输入张量,并将其移动到指定设备
|
||||
dummy_input = torch.randn(batch_size, 3, img_size, img_size, device=device)
|
||||
|
||||
# 导出 ONNX
|
||||
torch.onnx.export(
|
||||
model,
|
||||
dummy_input,
|
||||
final_path,
|
||||
input_names=['input'],
|
||||
output_names=['output'],
|
||||
opset_version=11, # ONNX opset 版本
|
||||
dynamic_axes={
|
||||
'input': {0: 'batch_size'}, # 动态批量维度
|
||||
'output': {0: 'batch_size'}
|
||||
}
|
||||
)
|
||||
print(f"Model exported successfully to {final_path}.")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Export PyTorch model to ONNX format.")
|
||||
parser.add_argument('--weights', type=str, required=True, help='Path to PyTorch model weights (.pt file)')
|
||||
parser.add_argument('--onnx-path', type=str, default='onnx', help='Output path for ONNX model')
|
||||
parser.add_argument('--img-size', type=int, default=224, help='Input image size (default: 224)')
|
||||
parser.add_argument('--batch-size', type=int, default=8, help='Input batch size (default: 1)')
|
||||
parser.add_argument('--num-classes', type=int, default=2, help='Number of classes in the model (default: 1000)')
|
||||
parser.add_argument('--use-gpu', action='store_true', help='Enable GPU support during export')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 设置设备
|
||||
device = torch.device('cuda' if args.use_gpu and torch.cuda.is_available() else 'cpu')
|
||||
if args.use_gpu and not torch.cuda.is_available():
|
||||
print("GPU not available, switching to CPU.")
|
||||
|
||||
# 检查权重文件是否存在
|
||||
if not os.path.isfile(args.weights):
|
||||
raise FileNotFoundError(f"Model weights file not found: {args.weights}")
|
||||
|
||||
# 加载模型
|
||||
model = load_pytorch_model(args.weights, args.num_classes, device)
|
||||
|
||||
# 导出为 ONNX
|
||||
export_to_onnx(model, args.onnx_path, args.img_size, args.batch_size, device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
57
DPL/model_cls/models.py
Normal file
57
DPL/model_cls/models.py
Normal file
@ -0,0 +1,57 @@
|
||||
from typing import Optional
|
||||
from torchvision import transforms
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import torch
|
||||
|
||||
# 模型类
|
||||
class Model:
|
||||
def __init__(self, model_path: str, device: torch.device, block_size: Optional[tuple] = None):
|
||||
"""
|
||||
初始化模型,加载 ONNX 模型,并设置设备(CPU 或 GPU)。
|
||||
"""
|
||||
self.device = device
|
||||
self.session = ort.InferenceSession(model_path)
|
||||
self.block_size = block_size
|
||||
self.data_transforms = transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
def predict(self, img: np.ndarray) -> torch.Tensor:
|
||||
"""
|
||||
使用 ONNX 模型进行推理,返回预测结果。
|
||||
"""
|
||||
img = self.data_transforms(img)
|
||||
img = img.unsqueeze(0)
|
||||
|
||||
# # 转换为 ONNX 输入格式
|
||||
img_numpy = img.cpu().numpy()
|
||||
|
||||
inputs = {self.session.get_inputs()[0].name: img_numpy}
|
||||
outputs = self.session.run(None, inputs)
|
||||
pred = torch.tensor(outputs[0])
|
||||
_, predicted_class = torch.max(pred, 1)
|
||||
|
||||
return predicted_class
|
||||
|
||||
def load(self, model_path: str):
|
||||
"""
|
||||
重新加载模型。
|
||||
"""
|
||||
self.session = ort.InferenceSession(model_path, providers=['TensorrtExecutionProvider'])
|
||||
|
||||
def preprocess_predict_img(self, img: np.ndarray, block_size: Optional[tuple] = None) -> np.ndarray:
|
||||
if block_size is None:
|
||||
block_size = self.block_size
|
||||
if block_size:
|
||||
img = img.reshape((self.block_size[0], -1, self.block_size[1], 3))
|
||||
img = img.transpose((1, 0, 2, 3))
|
||||
else:
|
||||
img = img[np.newaxis, ...]
|
||||
|
||||
pred = self.predict(img)
|
||||
pred = pred.squeeze().numpy()
|
||||
return pred
|
||||
114
DPL/model_cls/predict.py
Normal file
114
DPL/model_cls/predict.py
Normal file
@ -0,0 +1,114 @@
|
||||
import argparse
|
||||
from models import *
|
||||
|
||||
from PIL import Image
|
||||
import os
|
||||
import matplotlib.pyplot as plt
|
||||
from PIL.ImageDraw import Draw
|
||||
from PIL import ImageDraw, ImageFont
|
||||
from torchvision import transforms
|
||||
import time # 导入time模块
|
||||
|
||||
|
||||
# 预测函数
|
||||
def visualize_model_predictions(model: Model, img_path: str, save_dir: str, class_names: list):
|
||||
"""
|
||||
预测图像并可视化结果,保存预测后的图片到指定文件夹。
|
||||
"""
|
||||
start_time = time.time() # 开始时间
|
||||
|
||||
img = Image.open(img_path)
|
||||
img = img.convert('RGB') # 转换为 RGB 模式
|
||||
|
||||
# 使用模型进行预测
|
||||
preds = model.predict(img)
|
||||
print(preds)
|
||||
|
||||
# 在图像上添加预测结果文本
|
||||
predicted_label = class_names[preds[0]]
|
||||
|
||||
# 在图片上绘制文本
|
||||
img_with_text = img.copy()
|
||||
draw = ImageDraw.Draw(img_with_text)
|
||||
font = ImageFont.load_default() # 可以根据需要更改字体
|
||||
text = f'Predicted: {predicted_label}'
|
||||
text_position = (10, 10) # 文本的位置,可以根据需要调整
|
||||
draw.text(text_position, text, font=font, fill=(0, 255, 0)) # 白色文字
|
||||
|
||||
# 显示结果图片
|
||||
img_with_text.show()
|
||||
|
||||
# 保存预测后的图像
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
|
||||
# 获取文件名和保存路径
|
||||
img_name = os.path.basename(img_path)
|
||||
save_path = os.path.join(save_dir, f"pred_{img_name}")
|
||||
|
||||
# 保存图片
|
||||
img_with_text.save(save_path)
|
||||
print(f"Prediction saved at: {save_path}")
|
||||
|
||||
end_time = time.time() # 结束时间
|
||||
processing_time = (end_time - start_time) * 1000 # 转换为毫秒
|
||||
print(f"Time taken to process {img_name}: {processing_time:.2f} ms") # 打印每张图片的处理时间(毫秒)
|
||||
|
||||
|
||||
def process_image_folder(model: Model, folder_path: str, save_dir: str, class_names: list):
|
||||
"""
|
||||
处理文件夹中的所有图像,并预测每张图像的类别。
|
||||
"""
|
||||
start_time = time.time() # 记录总开始时间
|
||||
|
||||
# 获取文件夹中所有图片文件
|
||||
image_files = [f for f in os.listdir(folder_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
|
||||
|
||||
# 遍历文件夹中的所有图像
|
||||
for image_file in image_files:
|
||||
img_path = os.path.join(folder_path, image_file)
|
||||
print(f"Processing image: {img_path}")
|
||||
visualize_model_predictions(model, img_path, save_dir, class_names)
|
||||
|
||||
end_time = time.time() # 记录总结束时间
|
||||
total_time = (end_time - start_time) * 1000 # 转换为毫秒
|
||||
print(f"Total time taken to process all images: {total_time:.2f} ms") # 打印总处理时间(毫秒)
|
||||
|
||||
|
||||
def main():
|
||||
# 命令行参数解析
|
||||
parser = argparse.ArgumentParser(description="Use an ONNX model for inference.")
|
||||
|
||||
# 设置默认值,并允许用户通过命令行进行修改
|
||||
parser.add_argument('--weights', type=str, default=r'..\onnxs\dgd_class_11.14.onnx',
|
||||
help='Path to ONNX model file')
|
||||
parser.add_argument('--img-path', type=str, default='../dataset/val/dgd',
|
||||
help='Path to image or folder for inference')
|
||||
parser.add_argument('--save-dir', type=str, default='detect', help='Directory to save output images')
|
||||
parser.add_argument('--gpu', action='store_true', help='Use GPU for inference')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 设置设备
|
||||
device = torch.device('cuda' if args.gpu and torch.cuda.is_available() else 'cpu')
|
||||
if args.gpu and not torch.cuda.is_available():
|
||||
print("GPU not available, switching to CPU.")
|
||||
|
||||
# 加载模型
|
||||
model = Model(model_path=args.weights, device=device)
|
||||
|
||||
# 模拟加载 class_names
|
||||
# 假设模型类别数量为2
|
||||
class_names = ['class_0', 'class_1']
|
||||
|
||||
# 检查输入路径是否为文件夹
|
||||
if os.path.isdir(args.img_path):
|
||||
# 如果是文件夹,处理文件夹中的所有图片
|
||||
process_image_folder(model, args.img_path, args.save_dir, class_names)
|
||||
else:
|
||||
# 如果是单个图片文件,进行预测
|
||||
visualize_model_predictions(model, args.img_path, args.save_dir, class_names)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
196
DPL/model_cls/preprocess.py
Normal file
196
DPL/model_cls/preprocess.py
Normal file
@ -0,0 +1,196 @@
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import json
|
||||
from shapely.geometry import Polygon, box
|
||||
from shapely.affinity import translate
|
||||
|
||||
|
||||
def create_dataset_from_folder(image_folder, label_folder, block_size, output_dir):
|
||||
"""
|
||||
批量处理文件夹中的图片和对应标签,生成分类模型的数据集。
|
||||
|
||||
Args:
|
||||
image_folder (str): 图片文件夹路径。
|
||||
label_folder (str): 标签文件夹路径,对应Labelme生成的JSON文件。
|
||||
block_size (tuple): 分块的尺寸 (width, height)。
|
||||
output_dir (str): 输出数据集的根目录。
|
||||
"""
|
||||
# 创建输出文件夹
|
||||
has_label_dir = os.path.join(output_dir, "has_label")
|
||||
no_label_dir = os.path.join(output_dir, "no_label")
|
||||
os.makedirs(has_label_dir, exist_ok=True)
|
||||
os.makedirs(no_label_dir, exist_ok=True)
|
||||
|
||||
# 遍历图片文件夹
|
||||
for filename in os.listdir(image_folder):
|
||||
if filename.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.tif')):
|
||||
image_path = os.path.join(image_folder, filename)
|
||||
label_path = os.path.join(label_folder, os.path.splitext(filename)[0] + ".json")
|
||||
|
||||
# 检查标签文件是否存在
|
||||
if not os.path.exists(label_path):
|
||||
print(f"Label file not found for image: {filename}")
|
||||
continue
|
||||
|
||||
print(f"Processing {filename}...")
|
||||
process_single_image(image_path, label_path, block_size, has_label_dir, no_label_dir)
|
||||
|
||||
|
||||
def process_single_image(image_path, label_path, block_size, has_label_dir, no_label_dir):
|
||||
"""
|
||||
处理单张图片并分块保存到对应的文件夹。
|
||||
|
||||
Args:
|
||||
image_path (str): 图片路径。
|
||||
label_path (str): 标签路径。
|
||||
block_size (tuple): 分块的尺寸 (width, height)。
|
||||
has_label_dir (str): 包含标注的分块保存目录。
|
||||
no_label_dir (str): 无标注的分块保存目录。
|
||||
"""
|
||||
# 加载图片
|
||||
image = cv2.imread(image_path)
|
||||
img_height, img_width, _ = image.shape
|
||||
|
||||
# 加载Labelme JSON文件
|
||||
with open(label_path, 'r', encoding='utf-8') as f:
|
||||
label_data = json.load(f)
|
||||
|
||||
# 提取多边形标注
|
||||
polygons = []
|
||||
for shape in label_data['shapes']:
|
||||
if shape['shape_type'] == 'polygon':
|
||||
points = shape['points']
|
||||
polygons.append(Polygon(points))
|
||||
|
||||
if roi:
|
||||
x_min, y_min, x_max, y_max = roi
|
||||
x_min, y_min = max(0, x_min), max(0, y_min)
|
||||
x_max, y_max = min(img_width, x_max), min(img_height, y_max)
|
||||
image = image[y_min:y_max, x_min:x_max]
|
||||
img_height, img_width = y_max - y_min, x_max - x_min
|
||||
# 偏移标注的多边形
|
||||
polygons = [translate(poly.intersection(box(x_min, y_min, x_max, y_max)), -x_min, -y_min) for poly in polygons]
|
||||
|
||||
# 分割图片并保存到对应的文件夹
|
||||
block_width, block_height = block_size
|
||||
block_id = 0
|
||||
base_name = os.path.splitext(os.path.basename(image_path))[0]
|
||||
for y in range(0, img_height, block_height):
|
||||
for x in range(0, img_width, block_width):
|
||||
# 当前分块的边界框
|
||||
block_polygon = box(x, y, x + block_width, y + block_height)
|
||||
|
||||
# 判断是否与任何标注的多边形相交
|
||||
contains_label = any(poly.intersects(block_polygon) for poly in polygons)
|
||||
|
||||
# 裁剪当前块
|
||||
block = image[y:y + block_height, x:x + block_width]
|
||||
|
||||
# 保存到对应文件夹
|
||||
folder = has_label_dir if contains_label else no_label_dir
|
||||
block_filename = os.path.join(folder, f"{base_name}_block_{block_id}.jpg")
|
||||
cv2.imwrite(block_filename, block)
|
||||
block_id += 1
|
||||
|
||||
print(f"Saved {block_filename} to {'has_label' if contains_label else 'no_label'} folder.")
|
||||
|
||||
|
||||
def split_dataset(input_dir, output_dir, train_ratio=0.7, val_ratio=0.2, test_ratio=0.1):
|
||||
"""
|
||||
将分类后的数据集划分为 train、val 和 test 数据集,保持来源目录结构。
|
||||
|
||||
Args:
|
||||
input_dir (str): 分类结果根目录,包含多个子文件夹(标签文件夹)。
|
||||
output_dir (str): 输出数据集目录,将生成 train、val 和 test 文件夹。
|
||||
train_ratio (float): train 集比例或固定数量。
|
||||
val_ratio (float): val 集比例或固定数量。
|
||||
test_ratio (float): test 集比例或固定数量。
|
||||
"""
|
||||
# 定义输出子目录
|
||||
train_dir = os.path.join(output_dir, "train")
|
||||
val_dir = os.path.join(output_dir, "val")
|
||||
test_dir = os.path.join(output_dir, "test")
|
||||
|
||||
# 遍历所有子文件夹(标签文件夹)
|
||||
for category in os.listdir(input_dir):
|
||||
category_dir = os.path.join(input_dir, category)
|
||||
if not os.path.isdir(category_dir): # 忽略非文件夹
|
||||
continue
|
||||
|
||||
# 为当前类别在 train/val/test 创建相同的子文件夹结构
|
||||
os.makedirs(os.path.join(train_dir, category), exist_ok=True)
|
||||
os.makedirs(os.path.join(val_dir, category), exist_ok=True)
|
||||
os.makedirs(os.path.join(test_dir, category), exist_ok=True)
|
||||
|
||||
# 获取当前类别下的所有文件
|
||||
files = os.listdir(category_dir)
|
||||
random.shuffle(files)
|
||||
|
||||
# 计算分割点
|
||||
total_files = len(files)
|
||||
if train_ratio < 1:
|
||||
train_count = int(total_files * train_ratio)
|
||||
val_count = int(total_files * val_ratio)
|
||||
test_count = total_files - train_count - val_count
|
||||
else:
|
||||
train_count = int(train_ratio)
|
||||
val_count = int(val_ratio)
|
||||
test_count = int(test_ratio)
|
||||
|
||||
# 确保不超过文件总数
|
||||
train_count = min(train_count, total_files)
|
||||
val_count = min(val_count, total_files - train_count)
|
||||
test_count = min(test_count, total_files - train_count - val_count)
|
||||
|
||||
print(f"Category {category}: {train_count} train, {val_count} val, {test_count} test files")
|
||||
|
||||
# 划分数据集
|
||||
train_files = files[:train_count]
|
||||
val_files = files[train_count:train_count + val_count]
|
||||
test_files = files[train_count + val_count:]
|
||||
|
||||
# 复制文件到对应的子目录
|
||||
for file in train_files:
|
||||
shutil.copy(os.path.join(category_dir, file), os.path.join(train_dir, category, file))
|
||||
for file in val_files:
|
||||
shutil.copy(os.path.join(category_dir, file), os.path.join(val_dir, category, file))
|
||||
for file in test_files:
|
||||
shutil.copy(os.path.join(category_dir, file), os.path.join(test_dir, category, file))
|
||||
|
||||
print(f"Category {category} processed. Train: {len(train_files)}, Val: {len(val_files)}, Test: {len(test_files)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Preprocess images to required shapes")
|
||||
|
||||
# 设置默认值,并允许用户通过命令行进行修改
|
||||
parser.add_argument('--img-dir', type=str, default=r'..\dataset\dgd\test_img', help='Directory to input images')
|
||||
parser.add_argument('--label-dir', type=str, default=r'..\dataset\dgd\test_img', help='Directory to input labels')
|
||||
|
||||
parser.add_argument('--output-dir', type=str, default=r'..\dataset\dgd\runs', help='Directory to save output images')
|
||||
parser.add_argument('--roi', type=int, nargs=4, default=None, help='ROI region (x_min y_min x_max y_max)')
|
||||
parser.add_argument('--train-ratio', type=float, default=0.7, help='Train set ratio or count')
|
||||
parser.add_argument('--val-ratio', type=float, default=0.2, help='Validation set ratio or count')
|
||||
parser.add_argument('--test-ratio', type=float, default=0.1, help='Test set ratio or count')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 输入文件夹路径
|
||||
image_folder = args.img_dir
|
||||
label_folder = args.label_dir
|
||||
# 输出路径
|
||||
output_dir = args.output_dir
|
||||
|
||||
# 分块大小
|
||||
block_size = (170, 170) # 替换为希望的分块尺寸 (宽, 高)
|
||||
roi = tuple(args.roi) if args.roi else None
|
||||
all_output = os.path.join(output_dir, 'all')
|
||||
|
||||
# 批量生成数据集
|
||||
create_dataset_from_folder(image_folder, label_folder, block_size, all_output)
|
||||
split_dataset(all_output, output_dir, args.train_ratio, args.val_ratio, args.test_ratio)
|
||||
190
DPL/model_cls/train.py
Normal file
190
DPL/model_cls/train.py
Normal file
@ -0,0 +1,190 @@
|
||||
import os
|
||||
import argparse
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.optim import lr_scheduler
|
||||
from torchvision import datasets, models, transforms
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# 设置 cudnn 优化
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
def get_next_model_name(save_dir, base_name, ext):
|
||||
"""
|
||||
检索保存目录,生成递增编号的模型文件名。
|
||||
|
||||
Args:
|
||||
save_dir (str): 模型保存目录
|
||||
base_name (str): 模型基础名称
|
||||
ext (str): 模型文件扩展名(例如 ".pt" 或 ".onnx")
|
||||
|
||||
Returns:
|
||||
str: 自动递增编号的新模型文件名
|
||||
"""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
existing_files = [f for f in os.listdir(save_dir) if f.startswith(base_name) and f.endswith(ext)]
|
||||
existing_numbers = []
|
||||
for f in existing_files:
|
||||
try:
|
||||
num = int(f[len(base_name):].split('.')[0].strip('_'))
|
||||
existing_numbers.append(num)
|
||||
except ValueError:
|
||||
continue
|
||||
next_number = max(existing_numbers, default=0) + 1
|
||||
return os.path.join(save_dir, f"{base_name}_{next_number}{ext}")
|
||||
|
||||
|
||||
def get_data_loaders(data_dir, batch_size):
|
||||
"""Prepare data loaders for training and validation."""
|
||||
data_transforms = {
|
||||
'train': transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
]),
|
||||
'val': transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
]),
|
||||
}
|
||||
|
||||
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
|
||||
for x in ['train', 'val']}
|
||||
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size,
|
||||
shuffle=True, num_workers=4)
|
||||
for x in ['train', 'val']}
|
||||
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
|
||||
class_names = image_datasets['train'].classes
|
||||
|
||||
return dataloaders, dataset_sizes, class_names
|
||||
|
||||
|
||||
def train_model(model, criterion, optimizer, scheduler, dataloaders, dataset_sizes, device,
|
||||
num_epochs, model_save_dir, model_base_name):
|
||||
"""Train the model and save the best weights."""
|
||||
since = time.time()
|
||||
os.makedirs(model_save_dir, exist_ok=True)
|
||||
|
||||
best_acc = 0.0
|
||||
best_model_path = None # 用于保存最佳模型路径
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
print(f'Epoch {epoch}/{num_epochs - 1}')
|
||||
print('-' * 10)
|
||||
|
||||
for phase in ['train', 'val']:
|
||||
if phase == 'train':
|
||||
model.train()
|
||||
else:
|
||||
model.eval()
|
||||
|
||||
running_loss = 0.0
|
||||
running_corrects = 0
|
||||
|
||||
for inputs, labels in dataloaders[phase]:
|
||||
inputs, labels = inputs.to(device), labels.to(device)
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
with torch.set_grad_enabled(phase == 'train'):
|
||||
outputs = model(inputs)
|
||||
_, preds = torch.max(outputs, 1)
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
if phase == 'train':
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
running_loss += loss.item() * inputs.size(0)
|
||||
running_corrects += torch.sum(preds == labels.data)
|
||||
|
||||
if phase == 'train':
|
||||
scheduler.step()
|
||||
|
||||
epoch_loss = running_loss / dataset_sizes[phase]
|
||||
epoch_acc = running_corrects.double() / dataset_sizes[phase]
|
||||
|
||||
print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
|
||||
|
||||
if phase == 'val' and epoch_acc > best_acc:
|
||||
best_acc = epoch_acc
|
||||
|
||||
# 保存新的最佳模型
|
||||
best_model_path = get_next_model_name(model_save_dir, model_base_name, '.pt')
|
||||
|
||||
torch.save(model.state_dict(), best_model_path)
|
||||
print(f"Best model weights saved at {best_model_path}")
|
||||
|
||||
time_elapsed = time.time() - since
|
||||
print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
|
||||
print(f'Best val Acc: {best_acc:.4f}')
|
||||
|
||||
if best_model_path:
|
||||
# 仅加载最后保存的最佳模型
|
||||
model.load_state_dict(torch.load(best_model_path, weights_only=True))
|
||||
else:
|
||||
print("No best model was saved during training.")
|
||||
return model
|
||||
|
||||
|
||||
def main(args):
|
||||
"""Main function to train the model based on command-line arguments."""
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
dataloaders, dataset_sizes, class_names = get_data_loaders(args.data_dir, args.batch_size)
|
||||
if args.model_name == 'resnet18':
|
||||
model = models.resnet18(weights='IMAGENET1K_V1')
|
||||
num_ftrs = model.fc.in_features
|
||||
model.fc = nn.Linear(num_ftrs, args.num_classes)
|
||||
elif args.model_name == 'resnet50':
|
||||
model = models.resnet50(weights='IMAGENET1K_V1')
|
||||
num_ftrs = model.fc.in_features
|
||||
model.fc = nn.Linear(num_ftrs, args.num_classes)
|
||||
elif args.model_name == 'alexnet':
|
||||
model = models.alexnet(pretrained=True)
|
||||
num_ftrs = model.classifier[-1].in_features
|
||||
model.classifier[-1] = nn.Linear(num_ftrs, args.num_classes)
|
||||
elif args.model_name == 'vgg16':
|
||||
model = models.vgg16(pretrained=True)
|
||||
num_ftrs = model.classifier[-1].in_features
|
||||
model.classifier[-1] = nn.Linear(num_ftrs, args.num_classes)
|
||||
elif args.model_name == 'densenet':
|
||||
model = models.densenet121(pretrained=True)
|
||||
num_ftrs = model.classifier.in_features
|
||||
model.classifier = nn.Linear(num_ftrs, args.num_classes)
|
||||
elif args.model_name == 'inceptionv3':
|
||||
model = models.inception_v3(pretrained=True)
|
||||
num_ftrs = model.fc.in_features
|
||||
model.fc = nn.Linear(num_ftrs, args.num_classes)
|
||||
model = model.to(device)
|
||||
|
||||
# training start time and model info
|
||||
model_base_name = f"{args.model_name}_bs{args.batch_size}_ep{args.epochs}_{datetime.now().strftime('%y_%m_%d')}.pt"
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
|
||||
scheduler = lr_scheduler.StepLR(optimizer, step_size=args.step_size, gamma=args.gamma)
|
||||
|
||||
train_model(model, criterion, optimizer, scheduler, dataloaders, dataset_sizes, device,
|
||||
args.epochs, args.model_save_dir, model_base_name)
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 参数解析部分
|
||||
parser = argparse.ArgumentParser(description="Train a ResNet18 model on custom dataset.")
|
||||
parser.add_argument('--model_name', type=str, default='resnet18', help='Which model to train as base model')
|
||||
parser.add_argument('--data-dir', type=str, default='dataset', help='Path to the dataset directory.')
|
||||
parser.add_argument('--model-save-dir', type=str, default='model', help='Directory to save the trained model.')
|
||||
parser.add_argument('--batch-size', type=int, default=8, help='Batch size for training.')
|
||||
parser.add_argument('--epochs', type=int, default=5, help='Number of training epochs.')
|
||||
parser.add_argument('--lr', type=float, default=0.0005, help='Learning rate.')
|
||||
parser.add_argument('--momentum', type=float, default=0.95, help='Momentum for SGD.')
|
||||
parser.add_argument('--step-size', type=int, default=5, help='Step size for LR scheduler.')
|
||||
parser.add_argument('--gamma', type=float, default=0.2, help='Gamma for LR scheduler.')
|
||||
parser.add_argument('--num-classes', type=int, default=2, help='Number of output classes.')
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
0
DPL/model_cls/utils.py
Normal file
0
DPL/model_cls/utils.py
Normal file
Binary file not shown.
Binary file not shown.
14
DPL/yolov5/mask.py
Normal file
14
DPL/yolov5/mask.py
Normal file
@ -0,0 +1,14 @@
|
||||
from model import Model
|
||||
|
||||
# 初始化 Model 对象,实例化时自动生成并保存掩膜
|
||||
model = Model(
|
||||
image_folder="",
|
||||
label_folder="",
|
||||
output_folder="",
|
||||
block_size_x=24,
|
||||
block_size_y=24
|
||||
)
|
||||
|
||||
# 读取某张图片对应的掩膜
|
||||
mask_array = model.get_mask_array("image1") # 替换 "image1" 为你的图片名
|
||||
print(mask_array)
|
||||
122
DPL/yolov5/model.py
Normal file
122
DPL/yolov5/model.py
Normal file
@ -0,0 +1,122 @@
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import time
|
||||
|
||||
class Model:
|
||||
def __init__(self, image_folder, label_folder, output_folder, block_size_x=24, block_size_y=24):
|
||||
"""
|
||||
初始化 Model 类,并自动生成并保存掩膜
|
||||
:param image_folder: 输入图像文件夹路径
|
||||
:param label_folder: YOLOv5标签文件夹路径
|
||||
:param output_folder: 输出掩膜矩阵保存文件夹路径
|
||||
:param width_blocks: 图像宽度分块数
|
||||
:param height_blocks: 图像高度分块数
|
||||
"""
|
||||
self.image_folder = image_folder
|
||||
self.label_folder = label_folder
|
||||
self.output_folder = output_folder
|
||||
self.width_blocks = block_size_x
|
||||
self.height_blocks = block_size_y
|
||||
|
||||
# 确保输出文件夹存在
|
||||
if not os.path.exists(self.output_folder):
|
||||
os.makedirs(self.output_folder)
|
||||
|
||||
# 自动处理并保存掩膜
|
||||
self._process_and_save_masks()
|
||||
|
||||
def _read_yolov5_labels(self):
|
||||
"""
|
||||
读取YOLOv5标签文件夹中的标签文件,提取每个框的位置。
|
||||
:return: 返回一个字典,格式为 {image_name: [(x_center, y_center, width, height), ...]}
|
||||
"""
|
||||
labels = {}
|
||||
for filename in os.listdir(self.label_folder):
|
||||
if filename.endswith('.txt'):
|
||||
image_name = filename.replace('.txt', '')
|
||||
file_path = os.path.join(self.label_folder, filename)
|
||||
|
||||
with open(file_path, 'r') as f:
|
||||
boxes = []
|
||||
for line in f:
|
||||
parts = line.strip().split()
|
||||
if len(parts) < 5:
|
||||
continue
|
||||
|
||||
x_center = float(parts[1])
|
||||
y_center = float(parts[2])
|
||||
width = float(parts[3])
|
||||
height = float(parts[4])
|
||||
boxes.append([x_center, y_center, width, height])
|
||||
labels[image_name] = boxes
|
||||
return labels
|
||||
|
||||
def _generate_mask(self, image_shape, boxes):
|
||||
"""
|
||||
根据检测框信息生成掩膜,返回True和False的矩阵
|
||||
:param image_shape: 图像的shape(height, width)
|
||||
:param boxes: 检测框信息,格式为 [(x_center, y_center, width, height), ...]
|
||||
:return: 掩膜矩阵
|
||||
"""
|
||||
height, width = image_shape
|
||||
mask = np.zeros((height, width), dtype=bool)
|
||||
|
||||
for box in boxes:
|
||||
x_center, y_center, width_box, height_box = box
|
||||
x1 = int((x_center - width_box / 2) * width)
|
||||
y1 = int((y_center - height_box / 2) * height)
|
||||
x2 = int((x_center + width_box / 2) * width)
|
||||
y2 = int((y_center + height_box / 2) * height)
|
||||
|
||||
x1 = max(0, x1)
|
||||
y1 = max(0, y1)
|
||||
x2 = min(width, x2)
|
||||
y2 = min(height, y2)
|
||||
|
||||
mask[y1:y2, x1:x2] = True
|
||||
|
||||
return mask
|
||||
|
||||
def _process_and_save_masks(self):
|
||||
"""
|
||||
处理图像文件夹,生成掩膜并保存为True和False的矩阵
|
||||
"""
|
||||
labels = self._read_yolov5_labels()
|
||||
|
||||
for filename in os.listdir(self.image_folder):
|
||||
if filename.endswith(('.jpg', '.png', '.bmp')):
|
||||
image_path = os.path.join(self.image_folder, filename)
|
||||
image_name = filename.split('.')[0]
|
||||
|
||||
boxes = labels.get(image_name, [])
|
||||
if not boxes:
|
||||
print(f"未找到检测框信息:{image_name}")
|
||||
continue
|
||||
|
||||
image = cv2.imread(image_path)
|
||||
if image is None:
|
||||
print(f"无法读取图片: {image_path}")
|
||||
continue
|
||||
|
||||
height, width = image.shape[:2]
|
||||
start_time = time.time()
|
||||
mask = self._generate_mask((height, width), boxes)
|
||||
processing_time = time.time() - start_time
|
||||
print(f"处理图片 {image_name} 耗时: {processing_time:.4f}秒")
|
||||
|
||||
mask_filename = f"{image_name}_mask.npy"
|
||||
mask_path = os.path.join(self.output_folder, mask_filename)
|
||||
np.save(mask_path, mask)
|
||||
print(f"保存掩膜: {mask_filename}")
|
||||
|
||||
def get_mask_array(self, image_name):
|
||||
"""
|
||||
返回指定图片的掩膜数组
|
||||
:param image_name: 图片名称(不带扩展名)
|
||||
:return: 掩膜数组
|
||||
"""
|
||||
mask_path = os.path.join(self.output_folder, f"{image_name}_mask.npy")
|
||||
if not os.path.exists(mask_path):
|
||||
raise FileNotFoundError(f"掩膜文件未找到: {mask_path}")
|
||||
return np.load(mask_path)
|
||||
Loading…
Reference in New Issue
Block a user