S2-mlp V1&V2: Spatial-shift MLP Architecture for Vision

The original document: www.yuque.com/lart/papers…

Two articles on S2-MLP will be summarized here. The core idea of these two papers is the same, namely, replacing spatial MLP based on spatial offset operation.

Understand the text from the abstract

V1

Recently, visual Transformer (ViT) and its following works abandon the convolution and exploit the self-attention operation, attaining a comparable or even higher accuracy than CNNs. More recently, MLP-Mixer abandons both the convolution and the self-attention operation, proposing an architecture containing only MLP layers. To achieve cross-patch communications, it devises an additional token-mixing MLP besides the channel-mixing MLP. It achieves promising results when training on an extremely large-scale dataset. But it cannot achieve as outstanding performance as its CNN and ViT counterparts when training on medium-scale datasets such as ImageNet1K and ImageNet21K. The performance drop of MLP-Mixer motivates us to rethink the token-mixing MLP.

This leads to the main content of this article, the improvement space MLP.

We discover that the token-mixing MLP is a variant of the depthwise convolution with a global reception field and spatial-specific configuration. But the global reception field and the spatial-specific property make token-mixing MLP prone to over-fitting.

The problem of spatial MLP is pointed out, because of its global receptive field and spatial specific properties, the model is easy to overfit.

In this paper, we propose a novel pure MLP architecture, spatial-shift MLP (S2-MLP). Different from MLP-Mixer, our S2-MLP only contains channel-mixing MLP.

Only channel MLP is mentioned here, indicating that a new method has been found to expand the receptive field of channel MLP and to preserve point operation.

We utilize a spatial-shift operation for communications between patches. It has a local reception field and is spatial-agnostic. It is parameter-free and efficient for computation.

This leads to the core of this article, which is the space offset operation mentioned in the title. It appears that this operation takes no parameters and is just a process used to adjust characteristics. Spatial – Shift operation can refer to a few articles here: www.yuque.com/lart/archit…

The proposed S2-MLP attains higher recognition accuracy than MLP-Mixer when training on ImageNet-1K dataset. Meanwhile, S2-MLP accomplishes as excellent performance as ViT on ImageNet-1K dataset with considerably simpler architecture and fewer FLOPs and parameters.

V2

Recently, MLP-based vision backbones emerge. MLP-based vision architectures with less inductive bias achieve competitive performance in image recognition compared with CNNs and vision Transformers. Among them, spatial-shift MLP (S2-MLP), adopting the straightforward spatial-shift operation, achieves better performance than the pioneering works including MLP-mixer and ResMLP. More recently, using smaller patches with a pyramid structure, Vision Permutator (ViP) and Global Filter Network (GFNet) achieve better performance than S2-MLP.

This leads to the pyramid structure, and it looks like the V2 version will use a similar structure.

In this paper, we improve the S2-MLP vision backbone. We expand the feature map along the channel dimension and split the expanded feature map into several parts. We conduct different spatial-shift operations on split parts.

The space offset policy is still used, but how is it compared to V1

Meanwhile, we exploit the split-attention operation to fuse these split parts.

Split-attention (ResNeSt) was also introduced to fuse groupings. Are we going to use parallel branches here?

Moreover, like the counterparts, we adopt smaller-scale patches and use a pyramid structure for boosting the image recognition accuracy. We term the improved spatial-shift MLP vision backbone as S2-MLPv2. Using 55M parameters, our medium-scale model, One poster that achieves 83.6% top-1 accuracy on the imagenet-1K benchmark using 224×224 images without being exhibited in s2-MLPV2-medium achieves 83.6% top-1 accuracy on the imagenet-1K benchmark using 224×224 images without being exhibited in s2-MLPV2-medium self-attention and external training data.

In my opinion, compared with V1, V2 mainly borrowed some ideas of CycleFC and made adaptive adjustments. There are two aspects to the overall change:

  1. The idea of multi-branch processing is introduced and split-attention is applied to merge different branches.
  2. Inspired by existing work, smaller patches and layered pyramid structures were used.

The main content

Core structure comparison

In V1, the overall process continued the mlP-Mixer idea, which still maintained a straight tubular structure.

Structure diagram of MLP-Mixer:

It can be seen from the figure that S2MLP uses the post-norm structure, which is different from the pre-norm structure in the MLP-Mixer. In addition, S2MLP changes mainly focus on the location of spatial MLP, The Shift from spatial-mlP (Linear->GeLU->Linear) to spatial-channel-MLP (Linear->GeLU-> spatial-shift ->Lienar) Shifted. The core pseudocode for space offset is as follows:

As you can see, the input is divided into four different groups, each with different axial (H and W) offsets. For implementation reasons, there are duplicate values at the boundary. The number of packets depends on the number of directions. By default, 4 is used, that is, offset to four directions. Although only adjacent patches are associated from the perspective of a single space offset module, an approximate long-distance interaction process can be realized from the perspective of the overall stacked structure.

Compared with V1, V2 introduced the multi-branch processing strategy, and began to use the pre-norm form in structure.

The construction idea of multi-branch structure is very similar to CycleFC. Different branches use different processing strategies, and split-attention method is used for multi-branch integration.

Split-Attention: Vision Permutator (Hou et al., 2021) adopts split attention proposed in ResNeSt (Zhang et al., 2020) for enhancing multiple feature maps from different operations. This article uses a reference to fuse multiple branches. Main operation process:

  1. Can enter the KKK feature maps (from different branches) X = {Xk ∈ RN * C} k = 1 k, N = HW \ mathbf {X} = \ {X_k \ \ mathbb in ^ {R} {N \ times C} \} ^ _ {k = 1}, {k} \, N = HWX = {Xk ∈ RN * C} k = 1 k, N = HW
  2. Summing up the columns of all the special diagnosis charts: A RC = ∑ ∑ k = 1 k ∈ n = 1 NXK/n. : a \ \ mathbb in ^ {R} {C} = \ sum_ {k = 1} ^ {k} \ sum_ ^ {n = 1} {n} \ mathbf {X} _ {k} [n:] a RC = ∑ ∑ k = 1 k ∈ n = 1 NXK/n. :
  3. Through the transformation of stacked full connection layers, the channel attention logits of different feature graphs are obtained: A ^ ∈ = sigma (aW1) W2, RKC waalwijk W1 ∈ RC x C ˉ, W2 RC ˉ x ∈ KC \ hat {a} \ in \ mathbb {R} ^ = {KC} \ sigma (a W_1) W_2, \, W_1 \ in \ mathbb {R} ^ \ times \ bar {C} {C}, \, W_2 \ \ mathbb in ^ {R} {\ bar} {C} \ times KC a ^ ∈ = sigma (aW1) W2, RKC waalwijk W1 ∈ RC x C ˉ, W2 RC ˉ x ∈ KC
  4. 0 Use 0 to shape your attention vector: A ^ ∈ RKC waalwijk – a ^ ∈ fairly RK x C \ hat {a} \ \ mathbb in ^ {R} {KC} \ rightarrow \ hat {a} \ \ mathbb in ^ {R} {K} \ times C a ^ ∈ RKC waalwijk – a ^ ∈ fairly RK x C
  5. Softmax was used to calculate along the index KKK to obtain the normalized attention weight for different samples: , A ˉ [c] ∈ fairly RK softmax = (A ^) [c] :, \ bar {A} [c] :, \ \ mathbb in ^ {R} {K} = \ text {softmax} (\ hat {A} [c] :,) A ˉ [c] :, ∈ fairly RK softmax = (A ^) [c] :,
  6. The weighted sum of the input KKK feature maps gives the result YYY, and the result of one line can be expressed as: Y/n: ∈ RC = ∑ KXK k = 1 / n. : it’s A ˉ [k:] Y/n: \ \ mathbb in ^ {R} {C} = \ sum_ {k = 1} ^ {k} X_} {k/n. : \ odot \ bar {A} [k, :] Y/n: ∈ RC = ∑ KXK k = 1 / n. : it’s A ˉ / k, :

However, it should be noted that the third branch is an identity branch that directly takes part of the input channel. This is a continuation of GhostNet’s idea, unlike CycleFC, which uses a separate channel MLP.

GhostNetThe core structure of

The core pseudocode for the multi-branch structure is as follows:

Other details

Relation between spatial-shift and Depthwise Convolution

In fact, offsets in the four directions can be realized by specific convolution kernel construction:

Therefore, the group space migration operation can be achieved by specifying the corresponding Convolution kernel above for different groups of Depthwise Convolution.

In fact, there are many ways to implement offset. In addition to sliced index and depthwise convolution, we can also group torch. Roll and deform_conv2D for custom offset.

import torch
import torch.nn.functional as F
from torchvision.ops import deform_conv2d

xs = torch.meshgrid(torch.arange(5), torch.arange(5))
x = torch.stack(xs, dim=0)
x = x.unsqueeze(0).repeat(1.4.1.1).float()

direct_shift = torch.clone(x)
direct_shift[:, 0:2,,,1:] = torch.clone(direct_shift[:, 0:2,,,,4])
direct_shift[:, 2:4,,,,4] = torch.clone(direct_shift[:, 2:4,,,1:])
direct_shift[:, 4:6.1:, :] = torch.clone(direct_shift[:, 4:6To:4, :])
direct_shift[:, 6:8To:4, :] = torch.clone(direct_shift[:, 6:8.1:, :)print(direct_shift)

pad_x = F.pad(x, pad=[1.1.1.1], mode="replicate")  We need padding to preserve the boundary data

roll_shift = torch.cat(
    [
        torch.roll(pad_x[:, c * 2 : (c + 1) * 2. ] , shifts=(shift_h, shift_w), dims=(2.3))
        for c, (shift_h, shift_w) in enumerate([(0.1), (0, -1), (1.0), (-1.0)])
    ],
    dim=1,
)
roll_shift = roll_shift[..., 1:6.1:6]
print(roll_shift)

k1 = torch.FloatTensor([[0.0.0], [1.0.0], [0.0.0]]).reshape(1.1.3.3)
k2 = torch.FloatTensor([[0.0.0], [0.0.1], [0.0.0]]).reshape(1.1.3.3)
k3 = torch.FloatTensor([[0.1.0], [0.0.0], [0.0.0]]).reshape(1.1.3.3)
k4 = torch.FloatTensor([[0.0.0], [0.0.0], [0.1.0]]).reshape(1.1.3.3)
weight = torch.cat([k1, k1, k2, k2, k3, k3, k4, k4], dim=0)  Each output channel corresponds to an input channel
conv_shift = F.conv2d(pad_x, weight=weight, groups=8)
print(conv_shift)

offset = torch.empty(1.2 * 8 * 1 * 1.1.1)
for c, (rel_offset_h, rel_offset_w) in enumerate([(0, -1), (0, -1), (0.1), (0.1), (-1.0), (-1.0), (1.0), (1.0)]):
    offset[0, c * 2 + 0.0.0] = rel_offset_h
    offset[0, c * 2 + 1.0.0] = rel_offset_w
offset = offset.repeat(1.1.7.7).float()
weight = torch.eye(8).reshape(8.8.1.1).float()
deconv_shift = deform_conv2d(pad_x, offset=offset, weight=weight)
deconv_shift = deconv_shift[..., 1:6.1:6]
print(deconv_shift)

""" tensor([[[[0., 0., 0., 0., 0.], [1., 1., 1., 1., 1.], [2., 2., 2., 2., 2.], [3., 3., 3., 3., 3.], [4., 4., 4., 4., 4.]], [[0, 0), 1, 2, 3], [0., 0. 1, 2, 3], [0., 0. 1, 2, 3], [0., 0. 1, 2, 3], [0., 0. 1, 2, 3.]]. [[... 0, 0, 0, 0), and 0.], [1, 1, 1, 1, 1], [2, 2, 2, 2, 2], [3, 3), (3), (3), (3)], [4, 4, 4, 4, 4]], [[1, (2), (3), 4, 4], [1, 2, 3., 4, 4], [1, 2, 3., 4, 4], [1, 2, 3., 4, 4], [1, 2, 3., 4, 4]], [[... 0, 0, 0, . 0, 0.], [... 0, 0, 0, 0), and 0.], [1, 1, 1, 1, 1], [2, 2, 2, 2, 2], [3, 3), (3), (3), (3)]], [[0. 1, 2, 3., 4], [0, 1, 2, 3, 4], [0. 1, 2, 3, 4], [0. 1, 2, 3, 4], [0. 1, 2, 3., 4.]], [[1, 1, 1, 1, 1]. [(2), (2), (2), (2), (2)], [3, 3), (3), (3), (3)], [4, 4, 4, 4, 4], [4, 4, 4, 4, 4]], [[0. 1., 2., 3., 4], [0., 1., 2., 3., 4.], [0., 1., 2., 3., 4.], [0., 1., 2., 3., 4.], [0., 1., 2., 3., 4.]]]]) tensor([[[[0., 0., 0., 0., 0.], [1, 1, 1, 1, 1], [2, 2, 2, 2, 2], [3, 3), (3), (3), (3)], [4, 4, 4, 4, 4]], [[0., 0. 1, 2, 3.], [0., 0. 1., 2., 3.], [0., 0. 1, 2, 3], [0., 0. 1, 2, 3], [0., 0. 1, 2, 3]], [[... 0, 0, 0, 0), and 0.], [1, 1), (1), 1, 1], [2, 2, 2, 2, 2], [3, 3), (3), (3), (3)], [4, 4, 4, 4, 4]], [[1, 2, 3., 4, 4], [1, 2, 3., 4), 4.], [1, 2, 3, 4, 4], [1, 2, 3, 4, 4], [1, 2, 3, 4, 4]], [[... 0, 0, 0, 0), and 0.], [... 0, 0, 0, 0), and 0.]. [1, 1, 1, 1, 1], [2, 2, 2, 2, 2], [3, 3), (3), (3), (3)]], [[0. 1, 2, 3, 4], [0. 1, 2, 3., 4], [0., (1), (2), (3), 4], [0. 1, 2, 3, 4], [0. 1, 2, 3., 4.]], [[1, 1, 1, 1, 1], [2, 2, 2, 2, 2], [3, 3), (3), (3), (3)], [4, 4, 4, 4, 4], [4, 4, 4, 4, 4]], [[0. 1, 2, 3., 4], [0. 1, 2, 3., 4], [0. 1, 2, 3., 4.], [0., 1., 2., 3., 4.], [0., 1., 2., 3., 4.]]]]) tensor([[[[0., 0., 0., 0., 0.], [1., 1., 1., 1., 1.], [2., 2., 2., (2), (2)], [3, 3), (3), (3), (3)], [4, 4, 4, 4, 4]], [[0., 0. 1, 2, 3], [0., 0. 1, 2, 3], [0., 0. 1, 2, . 3], [0, 0, 1, 2, 3], [0., 0. 1, 2, 3]], [[... 0, 0, 0, 0), and 0.], [1, 1, 1, 1, 1], [(2), (2), (2), (2), (2)]. (3), (3), (3), (3), (3)], [4, 4, 4, 4, 4]], [[1, 2, 3., 4, 4], [1, 2, 3., 4, 4], [1, 2, 3., 4, 4], [1, (2), (3), 4, 4], [1, 2, 3., 4, 4]], [[... 0, 0, 0, 0), and 0.], [... 0, 0, 0, 0), and 0.], [1, 1, 1, 1, 1], [(2), (2), (2), (2), (2)], [3, 3), (3), (3), (3)]], [[0. 1, 2, 3, 4], [0. 1, 2, 3, 4], [0. 1, 2, 3., 4], [0. 1, 2, 3., 4], [0, 1, 2, 3, 4]], [[1, 1, 1, 1, 1], [2, 2, 2, 2, 2], [3, 3), (3), (3), (3)], [4, 4, 4, 4, 4]. [4, 4, 4, 4, 4]], [[0. 1, 2, 3, 4], [0. 1, 2, 3, 4], [0. 1, 2, 3, 4], [0. 1, 2, 3., 4], [0., 1., 2., 3., 4.]]]]) tensor([[[[0., 0., 0., 0., 0.], [1., 1., 1., 1., 1.], [2., 2., 2., 2., 2.], [3., 3., 3., 3., 3.], [4, 4, 4, 4, 4]], [[0., 0. 1, 2, 3], [0., 0. 1, 2, 3], [0., 0. 1, 2, 3], [0., 0. 1, 2, 3.], [0., 0. 1., 2., 3.]], [[... 0, 0, 0, 0), and 0.], [1, 1, 1, 1, 1], [2, 2, 2, 2, 2], [3, 3), (3), (3), (3)], [4, 4, 4, 4, 4]], [[1, 2, 3, 4, 4], [1, 2, 3, 4, 4], [1, 2, 3, 4, 4], [1, 2, 3., 4, 4], [1, 2, 3., 4), 4.]], [[... 0, 0, 0, 0), and 0.], [... 0, 0, 0, 0), and 0.], [1, 1, 1, 1, 1], [2, 2, 2, 2, 2], [3), (3), (3), (3), (3)]]. [[0, 1, 2, 3, 4], [0. 1, 2, 3, 4], [0. 1, 2, 3, 4], [0. 1, 2, 3, 4], [0. 1, 2, 3., 4.]], [[1, 1, 1, 1, 1], [2, 2, 2, 2, 2], [3, 3), (3), (3), (3)], [4, 4, 4, 4, 4], [4, 4, 4, 4, 4]], [[0. 1, 2, 3, 4], [0, 1, 2, 3, 4], [0. 1, 2, 3, 4], [0. 1, 2, 3, 4], [0. 1, 2, 3., 4.]]]]) "" "
Copy the code

Impact of offset direction

The experiment was run on a subset of ImageNet.

In V1, ablation experiments were conducted for different offset directions. In the model here, channels were grouped according to the number of directions. As can be seen from the results:

  • Offsets do provide performance gains.
  • A and B: There is not much difference between the four directions and the eight directions.
  • E and F: Horizontal offset works better.
  • C and E/F: The offset of two axes is better than that of a single axis.

Input size and patchSize effect

The experiment was run on a subset of ImageNet.

After patchsize is fixed in V1, WxH behaves differently with different input sizes. Excessive patchsize also has a poor effect and will lose more details, but it can effectively improve the inference speed.

The effectiveness of pyramid structures

In V2, two different structures are constructed. One has a smaller patch and uses pyramid structure, while the other has a larger patch and does not use pyramid structure. It can be seen that the former achieves better performance due to the enhanced performance of detail information brought by small Patchsize and better computing efficiency brought by pyramid structure.

The effect of the Split – Attention

V2 directly adds split-attention and features to get average comparison. As you can see, the former is better. However, the number of parameters is also different, in fact, a more reasonable comparison would be at least several layers of parametric structure to fuse the characteristics of the three branches.

The validity of the three-branch structure

“In this section, we evaluate the influence of removing one of them.” But it doesn’t explain how other structures might be adjusted after removing a specific branch.

The experimental results

The experimental results can be directly seen in the table of V2 paper:

link

  • Paper:
    • V1:arxiv.org/pdf/2106.07…
    • V2:arxiv.org/pdf/2108.01…
  • Reference code:
    • CycleFC code can be used for reference: github.com/ShoufaChen/…