This is the 15th day of my participation in the November Gwen Challenge. Check out the event details: The last Gwen Challenge 2021
I had heard of Data Augmentation and Virtual Adversarial Traning, but I did not expect to see a combination of them as Virtual Data Augmentation. This article focuses on Virtual Data Augmentation, a paper on EMNLP2021: A Robust and General Framework for Fine tuning the Pre – steeped Models, the paper puts forward A Robust and General data amplification method, paper source in github.com/RUCAIBox/VD…
At the beginning of this paper, the main problem of data amplification at present is: how to generate data diversity while keeping it in the same semantic space? Simply put, it is easy to increase the diversity of data amplification with a single word at its core: “chaos”. For example, many data amplification methods will randomly scramble the position of tokens in a sentence, or randomly delete some tokens, and randomly insert some tokens. In this way, although the diversity of samples is enhanced, the semantics may also change greatly, or even no longer be the same as the original sample. It is easy to keep the semantic unchanged, or to ensure that the amplified sample and the original sample are in the same semantic space. The core is: “Not too messy”, for example, by synonym substitution, this method can achieve almost no semantic change, but the data diversity is not enough, because it is essentially the same sentence
These two needs are actually contradictory, and all we can do is try to achieve some kind of balance. Specifically, the method proposed by the author includes two important parts: Embedding Augmentation and Regularized Training
Embedding Augmentation
Suppose now we have the sentence “Time is enough for test”, for each position of the token, we can replace it with [MASK], and then predict the probability of all the tokens in Vocabulary at that position through MLM, e.g
[MASK] is enough for test
The token output at the [MASK] position and its probability are
Time p=0.5 Day p=0.3 Hours p=0.15...Copy the code
Again for instance
Times is enough for [MASK]
The token output at the [MASK] position and its probability are
Test p=0.5 evaluation P =0.3 experiment p=0.1...Copy the code
Seeing this, you may have an idea of data amplification in mind, which is to use MLM task to predict the token at each position in the sentence, and then randomly select a token to replace according to the predicted probability. For example, the sentence above might be replaced with “Hours is enough for evaluation”. That’s not a bad way to augment data, but that’s not what the authors did
For the sake of simplicity, we only discuss the amplification of one token w~\tilde{w}w~ in a given sentence SSS (virtually all tokens in sentence SSS do this). We can predict the probability of w \tilde{w}w~ position of all words in Vocabulary by MLM task
VVV is the number of tokens in Vocabulary
In order to increase the diversity of data amplification, or to introduce some noise to enhance anti-interference, we randomly sample a vector from the Gaussian distribution
By mixing this vector with the probability distribution of formula (1), we can get a new probability distribution
Then for each token w~\tilde{w}w~ we fuse the Embedding vectors of all tokens W ^ I \hat{w}_iw^ I based on the probability P ‘(W ^ I ∣S)p'(\hat{w}_i\mid S) P ‘(W ^ I ∣S)
Among them, the pw ~ = {p ‘(w ^ I ∣ S)}, I = 1 v \ mathbf {p} _ {\ tilde {w}} = \ {p’ (\ hat {w} _i \ mid S) _ \} {I = 1} ^ Vpw ~ = {p ‘(w ^ I ∣ S)}, I = 1 v ME∈RV×d\mathbf{M}_E\in \mathbb{R}^{V\times d}ME∈RV×d is the word vector matrix of MLM model
Take a simple example to explain. For convenience, the replacement of a token is also taken as an example. There are only 4 tokens in the whole Vocabulary, and the dimension of word vector is 2. First of all, we have a sentence “She is a good student”, MASK “good”, and then predict the probability distribution by MLM model
From left to right are the probabilities of good, perfect, excellent, and smart, respectively. According to the Gaussian distribution N(0,σ2)\mathcal{N}(0, \sigma^2)N(0,σ2), the randomly generated vector is
I didn’t specify what the variance σ2\sigma^2 is, because I’m too lazy to calculate it
P (W ^ I ∣S) P (\hat{w}_i\mid S)p(W ^ I ∣S) mixed with ϵ\epsilonϵ by Softmax gives a new probability distribution
Assume that the Embedding matrix is
The final embedding is “good”
Virtual Data Augmentation, which essentially means that they do not replace with a real token but with a mentation. If the embedding is added to the ME\mathbf{M}_EME matrix, the corresponding index will not be found. This means that the generated embedding does not correspond to an actual token
Regularized Traning
It’s a very story-telling title, but essentially it’s introducing a loss function, and specifically, now our optimization goal is zero
FFF refers to the pre-training model with parameters θ\thetaθ, NNN refers to the number of samples, and KKK refers to KKK sentences amplified from one sentence. Specifically, if it is a classified task, then
⋅,⋅)\text{CE}(\cdot,\cdot)CE(⋅,⋅) is a cross-entropy Loss, with Ei\mathbf{E}_iEi representing the vector generated by Word2Vec Its dimension is [seq_len, emd_dim]
In order to prevent a huge semantic gap between the amplified sample and the original sample, in other words, we hope that the distribution between the amplified sample and the original sample is close, so the paper introduces KL divergence as the second loss
Among them, KKK refers to KKK samples amplified from the original sample, DsKLD_{sKL}DsKL is the symmetric KL divergence, specifically
In fact, this method can be regarded as multi-task, and we hope that the model parameters can be trained to a state where the model can do a good job whether the model performs downstream tasks on the original sample or allows the model to judge the gap between the original sample and the amplified sample. Finally, a figure from the paper is given to end this part (3 samples are amplified from one sample in the figure).
Results
The improvement in the original accuracy comparison alone did not seem to be very strong, and it seemed that I could achieve and even exceed Virtual Data Augmentation by introducing just a few tricks. The key lies in the second column “Att Acc”, which represents the result when the model is attacked. The improvement of this part is particularly large, indicating that VDA does have strong anti-interference or strong robustness
Personal summary
As a matter of fact, this paper has been clearly stated before, and there is nothing to summarize here. However, I would like to discuss my personal opinion with you, because when he does MLM task, he takes the whole Vocabulary as a candidate set, which is not very friendly to both computing speed and video memory occupancy. I think it can be changed to the Top K tokens with the maximum probability of taking out, and the k can be taken a little larger, such as 200, 300, etc., so as to ensure that the following tokens which are not semantically similar can be taken, while avoiding the operation of the whole Vocabulary. At least not tens of thousands of thousands of probability distributions