The article was transferred from wechat official account [Machine Learning Alchemy]. Article reprint or contact the author on wechat: CYX645016617

Unet is actually quite simple, so today’s post won’t be very long.

0 overview

Semantic Segmentation is an important branch of image processing and machine vision. Different from the classification task, semantic segmentation needs to determine the category of each pixel in the image and perform accurate segmentation. Semantic segmentation has been widely used in automatic driving, automatic image matting, medical imaging and other fields.

The image above shows the segmentation result of the moving segmentation task in autonomous driving, which can effectively identify cars (dark blue), pedestrians (red), traffic lights (yellow), roads (light purple) and so on from a picture

Unet can be said to be the most common and simplest segmentation model. It is simple, efficient, easy to understand, easy to build, and can be trained from small data sets.

Unet is a very old Segmentation model, which is proposed by U-NET: Convolutional Networks for Biomedical Image Segmentation in 2015

Thesis link: arxiv.org/abs/1505.04…

Before Unet, there is the older FCN network, which is Fully Convolutional Netowkrs debris, but it is basically a framework. Until now, who dares to say that the segmentation network does not use the Convolutional layer? However, the accuracy of FCN network is lower than that of Unet. Now there are Segnet, Mask RCNN, DeepLabv3+ and other networks, but today I will introduce Unet, after all, one bite does not make fat.

1 Unet

Unet is actually quite simple, so today’s post won’t be very long.

1.1 Original intention (not important)

  1. The original intention of Unet is to solve the problem of medical image segmentation;
  2. A U-shaped network structure to obtain contextual information and location information;
  3. Achieved multiple firsts in the 2015 ISBI Cell Tracking Competition, which was originally designed to solve the cell-level segmentation task

1.2 Network Structure

This structure is to convolved and pooled the image first, which is pooled four times in the Unet paper. For example, the image at the beginning is 224×224, then it will become 112×112, 56×56,28×28 and 14×14. Then we do upsampling or deconvolution for the 14×14 feature graph to obtain the 28×28 feature graph. Concat the concatenation of the 28×28 feature graph with the previous 28×28 feature graph, and then convolve and upsampling the spliced feature graph to obtain the 56×56 feature graph. Then it is spliced and convolved with the previous 56×56 features, and then upsampled. After four upsampled times, a prediction result of 224×224 with the same size as the input image can be obtained.

In fact, overall, this is also an Encoder-Decoder structure:Unet network is very simple, the first part is feature extraction, the second part is up-sampling. In some literature this structure is calledEncoder-decoder structureBecause the structure of the network is a larger letter U it is called U-net.

  • Encoder: in the left half, a subsampling module is composed of two 3×3 convolution layers (RELU) and a 2×2 maxpooling layer (as can be seen in the following code);
  • Decoder: there is a half part, composed of an up-sampled convolution layer (deconvolution layer) + feature concat+ two 3×3 convolution layer (ReLU) repeatedly (can be seen in the code);

At that time, COMPARED with the FCN network proposed earlier, Unet used splicing as the fusion mode of feature graph.

  • FCN is used to fuse features by adding corresponding pixel values of feature images.
  • U-net uses splicing of channel numbers to create thicker features, which consume more video memory.

The advantages of Unet in my opinion are as follows: The deeper the network layer, the larger the field of vision is, the shallow convolution focuses on the texture features, while the deep network focuses on the essential features, so the deep and shallow features have the meaning of grid. Another point is obtained by the convolution of the larger size of the figure characteristics of the edge, is the lack of information, after all, every time the sample of refining characteristics at the same time, also will inevitably lose some edge character, the characteristics of the lost and can not be recovered from the samples, so by the characteristics of the splicing, to achieve a back edge character.

2. Why Unet performs well in medical image segmentation

This is an open question, if you have any ideas, please reply to discuss.

Most medical image semantic segmentation tasks will first use Unet as the baseline. Of course, the advantage of Unet explained in the previous chapter can definitely be used as the answer to this problem. Here we will talk about the characteristics of medical image

According to the discussion of netizens, the results obtained are as follows:

  1. The semantics of medical image are simple and the structure is fixed. Therefore, semantic information is relatively simple compared with automatic driving, so there is no need to filter useless information. All features of medical images are important, so both low-level features and high-level semantic features are important, so skip Connection structure of U-shaped structure (feature splicing) is better used

  2. There are few data of medical images and it is difficult to obtain them. The amount of data may be only hundreds or even less than 100. Therefore, it is easy to over-fit if a large network such as DeepLabv3+ model is used. The advantage of large network is stronger image expression ability, while the relatively simple and small number of medical images do not have so much content to be expressed. Therefore, it is also found that in the small order of magnitude, the segmented SOTA model and the lightweight Unet do not have the advantages of gods and demons

  3. Medical imaging is often multimodal. In the ISLES brain attack, for example, officials provide data on several modes, including CBF, MTT and CBV. Therefore, in medical imaging tasks, it is often necessary to design their own networks to extract different modal features. Therefore, lightweight Unet with simple structure can have more operation space.

