The introduction

There are two major challenges with Transformer in graphics:

  • Visual entities vary greatly, and the performance of visual Transformer may not be very good in different scenarios

  • The image resolution is high and the number of pixels is large. The calculation of Transformer based on global self-attention is heavy

To solve the above two problems, we propose a Swin Transformer with sliding window operation and hierarchical design.

The sliding window operation includes non-overlapping local window and overlapping cross-window. Limiting attention calculation in one window can introduce the locality of CNN convolution operation on the one hand, and save calculation on the other hand.

Techches -t and ViT

Swin Transformer has good performance in all kinds of graphics tasks.

This article is long and will be based on the official open source code (github.com/microsoft/S…

The overall architecture

Let’s take a look at the overall architecture of Swin Transformer

Swin Transformer overall architecture

The whole model adopts a hierarchical design, including a total of 4 stages. Each Stage will reduce the resolution of the input feature map and expand the receptive field layer by layer, just like CNN.

  • At the beginning of input, a Patch Embedding was made, and the picture was cut into blocks and embedded into the Embedding.

  • In each Stage, it consists of Patch Merging and multiple blocks.

  • The Patch Merging module mainly reduces the image resolution at the beginning of each Stage.

  • The Block structure is shown in the figure on the right, mainly LayerNorm, MLP, Window Attention, and/or hypershift Window Attention (= = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = = =

class SwinTransformer(nn.Module): def __init__(...) : super().__init__() ... # absolute position embedding if self.ape: self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) self.pos_drop = nn.Dropout(p=drop_rate) # build layers self.layers = nn.ModuleList() for i_layer in range(self.num_layers): layer = BasicLayer(...) self.layers.append(layer) self.norm = norm_layer(self.num_features) self.avgpool = nn.AdaptiveAvgPool1d(1) self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity() def forward_features(self, x): x = self.patch_embed(x) if self.ape: x = x + self.absolute_pos_embed x = self.pos_drop(x) for layer in self.layers: x = layer(x) x = self.norm(x) # B L C x = self.avgpool(x.transpose(1, 2)) # B C 1 x = torch.flatten(x, 1) return x def forward(self, x): x = self.forward_features(x) x = self.head(x) return xCopy the code

There are a few areas that are handled differently from ViT:

  • ViT will encode the location of embedding after input. And swin-t here as an optional (self.ape), Swin-t in the calculation of Attention when doing a relative position code

  • The ViT will have a separate learnable parameter as the token of the classification. While SWin-t does average directly and outputs classification, which is similar to the global average pooling layer at the end of CNN

Let’s take a look at the components

Patch Embedding

Before entering the Block, we need to cut the picture into patches and then embed the vector.

The specific approach is to cut the original image into a window_size * WINDOW_SIZE window size, and then embedded.

Here, the stride, kernelsize can be set to window_size through the two-dimensional convolution layer. Set the output channel to determine the size of the embedding vector. Finally, the H and W dimensions are expanded and moved to the first dimension

import torch import torch.nn as nn class PatchEmbed(nn.Module): def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): super().__init__() img_size = to_2tuple(img_size) # -> (img_size, img_size) patch_size = to_2tuple(patch_size) # -> (patch_size, patch_size) patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] self.img_size = img_size self.patch_size = patch_size self.patches_resolution = patches_resolution self.num_patches = patches_resolution[0] * patches_resolution[1] self.in_chans = in_chans self.embed_dim = embed_dim self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) if norm_layer is not None: self.norm = norm_layer(embed_dim) else: self.norm = None def forward(self, x): (N, 96, 224/4, 224/4) x = torch. Flatten (x, 2) (N, 96, 56*56) x = torch. Transpose (x, 1, 2) # if self.norm is not None: x = self.norm(x) return xCopy the code

Patch Merging

The function of this module is to do downsampling before the start of each Stage, which is used to reduce the resolution and adjust the number of channels to form a hierarchical design. At the same time, it can also save a certain amount of calculation.

In CNN, the convolution/pooling layer with stride=2 is used to reduce the resolution before the start of each Stage.

Each downsampling is twice as much, so in the row direction and column direction, the element is selected at interval 2.

And then we splice it together as a whole tensor, and then we unfold it. At this point, the channel dimension will be 4 times the original (because H and W are reduced by 2 times), and then adjust the channel dimension to be twice the original through a full connection layer

class PatchMerging(nn.Module):
    def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.input_resolution = input_resolution
        self.dim = dim
        self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
        self.norm = norm_layer(4 * dim)

    def forward(self, x):
        """
        x: B, H*W, C
        """
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"
        assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."

        x = x.view(B, H, W, C)

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = torch.cat([x0, x1, x2, x3], -1)  # B H/2 W/2 4*C
        x = x.view(B, -1, 4 * C)  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

Copy the code

Here is a schematic diagram (input tensors N=1, H=W=8, C=1, not including the final fully connected layer adjustment)

Patch Merge

Personally, this feels like the inverse of PixelShuffle

Window Partition/Reverse

The window partition function is used to divide the window of a tensor, specifying the window size. Divide the original tensor from N H W C into NUM_windows *B, WINDOW_size, window_size, C, where num_windows = H*W/window_size, namely the number of Windows. The window reverse function is the corresponding inverse process. These two functions will be used later in Window Attention.

def window_partition(x, window_size):
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows


def window_reverse(windows, window_size, H, W):
    B = int(windows.shape[0] / (H * W / window_size / window_size))
    x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
    return x

Copy the code

Window Attention

This is the key to this article. Traditional Transformer computes attention on a global basis, so the computing complexity is high. Swin Transformer, on the other hand, reduces computation by limiting attention to each window.

Let’s look at the formula briefly

The main difference is that relative position coding is added to Q and K in the original formula for calculating Attention. Subsequent experiments have proved that the addition of relative position coding improves the performance of the model.

class WindowAttention(nn.Module): r""" Window based multi-head self attention (W-MSA) module with relative position bias. It supports both of shifted and non-shifted window. Args: dim (int): Number of input channels. window_size (tuple[int]): The height and width of the window. num_heads (int): Number of attention heads. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True qk_scale (float | None, optional): Override default qK scale of head_dim ** -0.5 if set attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 PROJ_drop (float, Optional): Dropout ratio of output. """ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.dim = dim self.window_size = window_size # Wh, Ww self.num_heads = num_heads # nH head_dim = dim // num_heads # self.scale = qk_scale or head_dim ** -0.5 # define a parameter table of relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * Window_size [0] -1) *(2* window_size[1] -1), num_heads)) # Set a learnable variable with the shape (2 *(wh-1) * 2*(wW-1), nH) Self. QKV = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) trunc_normal_(self.relative_position_bias_table, STD =.02) self. Softmax = nn. softmax (dim=-1)Copy the code

Let me separate out the logic involved in the relevant position coding, which is a bit convoluting

First, the Attention tensor calculated by QK has the shape (numWindows*B, num_heads, WINDOW_size * WINDOW_SIZE, WINDOW_size * WINDOW_size).

For the Attention tensor, with different elements as the origin, the coordinates of other elements are also different. Taking window_size=2 as an example, its relative position coding is shown in the figure below

Example of relative position coding

First we use the torch. Arange and the torch. Meshgrid functions to generate the corresponding coordinates, using windowsize=2 as an example

coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.meshgrid([coords_h, coords_w]) # -> 2*(wh, ww)
"""
  (tensor([[0, 0],
           [1, 1]]), 
   tensor([[0, 1],
           [0, 1]]))
"""

Copy the code

And then you stack it up and you expand it to a two-dimensional vector

coords = torch.stack(coords)  # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
"""
tensor([[0, 0, 1, 1],
        [0, 1, 0, 1]])
"""

Copy the code

Using the broadcast mechanism, insert a dimension in the first dimension and the second dimension respectively, broadcast subtraction, get 2, The tensors of WH * wW, WH *ww

relative_coords_first = coords_flatten[:, :, None] # 2, wh*ww, 1 relative_coords_second = coords_flatten[:, None, :] # 2, 1, wh*ww relative_relative_coords_first - relative_coords_secondCopy the code

Since we’re subtracting, we get an index that starts at a negative number, and we add an offset so that it starts at 0.

relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += self.window_size[0] - 1
relative_coords[:, :, 1] += self.window_size[1] - 1

Copy the code

We then need to expand it to one-dimensional offsets. And for the two coordinates (1,2) and (2,1). It’s different in two dimensions, but it’s the same offset by adding the x and y coordinates to the one dimensional offset.

Expand to a one-dimensional offset

So in the end we did a multiplication between them to distinguish them

relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1

Copy the code

offset multiply

Then sum over the last dimension, expand into a one-dimensional coordinate, and register as a variable that does not participate in network learning

relative_position_index = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
self.register_buffer("relative_position_index", relative_position_index)

Copy the code

Now let’s look at the forward code

def forward(self, x, mask=None): """ Args: x: input features with shape of (num_windows*B, N, C) mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) q = q * self.scale attn = (q @ k.transpose(-2, -1)) relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww attn = attn + relative_position_bias.unsqueeze(0) # (1, num_heads, windowsize, windowsize) if mask is not None: # Will be analyzed later... else: attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return xCopy the code
  • First, enter the tensors with the shape numWindows*B, WINDOW_size * WINDOW_size, C (explained later)

  • Then after going through the fully connected layer self.qkV, shape, adjust the order of the axes, get shape 3, numWindows*B, num_HEADS, WINDOW_size * WINDOW_SIZE, c// num_HEADS, and assign to Q, K, V.

  • According to the formula, we multiply q by a scale factor and then multiply it by k (which requires swapping the last two dimensions to satisfy the matrix multiplication requirement). We get the attn tensor of the shape (numWindows*B, num_heads, WINDOW_size * WINDOW_SIZE, WINDOW_size * WINDOW_size)

  • Earlier we set a learnable variable with the shape (2* WINDOW_size-1 *2* WINDOW_SIze-1, numHeads) for the position encoding. We use the calculated relative encoding position index self.relative_position_index to get the encoding shape (window_size*window_size, window_size*window_size, numHeads), Add to the ATTn tensor

  • Leaving out the mask case, all that remains is softMax, dropout, as in Transformer, with V-matrix multiplication, followed by a full connection layer and dropout

Shifted Window Attention

In order to better interact with other Windows, Swin Transformer is introducing the introduction of rich-window operations.

Shift Window

On the left is Window Attention, which does not overlap, and on the right is Shift Window Attention, which shifts the Window. You can see that the shifted window contains elements of the original adjacent window. But this also introduced a new problem: the number of Windows doubled, from four to nine.

In actual code, we areThis is achieved indirectly by shifting the feature map and setting a mask for Attention. Can be inKeep the original number of Windows, the final calculation result is equivalent.

Feature map shift operation

In the code, the shift of the feature map is achieved by torch. Roll, as shown below

The shift operation

If you need to reverse cyclic shift, simply set the shifts parameter to the corresponding positive value.

Attention Mask

I think this is the essence of Swin Transformer, by setting the right mask to shift Window Attention in the same number of Windows as Window Attention, to achieve equivalent results.

Start with index for each Window behind Shift Window and roll (window_size=2, shift_size=1)

Shift window index

So what we want to do is when we calculate Attention, we want to do it with the same index QK, and ignore the results of different index QK.

The correct result is shown in the figure below

Shift Attention

To get the correct results under the original four Windows, we must add a mask to the Attention result (shown on the far right of the image above)

The relevant code is as follows:

if self.shift_size > 0: # calculate attention mask for SW-MSA H, W = self.input_resolution img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 h_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -self.window_size), slice(-self.window_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: img_mask[:, h, w, :] = cnt cnt += 1 mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 mask_windows = mask_windows.view(-1, self.window_size * self.window_size) attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask ! = 0, float (100.0)). Masked_fill (attn_mask = = 0, float (0.0))Copy the code

As shown in the above figure, we will get a mask that looks like this using this code

tensor([[[[[ 0., 0., 0., 0.], [ 0., 0., 0., 0.], [ 0., 0., 0., 0.], [ 0., 0., 0., 0.]]], [[[ 0., -100., 0., -100.], [- 100, 0., - 100., 0.], [0., - 100. 0., - 100.], [100-100. 0., -., 0.]]], [[[0. 0., - 100., - 100.], [0., 0., - 100. - 100.], [- 100-100, 0., 0.], [- 100-100, 0., 0.]]], [[[0., - 100., - 100., - 100.], [- 100. 0., - 100., - 100.]. [- 100-100., 0., - 100.], [- 100-100-100, 0.]]]]])Copy the code

In the previous window Attention module forward code, we included this section

        if mask is not None:
            nW = mask.shape[0]
            attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
            attn = attn.view(-1, self.num_heads, N, N)
            attn = self.softmax(attn)

Copy the code

Add mask to the calculation result of attention and softmax. If the mask value is set to -100, softmax will ignore the corresponding value

Transformer Block overall architecture

The Transformer Block architecture

Two consecutive Block schemas are shown in the figure above. Note that a Stage must contain an even number of blocks because it needs to alternate between a Layer with Window Attention and a Layer with Window Attention.

Let’s look at the forward code for the Block

    def forward(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = x.view(B, H, W, C)

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        else:
            shifted_x = x

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
            x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        else:
            x = shifted_x
        x = x.view(B, H * W, C)

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

Copy the code

The overall process is as follows

  • LayerNorm is applied to the feature graph first

  • Use self.shift_size to determine whether the feature map needs to be shifted

  • The feature map is then cut into Windows

  • Calculate Attention, using self.attn_mask to distinguish between Window Attention and Shift Window Attention

  • Merge the Windows back

  • If the shift operation was performed before, reverse shift is performed to restore the previous shift operation

  • Make dropout and residual connections

  • This is followed by a LayerNorm+ fully connected layer, as well as dropout and residual connections

The experimental results

The experimental results

On the ImageNet22K dataset, the accuracy was a staggering 86.4%. In addition, the performance of detection, segmentation and other tasks is excellent, if you are interested in the final experiment section of the paper.

conclusion

This article is very innovative. It introduces the concept of Window and the locality of CNN, and it can control the overall computational amount of the model. In the Shift Window Attention part, with a mask and Shift operation, very clever implementation of computational equivalence. The author’s code is also very pleasing to see, recommended reading!


Welcome to GiantPandaCV, where you will see exclusive deep learning sharing, adhere to the original, share the new knowledge we learn every day. ✧ (, ̀ omega, ́)

If you have questions about this article, or want to join the communication group, please add BBuf wechat:

Qr code

This article uses the Article Synchronization Assistant to synchronize