
# A More Dynamic U-Net in PyTorch
Today we build a u-net!
Basically, a classic u-net structure contains:
- even though nobody really knows what’s got encoded 😄 to decrease the resolution and increase the channels.
- A "decoder" doing opposite job of the encoder.
- Some "skip connections" between the encoded "feature maps" produced by encoder blocks and those decoder blocks in the same stage/level. The encoded feature maps are concatenated with decoded features as the input of the next stage.
- There might some "attention" or "gated" blocks/modules in the last few stages ── they usually improve performance.
I don’t want to go through the details, it’s good and performs well across different tasks and datasets, so we all use this shit.
The goal today is to produce a better shit ── configurable, flexible, maintainable yet effective!
First we need some basic building blocks ── and no explain:
|
|
We create an upsample layer used in each decoder block, here we use the pixel shuffle upsampling technique (also called subpixel convolution), it’s proved to be better by experiments, especially for image generation.
There are 3 typical structures of an upsample layer:
- Subpixel convolution
- Transposed convolution
- Bilinear interpolation + convolution
The latter two are easy to implement. Of course you can use plain interpolation, basically it’s a special case of the last one ── only bilinear interpolation and no convolution, so there’re no parameters to learn during upsampling.
|
|
Now we can build an upsample block using UpsampleLayer
, the skip features x_skip
from encoder are also processed together with the decoding input x_in
.
|
|
UNetDecoder
contains a list of decoding stage constructed by UNetUpBlock
. It takes a list of feature maps as input and produces a single output from the last decoding stage.
|
|
UNetMidBlock
is used to transforming the last encoding feature. You may add additional attention/gating module to UNetMidBlock
, probably it will improve your model performance by capturing some global features.
|
|
Now wrap them up!
|
|
I call it MoreDynamicUNet
, because it’s even more flexible than those fastai guys' :D
WHY?
-
We utilize the
timm
library to provide tons of state-of-the-art vision backbones.── BUT fastai
DynamicUnet
can only use backbones intorchvision
. -
It only depends on libraries
torch
andtimm
, readable code with type annotations, you can easily modify and customize those modules. For example, you can replace theUpsampleLayer
with bilinear interpolation plus convolution, if you are working on segmentation task rather than generation ── that simple.── BUT fastai
DynamicUnet
is full of shitty code with non-descriptive short variable names and without any type annotation, it looks like it was wrote by some old-school C guy, his computer has little memory that even some variable names can blow it up.I tried to modify it, but gave up eventually ── I have to dive into this shitty library, and print all the motherfucking suspicious variables to know their meaning, it’s a pain in the ass.
-
You can control the middle channel scale, also the channel reduction method used in decoder. The benefit is you could prevent some
CUDA out of memory
error by reducingunet_mid_ch_scale
or choosingunet_dec_ch_reduc
between'eq_skip'
and'halve'
── you might be able to train your model, it’s a tradeoff.I cannot train fastai
DynamicUnet
using Resnet50 as encoder on some generation task, with a input size of 512 x 512 x 3 and a batch size of 8 on a single RTX 4090 24G. Instead, I use MYMoreDynamicUNet
using the same settings, but withunet_dec_ch_reduc
set to'eq_skip'
, although it fills up the GPU memory (like 23.2G), we can now train a Resnet50-Unet!
The full code (a single python file) is available at mdunet.py.
Enjoy your shitty u-net, and until next time!
References:
- U-Net: U-Net: Convolutional Networks for Biomedical Image Segmentation
- Checkerboard effects of transposed convolution: https://distill.pub/2016/deconv-checkerboard
- Sub-pixel convolution: Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network
- ICNR initialization of sub-pixel convolution: Checkerboard artifact free sub-pixel convolution: A note on sub-pixel convolution, resize convolution and convolution resize
- fastai DynamicUnet: https://docs.fast.ai/vision.models.unet.html
- fastai Layers: https://docs.fast.ai/layers.html