U-Net

Posted by     "Hao Xu" on Saturday, January 4, 2025

U-Net复现

本篇内容是通过PyTorch来复现U-Net.

模型总览

Source U-Net

Source U-Net

如上图(蓝色方块上方显示的是通道数,左下角显示的是数据的高宽)所示,U-Net 的模型结构符合我们前面说的 编码器/解码器结构 (Encoder/Decoder structure)

左边的 contracting path 就是编码器,从图片提取出特征;右边的 expansive path 就是解码器。

编码器结构

左边的编码器和典型的卷积网络结构相似,它由两个 3×3 没有填充卷积操作和 2×2 步长为2max pooling 不断重复组成。并且每个卷积操作后面都有一个 ReLU 激活函数。

由于 3×3 卷积操作没有进行 padding,所以每次卷积操作之后数据的宽高都会减少 (k-1),k 是卷积核的大小。如图,最初是的输入数据的宽高为 572×572,经过一次 3×3 没有填充的卷积之后变成了 570×570 。

在每次 max pooling 的下采样中,数据的通道数会翻倍,但是宽高变为 $\frac {i - k} {s} +1$ 。$i$ 表示输入形状,$k$ 是卷积核大小,$s$ 是步长。将 $k$ 与 $s$ 带入,可以知道,每次下采样数据的高宽都会减半

解码器结构

右边的解码器与编码器相比有两点差异。

其一,编码器中max pooling的下采样改成了步长为22×2转置卷积来进行上采样。这里数据的通道数会减半,同时数据的宽高都会变为 $s(i-1)+k$ 。这里 $s$ 步长,$i$ 表示输入形状,$k$是卷积核大小。将 $k$ 与 $s$ 带入,可以得知,每次上采样数据的高宽都会翻倍

其二,在每次上采样之后有一个名为 skip connection 的操作,即图中的 copy and crop 。即将左侧对应的特征图与上采样的输出进行 concatenation 。

注意:这里由于 padding、stride与 kernel size 的选择,每次卷积操作,边界像素都会有损失。所以左侧的特征图高宽是大于右侧对应特征图的,所以这里论文中对左侧特征图先进行了 crop,然后再与右侧特征图进行连接。而最后输出结果的形状远小于输入数据形状的原因也是因为卷积操作中边界像素的损失。

[同时,也可以考虑对解码器的特征图做线性插值或者padding操作后再进行concatenation 。或者在每次卷积操作中加入为1 的padding,即可使卷积操作不损失边界且左右编码器解码器对应的特征图高宽一致(但是由于四次下采样每次数据高宽都减半,所以使用这种方法需要确保模型输入数据高宽是 $2^4$ 的倍数)]

输入与输出

U-Net论文中的数据是单通道的灰度图,所以输入数据的通道数为1(如果是RGB图像即为3)输入后经过第一个卷积操作直接转换成了64通道的特征图,与后面的通道数翻倍增加不同。

最后得到的输出会经过 1×1 的卷积操作将 64 通道的特征图映射成所需的类别数。

代码复现

如图所示,U-Net主要由连续的两个 conv 3×3 + ReLu,copy and crop,max pool 下采样,up-conv 转置卷积上采样和conv 1×1组成。

下面我们将分别实现连续的两个 conv 3×3 + ReLu,下采样和上采样。

首先,我们导入必要的库

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

Conv Block

这里实现连续的两个 conv 3×3 + ReLu。

class conv_block(nn.Module):
    def __init__(self, in_channels, out_channels, padding=0):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3,stride=1,padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3,stride=1,padding=padding),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self,x):
        x = self.conv(x)
        return x

DownSample

这里的下采样包括 max pool 下采样和连续的两个 conv 3×3 + ReLu 。

class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels, padding=0):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(kernel_size=2, stride=2),
            conv_block(in_channels, out_channels, padding=padding)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

UpSample

这里的上采样包括转置卷积上采样,并与左侧对应编码器的特征图 concatenation 。之后进行连续的两个 conv 3×3 + ReLu 。

class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels, concat=0):
        super().__init__()
        """
        concat=0 -> do center crop
        concat=1 -> padding decoder feature map
        concat=2 -> padding=1 in conv_block
        """
        self.concat = concat
        if self.concat not in [0, 1, 2]:
            raise Exception('concat not in list of [0, 1, 2]')
        if self.concat == 2:
            padding = 1

        self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
        self.conv = conv_block(in_channels, out_channels, padding=padding)

    def forward(self, x, x_copy):
        x = self.up(x)

        if self.concat == 0:
            B, C, H, W = x.shape
            x_copy = torchvision.transforms.CenterCrop([H, W])(x_copy)
            
        elif self.concat == 1:
            diffY = x_copy.size()[2] - x.size()[2]
            diffX = x_copy.size()[3] - x.size()[3]
            x = F.pad(x, [
                diffX // 2, diffX - diffX // 2, 
                diffY // 2, diffY - diffY // 2
                ])

        x = torch.cat([x_copy, x], dim=1)
        return self.conv(x)

U-Net模型

前面通过 PyTorch 构造了 U-Net 模型编码器与解码器的各个模块,只需要将其拼接在一起就可以组成 U-Net 模型了。

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, concat=0):
        super().__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.concat = concat

        if concat == 2:
            padding = 1
        else:
            padding = 0

        expansion = 2
        inplanes = 64
        chns = [inplanes, inplanes*expansion, inplanes*expansion**2, inplanes*expansion**3, inplanes*expansion**4]

        self.inc = conv_block(n_channels, chns[0], padding)
        self.down1 = DownSample(chns[0], chns[1], padding)
        self.down2 = DownSample(chns[1], chns[2], padding)
        self.down3 = DownSample(chns[2], chns[3], padding)
        self.down4 = DownSample(chns[3], chns[4], padding)

        self.up1 = UpSample(chns[-1], chns[-2], concat)
        self.up2 = UpSample(chns[-2], chns[-3], concat)
        self.up3 = UpSample(chns[-3], chns[-4], concat)
        self.up4 = UpSample(chns[-4], chns[-5], concat)
        self.outc = nn.Conv2d(chns[-5], n_classes, kernel_size=1)

    def forward(self, x):
        e1 = self.inc(x)
        e2 = self.down1(e1)
        e3 = self.down2(e2)
        e4 = self.down3(e3)
        e5 = self.down4(e4)
        
        x = self.up1(e5, e4)
        x = self.up2(x, e3)
        x = self.up3(x, e2)
        x = self.up4(x, e1)
        logits = self.outc(x)

        return logits

以上就是U-Net模型PyTorch的实现。

Reference

  1. Ronneberger, Olaf, Philipp Fischer, and Thomas Brox. “U-net: Convolutional networks for biomedical image segmentation.” International Conference on Medical image computing and computer-assisted intervention. Springer, Cham, 2015.APA
  2. Milesial. “U-Net: Semantic segmentation with PyTorch” https://github.com/milesial/Pytorch-UNet

「如果这篇文章对你有用,请随意打赏」

Hao Xu Blog

如果这篇文章对你有用,请随意打赏

使用微信扫描二维码完成支付