mirror of
https://github.com/NanjingForestryUniversity/supermachine--tomato-passion_fruit.git
synced 2025-11-09 14:54:07 +00:00
170 lines
5.3 KiB
Python
170 lines
5.3 KiB
Python
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from torchvision import transforms
|
|
from PIL import Image
|
|
import os
|
|
import numpy as np
|
|
|
|
# 定义 U-Net 的下采样(编码器)部分
|
|
class UNetDown(nn.Module):
|
|
def __init__(self, in_size, out_size):
|
|
super(UNetDown, self).__init__()
|
|
self.conv = nn.Sequential(
|
|
nn.Conv2d(in_size, out_size, 3, padding=1),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(out_size, out_size, 3, padding=1),
|
|
nn.ReLU(inplace=True),
|
|
nn.MaxPool2d(2)
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
|
|
# 定义 U-Net 的上采样(解码器)部分
|
|
class UNetUp(nn.Module):
|
|
def __init__(self, in_size, out_size):
|
|
super(UNetUp, self).__init__()
|
|
self.up = nn.ConvTranspose2d(in_size, out_size, 2, stride=2)
|
|
self.conv = nn.Sequential(
|
|
nn.Conv2d(in_size, out_size, 3, padding=1),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(out_size, out_size, 3, padding=1),
|
|
nn.ReLU(inplace=True)
|
|
)
|
|
|
|
def forward(self, x1, x2):
|
|
x1 = self.up(x1)
|
|
# 计算 padding
|
|
diffY = x2.size()[2] - x1.size()[2]
|
|
diffX = x2.size()[3] - x1.size()[3]
|
|
x1 = nn.functional.pad(x1, [
|
|
diffX // 2, diffX - diffX // 2,
|
|
diffY // 2, diffY - diffY // 2
|
|
])
|
|
|
|
x = torch.cat([x2, x1], dim=1)
|
|
return self.conv(x)
|
|
|
|
|
|
# 组装完整的 U-Net 模型
|
|
class UNet(nn.Module):
|
|
def __init__(self):
|
|
super(UNet, self).__init__()
|
|
self.down1 = UNetDown(3, 64)
|
|
self.down2 = UNetDown(64, 128)
|
|
self.down3 = UNetDown(128, 256)
|
|
self.down4 = UNetDown(256, 512)
|
|
self.middle_conv = nn.Sequential(
|
|
nn.Conv2d(512, 1024, 3, padding=1),
|
|
nn.ReLU(inplace=True),
|
|
nn.Conv2d(1024, 1024, 3, padding=1),
|
|
nn.ReLU(inplace=True)
|
|
)
|
|
self.up1 = UNetUp(1024, 512)
|
|
self.up2 = UNetUp(512, 256)
|
|
self.up3 = UNetUp(256, 128)
|
|
self.up4 = UNetUp(128, 64)
|
|
self.final_conv = nn.Conv2d(64, 1, 1)
|
|
|
|
def forward(self, x):
|
|
x1 = self.down1(x)
|
|
x2 = self.down2(x1)
|
|
x3 = self.down3(x2)
|
|
x4 = self.down4(x3)
|
|
x_middle = self.middle_conv(x4)
|
|
x = self.up1(x_middle, x4)
|
|
x = self.up2(x, x3)
|
|
x = self.up3(x, x2)
|
|
x = self.up4(x, x1)
|
|
x = self.final_conv(x)
|
|
return torch.sigmoid(x)
|
|
|
|
# 创建模型
|
|
model = UNet()
|
|
print(model)
|
|
|
|
|
|
class ImageDataset(Dataset):
|
|
def __init__(self, image_dir, mask_dir, transform=None, size=(320, 458)):
|
|
self.image_dir = image_dir
|
|
self.mask_dir = mask_dir
|
|
self.transform = transform
|
|
self.size = size
|
|
self.images = os.listdir(image_dir)
|
|
|
|
def __len__(self):
|
|
return len(self.images)
|
|
|
|
def __getitem__(self, idx):
|
|
img_name = self.images[idx]
|
|
img_path = os.path.join(self.image_dir, img_name)
|
|
mask_path = os.path.join(self.mask_dir, img_name.replace('.bmp', '.png'))
|
|
image = Image.open(img_path).convert("RGB")
|
|
mask = Image.open(mask_path).convert("L") # 确保是单通道的灰度图
|
|
|
|
# Resize image and mask
|
|
resize = transforms.Resize(self.size)
|
|
image = resize(image)
|
|
mask = resize(mask)
|
|
|
|
if self.transform:
|
|
image = self.transform(image)
|
|
mask = self.transform(mask)
|
|
|
|
mask = (mask > 128).type(torch.FloatTensor) # 二值化标签
|
|
return image, mask
|
|
|
|
transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
])
|
|
|
|
|
|
def test_and_save(model, test_dir, save_dir):
|
|
model.eval()
|
|
transform = transforms.Compose([transforms.ToTensor()])
|
|
test_images = os.listdir(test_dir)
|
|
for img_name in test_images:
|
|
img_path = os.path.join(test_dir, img_name)
|
|
image = Image.open(img_path).convert("RGB")
|
|
image = transform(image).unsqueeze(0) # 添加 batch dimension
|
|
|
|
with torch.no_grad():
|
|
prediction = model(image)
|
|
|
|
prediction = prediction.squeeze(0).squeeze(0) # 去掉 batch 和 channel dimension
|
|
prediction = (prediction > 0.5).type(torch.uint8) # 二值化
|
|
save_image = Image.fromarray(prediction.numpy() * 255, 'L')
|
|
save_image.save(os.path.join(save_dir, img_name))
|
|
|
|
|
|
def train(model, loader, optimizer, criterion, epochs):
|
|
model.train()
|
|
for epoch in range(epochs):
|
|
for images, masks in loader:
|
|
optimizer.zero_grad()
|
|
outputs = model(images)
|
|
loss = criterion(outputs, masks)
|
|
loss.backward()
|
|
optimizer.step()
|
|
print(f"Epoch {epoch+1}, Loss: {loss.item()}")
|
|
|
|
train_dataset = ImageDataset('/Users/xs/PycharmProjects/super-tomato/datasets_green/train/img', '/Users/xs/PycharmProjects/super-tomato/datasets_green/train/label', transform=transform)
|
|
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
|
|
|
|
# 上面定义的 UNet 类的代码
|
|
model = UNet()
|
|
criterion = nn.BCELoss()
|
|
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
|
|
|
# 训练模型
|
|
|
|
|
|
train(model, train_loader, optimizer, criterion, epochs=50)
|
|
|
|
|
|
|
|
test_and_save(model, '/Users/xs/PycharmProjects/super-tomato/tomato_img_25', '/Users/xs/PycharmProjects/super-tomato/tomato_img_25/pre')
|