supermachine--tomato-passio.../20240410RGBtest1/super-tomato/U_net.py

170 lines
5.3 KiB
Python

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import numpy as np
# 定义 U-Net 的下采样(编码器)部分
class UNetDown(nn.Module):
def __init__(self, in_size, out_size):
super(UNetDown, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_size, out_size, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_size, out_size, 3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(2)
)
def forward(self, x):
return self.conv(x)
# 定义 U-Net 的上采样(解码器)部分
class UNetUp(nn.Module):
def __init__(self, in_size, out_size):
super(UNetUp, self).__init__()
self.up = nn.ConvTranspose2d(in_size, out_size, 2, stride=2)
self.conv = nn.Sequential(
nn.Conv2d(in_size, out_size, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_size, out_size, 3, padding=1),
nn.ReLU(inplace=True)
)
def forward(self, x1, x2):
x1 = self.up(x1)
# 计算 padding
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = nn.functional.pad(x1, [
diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2
])
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
# 组装完整的 U-Net 模型
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
self.down1 = UNetDown(3, 64)
self.down2 = UNetDown(64, 128)
self.down3 = UNetDown(128, 256)
self.down4 = UNetDown(256, 512)
self.middle_conv = nn.Sequential(
nn.Conv2d(512, 1024, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(1024, 1024, 3, padding=1),
nn.ReLU(inplace=True)
)
self.up1 = UNetUp(1024, 512)
self.up2 = UNetUp(512, 256)
self.up3 = UNetUp(256, 128)
self.up4 = UNetUp(128, 64)
self.final_conv = nn.Conv2d(64, 1, 1)
def forward(self, x):
x1 = self.down1(x)
x2 = self.down2(x1)
x3 = self.down3(x2)
x4 = self.down4(x3)
x_middle = self.middle_conv(x4)
x = self.up1(x_middle, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
x = self.final_conv(x)
return torch.sigmoid(x)
# 创建模型
model = UNet()
print(model)
class ImageDataset(Dataset):
def __init__(self, image_dir, mask_dir, transform=None, size=(320, 458)):
self.image_dir = image_dir
self.mask_dir = mask_dir
self.transform = transform
self.size = size
self.images = os.listdir(image_dir)
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_name = self.images[idx]
img_path = os.path.join(self.image_dir, img_name)
mask_path = os.path.join(self.mask_dir, img_name.replace('.bmp', '.png'))
image = Image.open(img_path).convert("RGB")
mask = Image.open(mask_path).convert("L") # 确保是单通道的灰度图
# Resize image and mask
resize = transforms.Resize(self.size)
image = resize(image)
mask = resize(mask)
if self.transform:
image = self.transform(image)
mask = self.transform(mask)
mask = (mask > 128).type(torch.FloatTensor) # 二值化标签
return image, mask
transform = transforms.Compose([
transforms.ToTensor(),
])
def test_and_save(model, test_dir, save_dir):
model.eval()
transform = transforms.Compose([transforms.ToTensor()])
test_images = os.listdir(test_dir)
for img_name in test_images:
img_path = os.path.join(test_dir, img_name)
image = Image.open(img_path).convert("RGB")
image = transform(image).unsqueeze(0) # 添加 batch dimension
with torch.no_grad():
prediction = model(image)
prediction = prediction.squeeze(0).squeeze(0) # 去掉 batch 和 channel dimension
prediction = (prediction > 0.5).type(torch.uint8) # 二值化
save_image = Image.fromarray(prediction.numpy() * 255, 'L')
save_image.save(os.path.join(save_dir, img_name))
def train(model, loader, optimizer, criterion, epochs):
model.train()
for epoch in range(epochs):
for images, masks in loader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, masks)
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}, Loss: {loss.item()}")
train_dataset = ImageDataset('/Users/xs/PycharmProjects/super-tomato/datasets_green/train/img', '/Users/xs/PycharmProjects/super-tomato/datasets_green/train/label', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
# 上面定义的 UNet 类的代码
model = UNet()
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练模型
train(model, train_loader, optimizer, criterion, epochs=50)
test_and_save(model, '/Users/xs/PycharmProjects/super-tomato/tomato_img_25', '/Users/xs/PycharmProjects/super-tomato/tomato_img_25/pre')