After the BERT model burst onto the scene in 2018, Various BERT derivative models ALBERT, SpanBERT, DistilBERT, SesameBERT, SemBERT, SciBERT, BioBERT, MobileBERT, TinyBERT and CamemBERT are springing up like mushrooms. What do these various Bert models have in common? Some!

The answer is self-attention. What is self-attention? What’s the math behind it? That’s what we’re going to talk about today. The main purpose of this paper is to review the overall mathematical calculation involved in the self-attention model.

I. What is the self-attention model? Why do we need a self-attention model?

Now that you’ve seen the attention model, you know the attention model, if you think self-attention and the attention type, you’re right, both of them share a lot of the same concepts conceptually and some of the same math.

Our human language is a sequence of variously long vectors, and when we do machine translation, we usually encode context using a recurrent neural network (RNN), which is a neural network model that transmits word information in a linear manner (that is, each word needs to be processed individually as input). This linear approach to RNN brings up two problems:

** Due to the natural sequence structure of RNN, the training is processed in a linear way and cannot be parallelized, so the training speed is limited

**RNN is weak in processing long text. **RNN is weak in processing long text. When processing a word, the state of the current processing word information will be transmitted to the next word, and the information of a word will decay with the increase of distance. But in long texts, you need the context to know the meaning of the word, such as “I arrived at the bank after crossing the river.” When you look at the river, you should know that the bank is very likely to be a bank. In an RNN, you need to process all the words from bank to river step by step, and RNNS are often less effective when they are far apart.

To solve these two problems, we need to use the self-attention model.

Two, what is the calculation process of self-attention?

In the process of calculating self-attention, every word will be Embedding to get the word vectorFor each inputThe first thing you have to do is map linearly to three different Spaces, and you get three matrices. Among them,Linear mapping toParameter matrix of.Is the parameter obtained during the training process, we first have a concept, later through the code to see how these parameters matrix.

The self – image: attention main parameters jalammar. Making. IO/illustrated…

So let’s see how the self-attention model works.

Self-attention

Suppose the sequence of inputs isThe output sequence isThen, self-attention calculation can be divided into three steps

1, calculate Q(query vector Quey), K(key vector), Value(Value vector)

2. Calculate the attention weight, using the dot product as the attention scoring function

It can be shortened to:

Among them,Represent query vectorOr the bond vectorThe dimensions of

3. Calculate the output vector sequence

Among them,Is the position of the output and input vector sequences,Represents the weight of concern from the JTH input to the NTH output

The above description may be a bit abstract, but let’s use an example to see how it works.

In our example, we initializeIs the following value:

Parameter initialization:

Parameter matrix