Pytorch model code

This is my own code, so it is not very concise, but it should be easy to understand, and exactly the same as I explained before,(please feel free to contact me if you have any questions: CYx645016617) :

import torch
import torch.nn as nn
import torch.nn.functional as F

class double_conv2d_bn(nn.Module) :
    def __init__(self,in_channels,out_channels,kernel_size=3,strides=1,padding=1) :
        super(double_conv2d_bn,self).__init__()
        self.conv1 = nn.Conv2d(in_channels,out_channels,
                               kernel_size=kernel_size,
                              stride = strides,padding=padding,bias=True)
        self.conv2 = nn.Conv2d(out_channels,out_channels,
                              kernel_size = kernel_size,
                              stride = strides,padding=padding,bias=True)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
    
    def forward(self,x) :
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        return out
    
class deconv2d_bn(nn.Module) :
    def __init__(self,in_channels,out_channels,kernel_size=2,strides=2) :
        super(deconv2d_bn,self).__init__()
        self.conv1 = nn.ConvTranspose2d(in_channels,out_channels,
                                        kernel_size = kernel_size,
                                       stride = strides,bias=True)
        self.bn1 = nn.BatchNorm2d(out_channels)
        
    def forward(self,x) :
        out = F.relu(self.bn1(self.conv1(x)))
        return out
    
class Unet(nn.Module) :
    def __init__(self) :
        super(Unet,self).__init__()
        self.layer1_conv = double_conv2d_bn(1.8)
        self.layer2_conv = double_conv2d_bn(8.16)
        self.layer3_conv = double_conv2d_bn(16.32)
        self.layer4_conv = double_conv2d_bn(32.64)
        self.layer5_conv = double_conv2d_bn(64.128)
        self.layer6_conv = double_conv2d_bn(128.64)
        self.layer7_conv = double_conv2d_bn(64.32)
        self.layer8_conv = double_conv2d_bn(32.16)
        self.layer9_conv = double_conv2d_bn(16.8)
        self.layer10_conv = nn.Conv2d(8.1,kernel_size=3,
                                     stride=1,padding=1,bias=True)
        
        self.deconv1 = deconv2d_bn(128.64)
        self.deconv2 = deconv2d_bn(64.32)
        self.deconv3 = deconv2d_bn(32.16)
        self.deconv4 = deconv2d_bn(16.8)
        
        self.sigmoid = nn.Sigmoid()
        
    def forward(self,x) :
        conv1 = self.layer1_conv(x)
        pool1 = F.max_pool2d(conv1,2)
        
        conv2 = self.layer2_conv(pool1)
        pool2 = F.max_pool2d(conv2,2)
        
        conv3 = self.layer3_conv(pool2)
        pool3 = F.max_pool2d(conv3,2)
        
        conv4 = self.layer4_conv(pool3)
        pool4 = F.max_pool2d(conv4,2)
        
        conv5 = self.layer5_conv(pool4)
        
        convt1 = self.deconv1(conv5)
        concat1 = torch.cat([convt1,conv4],dim=1)
        conv6 = self.layer6_conv(concat1)
        
        convt2 = self.deconv2(conv6)
        concat2 = torch.cat([convt2,conv3],dim=1)
        conv7 = self.layer7_conv(concat2)
        
        convt3 = self.deconv3(conv7)
        concat3 = torch.cat([convt3,conv2],dim=1)
        conv8 = self.layer8_conv(concat3)
        
        convt4 = self.deconv4(conv8)
        concat4 = torch.cat([convt4,conv1],dim=1)
        conv9 = self.layer9_conv(concat4)
        outp = self.layer10_conv(conv9)
        outp = self.sigmoid(outp)
        return outp
    

model = Unet()
inp = torch.rand(10.1.224.224)
outp = model(inp)
print(outp.shape)
==> torch.Size([10.1.224.224])
Copy the code

First, upsampling and two convolution layers are constructed separately for repeated use in Unet model construction. The output and input of the model are then the same size, indicating that the model can be run.

Reference blog:

  1. Blog.csdn.net/wangdongwei…
  2. www.zhihu.com/question/26…
  3. zhuanlan.zhihu.com/p/90418337