Today we build a u-net!

Basically, a classic u-net structure contains:

  • even though nobody really knows what’s got encoded 😄 to decrease the resolution and increase the channels.
  • A "decoder" doing opposite job of the encoder.
  • Some "skip connections" between the encoded "feature maps" produced by encoder blocks and those decoder blocks in the same stage/level. The encoded feature maps are concatenated with decoded features as the input of the next stage.
  • There might some "attention" or "gated" blocks/modules in the last few stages ── they usually improve performance.

I don’t want to go through the details, it’s good and performs well across different tasks and datasets, so we all use this shit.

The goal today is to produce a better shit ── configurable, flexible, maintainable yet effective!

First we need some basic building blocks ── and no explain:

Basic building blocks
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from typing import List, Tuple, Optional
from collections import OrderedDict

class ResBlock(nn.Module):
    def __init__(self, channels, has_bn=True):
        super(ResBlock, self).__init__()
        self.has_bn = has_bn

        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1)
        if has_bn:
            self.bn1 = nn.BatchNorm2d(channels)
        self.relu1 = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1)
        if has_bn:
            self.bn2 = nn.BatchNorm2d(channels)
        self.relu2 = nn.ReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        z = self.conv1(x)
        if self.has_bn:
            z = self.bn1(z)
        z = self.relu1(z)

        z = self.conv2(z)
        if self.has_bn:
            z = self.bn2(z)
        y = self.relu2(x + z)
        return y


class ConvBNReLU(nn.Sequential):
    def __init__(self,
                 in_ch: int,
                 out_ch: int,
                 ksize: int,
                 stride: int = 1,
                 pad: int = 0,
                 pad_mode: str = 'zeros',
                 bias: bool = True,
                 inplace: bool = False):
        super(ConvBNReLU, self).__init__(OrderedDict([
            ('conv', nn.Conv2d(in_ch, out_ch, kernel_size=ksize, stride=stride, padding=pad, padding_mode=pad_mode, bias=bias)),
            ('bn', nn.BatchNorm2d(out_ch)),
            ('relu', nn.ReLU(inplace))
        ]))


class ConvReLU(nn.Sequential):
    def __init__(self, in_ch: int, out_ch: int, ksize: int, stride: int = 1, pad: int = 0, pad_mode: str = 'zeros',
                 bias: bool = True, inplace: bool = False):
        super(ConvReLU, self).__init__(OrderedDict([
            ('conv', nn.Conv2d(in_ch, out_ch, kernel_size=ksize, stride=stride, padding=pad, padding_mode=pad_mode, bias=bias)),
            ('relu', nn.ReLU(inplace))
        ]))

We create an upsample layer used in each decoder block, here we use the pixel shuffle upsampling technique (also called subpixel convolution), it’s proved to be better by experiments, especially for image generation.

There are 3 typical structures of an upsample layer:

  • Subpixel convolution
  • Transposed convolution
  • Bilinear interpolation + convolution

The latter two are easy to implement. Of course you can use plain interpolation, basically it’s a special case of the last one ── only bilinear interpolation and no convolution, so there’re no parameters to learn during upsampling.

UpsampleLayer
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# stolen from fastai
def icnr_init(x, scale=2, init=nn.init.kaiming_normal_):
    "ICNR init of `x`, with `scale` and `init` function"
    ni,nf,h,w = x.shape
    ni2 = int(ni/(scale**2))
    k = init(x.new_zeros([ni2,nf,h,w])).transpose(0, 1)
    k = k.contiguous().view(ni2, nf, -1)
    k = k.repeat(1, 1, scale**2)
    return k.contiguous().view([nf,ni,h,w]).transpose(0, 1)


# pixel shuffle upsampling
class UpsampleLayer(nn.Sequential):
    def __init__(self, in_ch, out_ch, scale=2, blur=False):
        layers = [
            ConvReLU(in_ch, out_ch * scale ** 2, ksize=1, stride=1, pad=0, bias=False, inplace=True),
            nn.PixelShuffle(scale)
        ]
        if blur:
            layers += [
                nn.ReflectionPad2d((1, 0, 1, 0)),
                nn.AvgPool2d(kernel_size=2, stride=1)
            ]
        super(UpsampleLayer, self).__init__(*layers)
        self._icnr_init()

    def _icnr_init(self):
        self[0][0].weight.data.copy_(icnr_init(self[0][0].weight.data))

Now we can build an upsample block using UpsampleLayer, the skip features x_skip from encoder are also processed together with the decoding input x_in.

UNetUpBlock
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
class UNetUpBlock(nn.Module):
    def __init__(self,
                 in_ch: int,
                 skip_ch: int,
                 out_ch: int,
                 upscale_factor: int = 2):
        super(UNetUpBlock, self).__init__()
        self.upsample = UpsampleLayer(in_ch, out_ch, scale=upscale_factor, blur=True)
        self.double_conv = nn.Sequential(OrderedDict([
            ('conv1', ConvReLU(skip_ch + out_ch, skip_ch + out_ch, ksize=3, stride=1, pad=1)),
            ('conv2', ConvReLU(skip_ch + out_ch, out_ch, ksize=3, stride=1, pad=1)),
        ]))

        self.bn = nn.BatchNorm2d(skip_ch)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x_in: torch.Tensor, x_skip: torch.Tensor) -> torch.Tensor:
        x_up = self.upsample(x_in)
        if x_skip.shape[-2:] != x_up.shape[-2:]:
            x_up = F.interpolate(x_up, size=x_skip.shape[-2:], mode='nearest')
        z = self.relu(torch.cat([x_up, self.bn(x_skip)], dim=1))
        z = self.double_conv(z)
        return z