[[0, 0, 1], [1, 1, 0], [0, 1, 0], [1, 1, 0]Copy the code

Parameter matrix

[1, 0, 1], [1, 0, 0], [0, 0, 1], [0, 1, 1]Copy the code

Parameter matrix

[[0, 2, 0],
 [0, 3, 0],
 [1, 0, 3],
 [1, 1, 0]]
Copy the code

note:Gaussian, Xavier_, and _Kaiming_ random distributions are used for initialization. Complete these initializations before training begins. Tensor2tensor is an open source framework from Google, which implements models such as Attention, self-attention, Bert, etc. However, this model is outdated. Google developed Trax as an alternative to Tensor2Tensor. Let’s see what Trax does in the self-attention model with respect to Q, K, and V.

Some of the code has been deleted for ease of understanding, leaving only the trunk.

# trax/layers/attention.py def AttentionQKV(d_feature, n_heads=1, Dropout =0.0, mode='train', cache_KV_in_predict=False, q_sparsity=None, result_sparsity=None): # Construct q, K, V processing layers, K_processor = core-.dense (d_feature) v_processor = core-.dense (d_feature) if q_sparsity is None: q_processor = core.Dense(d_feature) return cb.Serial( cb.Parallel( q_processor, k_processor, v_processor, ), PureAttention( # pylint: disable=no-value-for-parameter n_heads=n_heads, dropout=dropout, mode=mode), result_processor )Copy the code

Where core.dense constructs a full join layer that calls the init_weights_and_state function to initialize the weights

Initialize the hidden layer weights

# trax/layers/core.py def init_weights_and_state(self, input_signature): Shape_b = (input_signature.shape[-1], self._n_units,) rng_b = fastmath.random.split(self.rng, 2) w = self._kernel_initializer(shape_w, rng_w) if self._use_bfloat16: w = w.astype(jnp.bfloat16) if self._use_bias: b = self._bias_initializer(shape_b, rng_b) if self._use_bfloat16: b = b.astype(jnp.bfloat16) self.weights = (w, b) else: self.weights = wCopy the code

Input X:

X1 =,0,1,0 [1] x2 =,2,0,2 0 and 3 x3 =,1,1,1 [1]Copy the code

Step 1: Calculate

# calculate Q (1, 0, 1] [1, 0, 1, 0] [1, 0, 0] [1, 0, 2] [0, 2, 0, 2] x [0, 0, 1] = [2, 2, 2] [1, 1, 1, 1] [0, 1, 1], [2, 1, 3] # K calculation [0, 0, 1] [1, 0, 1, 0] [1, 1, 0] [0, 1, 1] [0, 2, 0, 2] x [0, 1, 0] = [4, 4, 0] [1, 1, 1, 1] [1, 1, 0] [2, 3, [1] # computing V 0, 2, 0] [1, 0, 1, 0] [0, 3, 0] [1, 2, 3] [0, 2, 0, 2] x [1, 0, 3] = [2, 8, 0] [1, 1, 1, 1] [1, 1, 0] [2, 6, 3]Copy the code

Query, key, and value calculation

Step 2: Calculate attention weights

We calculate the weight of attention by means of dot product, and the formula for calculating the weight of attention is as follows:

First, the weight of attention is calculated by taking the dot product of K transpose and Q.

[1, 0, 2]    [0, 4, 2]   [2, 4, 4]
[2, 2, 2] x  [1, 4, 3] = [4, 16, 12]
[2, 1, 3]    [1, 0, 1]   [4, 12, 10]
Copy the code

Among them,Represent query vectorOr the bond vectorThe dimension of theta, which is here,For the sake of convenience, let’s just take one decimal place, so √_3_ = 1.7.

So, according toThe calculation can be obtained:

[1.2, 2.4, 2.4]
[2.4, 9.4, 7.1]
[2.4, 7.1, 5.9]
Copy the code

And finally, we calculate, the attention weight matrix is obtained

# attention weighting matrix [0.1, 0.4, 0.4] [0.0, 0.9, 0.0] [0.0, 0.7, 0.2]Copy the code

Attention weight calculation

For query Q, different key values K have different attention weights, for example: for input, the key value is, the corresponding attention weights are 0.1, 0.4 and 0.4 respectively.

Step 3: Calculate the output vector sequence

The formula for calculating the output vector sequence is as follows:

Among them,Is the position of the output and input vector sequences,Represents the weight of concern from the JTH input to the NTH output

H1 = [1, 2, 3] * 0.1 + (2, 8, 0) * 0.4 + / 2, 6, 3 * 0.4 = [1.7, 5.8, 1.5]Copy the code

The h2 = [1, 2, 3] * 0.0 + (2, 8, 0) * 0.9 + / 2, 6, 3 * 0.0 = [1.8, 7.2, 0]Copy the code

H3 = [1, 2, 3] * 0.0 + (2, 8, 0) * 0.7 + / 2, 6, 3 * 0.2 = [1.8, 6.8, 0.6]Copy the code

Self-attention

3. Multi-head attention mechanism:

The self-attention model can be regarded as establishing the interactive relationship between different vectors in input X in a linear projection space. In order to extract more interactive information, we can use multi-head self-attention to capture different interactive information in multiple different projection Spaces.

Multi-attentional mechanism is an extension of self-attention. For input X, multi-attentional mechanism is used, and the number of heads used is set as N. Then, the input vector X is divided into N independent vectors, and the attention weight of each vector is calculated using self-attention.

The Transformer model we will cover in our next post uses this multi-attentional mechanism.

The self-attention model calculates weightsOnly rely on, while ignoring the position information of input information. Therefore, it is generally necessary to add the position coding information to correct when it is used alone.

Reference:

zhuanlan.zhihu.com/p/47282410

Jalammar. Making. IO/illustrated…

Arxiv.org/abs/1706.03…

Towardsdatascience.com/illustrated…

google/trax

Neural Networks and Deep Learning

Machine Reading Comprehension: Algorithms and Practices