It is well known that the Attention-based Transformer class model has good parallelization performance, but its spatial and temporal complexity is O(n2)\mathcal{O}(n^2)O(n2), NNN is the sequence length, Therefore, when the NNN is large, the computation of Transformer model is unbearable. More recently, there has been a lot of effort to reduce the computational complexity of Transformer models, such as streamlining techniques such as model pruning, quantification and distillation, or modifying the Attention structure. So that the complexity can be reduced to O(nlogn)\mathcal{O}(nlogn)O(nlogn) or even O(n)\mathcal{O}(n)O(n)
Thesis “Transformers are RNNs: I was intrigued by a method for Linear Attention in Fast Autoregressive Transformers with Linear Attention, and read a few blogs about it. There are some good results, and finally I will summarize my understanding of linear Attention in this article
Attention
The most popular Attention mechanic at the moment is calli-dot Attention
Q∈Rn× DK,K∈Rm× DK,V∈Rm× DV \ boldSymbol {Q}\in \mathbb{R}^{n\times d_k}, \ boldSymbol {K}\in \mathbb{R}^{m\times d_k}, \boldsymbol{K}\in \mathbb{R}^{m\times d_k} \boldsymbol{V}\in \mathbb{R}^{m\times d_v}Q∈Rn× DK,K∈Rm× DK,V∈Rm× DV. In this article we focus on the Self Attention scenario, so for the sake of introduction, Let Q,K,V∈Rn×d\ boldSymbol {Q},\ boldSymbol {K},\ boldSymbol {V}\in \mathbb{R}^{n\times d}Q,K,V∈Rn×d
To remove Softmax
Perhaps surprisingly, the key constraint on Attention’s performance is the Softmax! In fact, it’s a simple derivation. QKTQK^TQKT At this step we get an n×nn\times nn×n matrix, and then we need to make a Softmax
Softmax for a 1×n1\times n1×n row vector is O(n)O(n)O(n), but Softmax for each row of an N ×nn\times nn×n matrix is O(n2)O(n^2)O(n2).
If there is no Softmax, then the formula of Attention becomes three matrix multiplication QK⊤V\boldsymbol{QK^{\top}V}QK⊤V, and matrix multiplication is satisfied with the combination rate, so we can first calculate K⊤V\boldsymbol{K^{\top}V}K⊤V, Obtain a matrix d×dd\times dd×d (the time complexity of this step is O(d2n)O(d^2n)O(d2n)), then left multiply it by QQQ (the time complexity of this step is O(d2n)O(d^2n)O(d2n)), because d < nd \ll nd < n, So the approximate time complexity is just O(n), O(n), O(n).
For BERT base, d=64d=64d=64 instead of 768, why? Because 768 is actually obtained by multi-head concatenation, and d= 64D = 64D =64 for each Head
In other words, removing the Attention complexity of Softmax can reduce it to the optimal linear level O(n)\mathcal{O}(n)O(n)! This is obviously our ultimate pursuit: Linear Attention
General definition
The question is, would it still be Attention without Softmax? Can he still have the standard Attention effect? To answer this question, let’s rewrite the definition of scale-dot Attention as equivalent (the vectors in this article are column vectors)
Q,K∈Rn×d\ boldSymbol {Q},\ boldSymbol {K}\in \mathbb{R}^{n\times d}Q,K∈Rn×d, M=Q×K⊤\boldsymbol{M} = \ boldSymbol {Q}\times \boldsymbol{K^{\top}}M=Q×K⊤ The first row of M\ boldSymbol {M}M is obtained by multiplying the first row of Q\ boldSymbol {Q}Q by all columns of K⊤\ boldSymbol {K^{\top}}K⊤
Attention(Q,K,V)iAttention(\ boldSymbol {Q},\ BoldSymbol {K},\ BoldSymbol {V})_iAttention(Q,K,V) I indicates the third row of the final output matrix
Qi ⊤\ boldSymbol {q}_i^{\top}qi⊤ q ∈Rn×d\ boldSymbol {q} \in \mathbb{R}^{n\times d} q ∈Rn×d
Kj \ boldSymbol {k} _jKJ \ k ⊤∈Rd×n\ boldSymbol {k ^{\top}}\in \mathbb{R}^{d\times n} k ⊤
Vj \boldsymbol{v}_jvj = v ⊤∈Rd×nV^{\top}\in \mathbb{R}^{d\times n} v ⊤
Eqi ⊤kje^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}eqi⊤kj for vj\boldsymbol{v}_jvj. So we can come up with a general definition of Attention
Eqi ⊤kje^{\boldsymbol{q}_i^{\top}\boldsymbol{k}_j}eqi⊤kj,ki\boldsymbol{q}_i,\boldsymbol{k}_iqi,ki Sim (qi,kj)\text{sim}(\boldsymbol{q}_i,\ boldSymbol {k}_j)sim(qi, KJ)\text{sim}(\boldsymbol{q}_i,\ boldSymbol {k}_j) We require sim(Qi,kj)≥0\text{sim}(\boldsymbol{q}_i, \ boldSymbol {k}_j)\geq 0sim(Qi, KJ)≥0. In other words, if we want to define new Attention, we must retain the form (3), and sim(Qi,kj)≥0\text{sim}(\boldsymbol{q}_i, \ boldSymbol {k}_j)\geq 0sim(Qi, kJ)≥0
This general form of Attention is also called non-local Networks in CV, from the paper “Non-Local Neural Networks”
A few examples
If I just remove Softmax, So sim(qi,kj)=qi⊤kj\text{sim}(\boldsymbol{q}_i, \boldsymbol{k}_j) = \ boldSymbol {q}_i^{\top}\ BoldSymbol {k}_jsim(Qi,kj)=qi⊤kj. Here are some options
It is worth mentioning that the first two kinds of Linear Attention introduced below are from the CV field, and the third one was conceived by Mr. Su Jianlin (in addition to the following introduction, there are also the improvement work of Attention in THE CV field such as EMANet).
Kernel form
A natural idea is that if qi,kj\boldsymbol{q}_i, \boldsymbol{k}_jqi,kj \boldsymbol{k}_jqi, each element of KJ is non-negative, then the inner product is also non-negative. To accomplish this, we can add the activation function ϕ,φ\phi,\varphiϕ,φ to Qi,kj\boldsymbol{q}_i, \boldsymbol{k} _jQi,kj respectively
ϕ (⋅), phi (⋅) \ phi (\ cdot), \ varphi (\ cdot) ϕ (⋅), phi (⋅) is the range of the activation function. The beginning of the article mentioned in the paper the Transformers are RNNs “choose the ϕ (x) = phi (x) = elu (x) + 1 \ phi (x) = \ varphi (x) = \ text {elu} (x) + 1 ϕ (x) = phi (x) = elu (x) + 1, one of them
The common α\alphaα values are [0.1,0.3][0.1, 0.3][0.1,0.3] [0.1,0.3]
If you want to tell a story, formula (4) can be associated with the “kernel method”, especially when ϕ=φ\phi=\varphiϕ=φ, ϕ\phiϕ is equivalent to a kernel function, And ⟨ ϕ (qi), ϕ (kj) ⟩ \ langle \ phi (\ boldsymbol {q} _i), \ phi (\ boldsymbol {k} _j) \ rangle ⟨ ϕ (qi), ϕ (kj) ⟩ is through the kernel function is defined by the inner product. For this consideration, please refer to the paper Transformer Dissection: An Unified Understanding for Transformer’s Attention via the Lens of kernel
Use Softmax
An earlier article, Efficient Attention: Attention with Linear Complexities, offers a more interesting option. In QK⊤\ boldSymbol {QK^{\top}}QK⊤, Q,K∈Rn×d\ boldSymbol {Q},\ boldSymbol {K}\in \mathbb{R}^{n\times d}Q,K∈Rn×d, If “Q\ boldSymbol {Q}Q is normalized in DDD, and K\ boldSymbol {K}K is normalized in NNN”, then QK⊤\ boldSymbol {QK^{\top}}QK⊤ automatically meets normalization, so the choice it gives is
Where softmax1SoftMax_1SoftMax1 and softMax2SoftMax_2SoftMax2 respectively mean Softmax operation in the first (n)(n)(n) and the second dimension (D)(d)(d). In other words, we give Q,K\boldsymbol{Q},\ boldSymbol {K}Q,K plus Softmax, instead of QK⊤\boldsymbol{QK^{\top}}QK⊤ and then add Softmax
In fact, it can be proved that this form is also a special case of equation (4), This corresponds to ϕ (qi) = softmax (qi), phi (kj) = ekj \ phi (\ boldsymbol {q} _i) = softmax (\ boldsymbol {q} _i), \ varphi (\ boldsymbol {k} _j) = e ^ {\ boldsymbol {k} _ J}ϕ(qi)=softmax(Qi),φ(kj)=ekj
The conception of god Su
Here, God Su gives an idea. The starting point of this idea is no longer Eq. (4), but comes from our Taylor expansion of the original definition (2). And by Taylor unfolding we have
If qi⊤kj≥−1\boldsymbol{q}_i^{\top}\ boldSymbol {k}_j\ geq-1Qi ⊤kj≥−1, then the right end is guaranteed to be nonnegative, So let sim(qi,kj)=1+qi⊤kj\text{sim}(\boldsymbol{q}_i, \ boldsymbol _j) = 1 + \ boldsymbol {k} {q} _i ^ {\ top} \ boldsymbol {k} _jsim (qi, kj) = 1 + qi ⊤ kj. As you might have thought by now, to ensure that QI ⊤kj≥−1\boldsymbol{q}_i^{\top}\ BoldSymbol {k}_j\ geq-1qi ⊤kj≥−1, We only need to normalize L2L_2L2 for QI, KJ \ boldSymbol {q}_i,\ boldSymbol {k} _jQi, KJ respectively. Therefore, the final plan proposed by Su God was:
If x = \ [x1, x2,…, xn] boldsymbol {x} = (x_1, x_2,…, x_n] x = [x1, x2,…, xn], Is ∥ ∥ x = x12 + x22 + ⋅ ⋅ ⋅ + xn2 \ \ Vert Vert x = \ SQRT {x_1 ^ 2 + x_2 ^ 2 +… + x_n ^ 2} ∥ ∥ x = x12 + x22 + ⋅ ⋅ ⋅ + xn2
This is different from Formula (4), but theoretically it is closer to the original scale-dot Attention
implementation
Here, the method proposed by Su Shen is mainly implemented. However, due to the limited level of the author, there are actually some problems in the final implementation code, mainly as follows:
- From the test results, the improved computing speed did not improve
- You can’t add up to one
Code implementation is mainly aimed at BERT PyTorch realize the code of this article, more specifically, it only changes the ScaledDotProductAttention this function, so the following only released this part of the code
class ScaledDotProductAttention(nn.Module) :
def __init__(self) :
super(ScaledDotProductAttention, self).__init__()
def forward(self, Q, K, V, attn_mask) :
Q = F.normalize(Q, dim=3)
K = F.normalize(K, dim=3)
M = (torch.ones(Q.shape[0], Q.shape[1], Q.shape[2], K.shape[2]) + torch.matmul(Q, K.transpose(-1, -2))) # scores : [batch_size, n_heads, seq_len, seq_len]
M_sum = torch.sum(M, dim=3)
M = M / M_sum.unsqueeze(3).repeat(1.1.1, M.shape[3])
attn = M.masked_fill(attn_mask, 0) # Fills elements of self tensor with value where mask is one.
context = torch.matmul(attn, V)
return context
Copy the code
If you have a better way to do it, please let me know
Reference
- The Exploration of Linear Attention: Must Attention have a Softmax?