UNetDecoder contains a list of decoding stage constructed by UNetUpBlock. It takes a list of feature maps as input and produces a single output from the last decoding stage.

UNetDecoder
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class UNetDecoder(nn.Module):
    def __init__(self,
                 enc_chs: List[int],
                 ch_reduc: str = 'eq_skip'):
        super(UNetDecoder, self).__init__()

        if ch_reduc not in ['eq_skip', 'halve']:
            raise RuntimeError(f"unknown channel reduction method: {ch_reduc}")

        dec_blocks = []
        in_ch = enc_chs[-1]
        for skip_ch in reversed(enc_chs[:-1]):
            if ch_reduc == 'eq_skip':
                # reduce the channels to skip_ch
                out_ch = skip_ch
            elif ch_reduc == 'halve':
                # halves the channels
                out_ch = in_ch // 2

            dec_blocks.append(UNetUpBlock(in_ch=in_ch,
                                          skip_ch=skip_ch,
                                          out_ch=out_ch))
            in_ch = out_ch

        self.dec_blocks = nn.ModuleList(dec_blocks)
        # from here `in_ch` is the first encoding channel

        self.up_layer = UpsampleLayer(in_ch, in_ch, scale=2, blur=True)
        self.out_conv = ConvReLU(in_ch, in_ch, ksize=3, stride=1, pad=1)

    def forward(self, feats: List[torch.Tensor]) -> torch.Tensor:
        feat_in = feats[-1]
        feat_out = None

        for i, feat_skip in enumerate(reversed(feats[:-1])):
            feat_out = self.dec_blocks[i](feat_in, feat_skip)
            feat_in = feat_out

        out = self.up_layer(feat_out)
        out = self.out_conv(out)
        return out

UNetMidBlock is used to transforming the last encoding feature. You may add additional attention/gating module to UNetMidBlock, probably it will improve your model performance by capturing some global features.

UNetMidBlock
1
2
3
4
5
6
7
8
9
class UNetMidBlock(nn.Sequential):
    def __init__(self,
                 last_enc_ch: int,
                 mid_ch_scale: int = 2):
        unet_mid_ch = last_enc_ch * mid_ch_scale
        super(UNetMidBlock, self).__init__(OrderedDict([
            ('conv1', ConvReLU(last_enc_ch, unet_mid_ch, ksize=3, stride=1, pad=1)),
            ('conv2', ConvReLU(unet_mid_ch, last_enc_ch, ksize=3, stride=1, pad=1)),
        ]))

Now wrap them up!

MoreDynamicUNet
 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class MoreDynamicUNet(nn.Module):
    def __init__(self,
                 in_ch: int = 3,
                 out_ch: int = 3,
                 backbone: str = 'resnet34',
                 pretrained: bool = False,
                 unet_mid_ch_scale: int = 2,
                 unet_dec_ch_reduc: str = 'eq_skip'):
        super(MoreDynamicUNet, self).__init__()

        self.encoder = timm.create_model(backbone, pretrained=pretrained, features_only=True)

        dummy_input = torch.rand(1, in_ch, 224, 224)
        dummy_feats = self.encoder(dummy_input)
        enc_chs = [f.shape[1] for f in dummy_feats]
        last_enc_ch = enc_chs[-1]

        self.mid_block = UNetMidBlock(last_enc_ch=last_enc_ch, mid_ch_scale=unet_mid_ch_scale)

        dummy_feats[-1] = self.mid_block(dummy_feats[-1])
        enc_chs[-1] = dummy_feats[-1].shape[1]

        self.decoder = UNetDecoder(enc_chs=enc_chs, ch_reduc=unet_dec_ch_reduc)

        dec_out = self.decoder(dummy_feats)
        dec_out_ch = dec_out.shape[1]

        self.out_block = ResBlock(channels=dec_out_ch + in_ch, has_bn=False)
        self.out_proj = nn.Conv2d(dec_out_ch + in_ch, out_ch, kernel_size=1, stride=1, padding=0, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        feats = self.encoder(x)
        feats[-1] = self.mid_block(feats[-1])
        z = self.decoder(feats)
        y = self.out_proj(self.out_block(torch.cat([z, x], dim=1)))
        return y

I call it MoreDynamicUNet, because it’s even more flexible than those fastai guys' :D

WHY?

  1. We utilize the timm library to provide tons of state-of-the-art vision backbones.

    ── BUT fastai DynamicUnet can only use backbones in torchvision.

  2. It only depends on libraries torch and timm, readable code with type annotations, you can easily modify and customize those modules. For example, you can replace the UpsampleLayer with bilinear interpolation plus convolution, if you are working on segmentation task rather than generation ── that simple.

    ── BUT fastai DynamicUnet is full of shitty code with non-descriptive short variable names and without any type annotation, it looks like it was wrote by some old-school C guy, his computer has little memory that even some variable names can blow it up.

    I tried to modify it, but gave up eventually ── I have to dive into this shitty library, and print all the motherfucking suspicious variables to know their meaning, it’s a pain in the ass.

  3. You can control the middle channel scale, also the channel reduction method used in decoder. The benefit is you could prevent some CUDA out of memory error by reducing unet_mid_ch_scale or choosing unet_dec_ch_reduc between 'eq_skip' and 'halve' ── you might be able to train your model, it’s a tradeoff.

    I cannot train fastai DynamicUnet using Resnet50 as encoder on some generation task, with a input size of 512 x 512 x 3 and a batch size of 8 on a single RTX 4090 24G. Instead, I use MY MoreDynamicUNet using the same settings, but with unet_dec_ch_reduc set to 'eq_skip', although it fills up the GPU memory (like 23.2G), we can now train a Resnet50-Unet!

The full code (a single python file) is available at mdunet.py.

Enjoy your shitty u-net, and until next time!

References: