# A More Dynamic U-Net in PyTorch
Today we build a u-net!
Basically, a classic u-net structure contains:
- even though nobody really knows what’s got encoded 😄 to decrease the resolution and increase the channels.
- A "decoder" doing opposite job of the encoder.
- Some "skip connections" between the encoded "feature maps" produced by encoder blocks and those decoder blocks in the same stage/level. The encoded feature maps are concatenated with decoded features as the input of the next stage.
- There might some "attention" or "gated" blocks/modules in the last few stages ── they usually improve performance.
I don’t want to go through the details, it’s good and performs well across different tasks and datasets, so we all use this shit.
The goal today is to produce a better shit ── configurable, flexible, maintainable yet effective!
First we need some basic building blocks ── and no explain:
|
|
We create an upsample layer used in each decoder block, here we use the pixel shuffle upsampling technique (also called subpixel convolution), it’s proved to be better by experiments, especially for image generation.
There are 3 typical structures of an upsample layer:
- Subpixel convolution
- Transposed convolution
- Bilinear interpolation + convolution
The latter two are easy to implement. Of course you can use plain interpolation, basically it’s a special case of the last one ── only bilinear interpolation and no convolution, so there’re no parameters to learn during upsampling.
|
|
Now we can build an upsample block using UpsampleLayer, the skip features x_skip from encoder are also processed together with the decoding input x_in.
|
|
UNetDecoder contains a list of decoding stage constructed by UNetUpBlock. It takes a list of feature maps as input and produces a single output from the last decoding stage.
|
|
UNetMidBlock is used to transforming the last encoding feature. You may add additional attention/gating module to UNetMidBlock, probably it will improve your model performance by capturing some global features.
|
|
Now wrap them up!
|
|
I call it MoreDynamicUNet, because it’s even more flexible than those fastai guys' :D
WHY?
-
We utilize the
timmlibrary to provide tons of state-of-the-art vision backbones.── BUT fastai
DynamicUnetcan only use backbones intorchvision. -
It only depends on libraries
torchandtimm, readable code with type annotations, you can easily modify and customize those modules. For example, you can replace theUpsampleLayerwith bilinear interpolation plus convolution, if you are working on segmentation task rather than generation ── that simple.── BUT fastai
DynamicUnetis full of shitty code with non-descriptive short variable names and without any type annotation, it looks like it was wrote by some old-school C guy, his computer has little memory that even some variable names can blow it up.I tried to modify it, but gave up eventually ── I have to dive into this shitty library, and print all the motherfucking suspicious variables to know their meaning, it’s a pain in the ass.
-
You can control the middle channel scale, also the channel reduction method used in decoder. The benefit is you could prevent some
CUDA out of memoryerror by reducingunet_mid_ch_scaleor choosingunet_dec_ch_reducbetween'eq_skip'and'halve'── you might be able to train your model, it’s a tradeoff.I cannot train fastai
DynamicUnetusing Resnet50 as encoder on some generation task, with a input size of 512 x 512 x 3 and a batch size of 8 on a single RTX 4090 24G. Instead, I use MYMoreDynamicUNetusing the same settings, but withunet_dec_ch_reducset to'eq_skip', although it fills up the GPU memory (like 23.2G), we can now train a Resnet50-Unet!
The full code (a single python file) is available at mdunet.py.
Enjoy your shitty u-net, and until next time!
References:
- U-Net: U-Net: Convolutional Networks for Biomedical Image Segmentation
- Checkerboard effects of transposed convolution: https://distill.pub/2016/deconv-checkerboard
- Sub-pixel convolution: Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network
- ICNR initialization of sub-pixel convolution: Checkerboard artifact free sub-pixel convolution: A note on sub-pixel convolution, resize convolution and convolution resize
- fastai DynamicUnet: https://docs.fast.ai/vision.models.unet.html
- fastai Layers: https://docs.fast.ai/layers.html