This article was first published on:Walker AI

In 2018, the birth of Bert brought a great breakthrough to natural language processing. Bert and its derived models achieved the results of SOTA in several downstream tasks of text processing. But such improvements come at a price, one of which is a huge increase in computing capacity.

The Bert-Base model is composed of 12-layer Transformer, with about 100 million parameters involved in calculation, while the bert-Large model with better effect is composed of 24-layer Transformer, and the number of parameters even reaches 300 million. The huge number of parameters puts forward higher requirements on GPU performance and video memory. Especially in the landing application of enterprises, more advanced GPU is needed to complete model training. The same is true in the reasoning process. In the actual situation of online deployment, the response to text analysis is MS level, and renting a server with high computing power requires a lot of cost. So is there a way to have it both ways and reduce the model complexity without sacrificing the classification accuracy?

The answer is yes.

1. BERT as a service (reduce training calculation)

Under normal circumstances, the application of BERT model in text classification is fine-tuning, because BERT is a pre-training model, and Google has learned and trained a parameter model in large-scale text. When we perform text classification, we only need to take pre-training parameters as initial parameters. Using our training set to fine-tune the model can achieve good results. However, this method has to calculate hundreds of millions of parameters, which still consumes computing resources in the training phase. At this time, the idea of using BERT model as a service to generate word vector was born. All BERT parameters were fixed and no longer involved in training, so there was no reverse update. In this method, BERT is used as a generator of word vector, which only generates calculation when the service is called without training, greatly saving the training cost.

This approach omits the training process, which poses a problem because BERT pre-training parameters are derived from a large-scale corpus and are a general model, whereas what we are trying to do is often the classification of text in a specific field, such as medical texts. Because we did not fine-tune the model, the model could not learn some special expressions in specific fields, and there would be a big deviation in the model results using BERT as a service. For such problems, some solutions are to add some basic models such as full connection, CNN and LSTM after BERT. These basic models are trained to learn special representations of the current data set, but these models are shallow and perform worse than fine-tuning Bert.

To sum up, the method of using BERT as a service saves resource consumption in training at the cost of sacrificing certain accuracy, but does not reduce calculation in reasoning. When the service is invoked, relatively high computing resources are still needed.

2. Distill BERT (reduce inferential calculation)

Those of you who have studied chemistry know that distillation can be used to extract essence from a large number of materials. The distillation of BERT is also based on this idea.

The bert-Base model mentioned above is composed of 12-layer Transformer with a total of 100 million parameters. However, not all parameters are necessary for the current task, especially for simple and basic tasks such as text classification. Maybe only 30 million parameters can achieve good results. Under the guidance of this idea, many methods of BERT distillation have been proposed.

The idea of distillation was proposed by Hinton in NIPS 2014. Its core idea is to train a complex Teacher network by a large number of expectations, and then use the Teacher network to train the Student network. This is an important reason why distillation is different from methods such as pruning. In addition, what the student network in distillation learns is the generalization ability of the teacher network, not the ability to fit data. It can be understood as the ability of the teacher to do the questions rather than learning the standard answers to each question.

Taking text emotion classification as an example, in order to enable students’ network to learn the knowledge of teacher network, teacher network should not tell students the emotion category of the current sentence (0 or 1), but should tell students the classification probability (such as 0.73), so that students can learn the knowledge of teacher network. In practice, teacher’s model often has a good classification effect, and the probability distribution obtained is mostly around 0 or 1. At this time, the difference between probability and category is not much. In order to better extract the knowledge of teacher’s model, Hinton added smoothing parameter T into the formula for calculating Softmax, the specific formula is as follows:

The purpose of distillation is to obtain a student network with a smaller number of parameters than the teacher network, and at the same time, the effect of the student network should be as close as possible to the teacher network. To achieve this, a special Loss function needs to be designed. This loss function should not only measure the difference between the probability values output by the teacher network and the student network, but also the difference between the labels output by the student network and the real labels. Different researchers have used different Loss functions, but their general form is as follows:

Among them, CE represents the loss of cross entropy, which can also be replaced by MSE, KL divergence and other formulas to measure differences. Y is the label of data authenticity, Q is the result of the previous formula to represent the output of teacher network, and P is the output of student network.

2.1 Distill BERT to BiLSTM

The proposed method uses Bert-large as the teacher network and two-way LSTM as the student network to conduct fine-tuning learning on the bert-large task, and train the student network by using original data set and enhanced data set after the teacher network training. The design of Loss is similar to the above general idea but different from the details. In this study, MSE between the cross entropy of hard label + Logits between teacher network and student network is used.

Since only using the original data set is effective for teacher network model, which may lead to the failure of students’ network to learn effective features, the author carries out data enhancement for the original data set. The specific methods are as follows:

  • Randomly replace the original word with [mask]

  • Replace the original word with a POS tag

  • A new sentence is constructed by randomly extracting n-gram words from the original sentence

The experimental results are as follows:

The distilled model was comparable to the original model on simple categorization tasks, but less so on complex tasks. Compared with the original teacher model, the number of parameters is reduced by 100 times and the speed is increased by 15 times.

2.2 Distill BERT to Transformer

When Bert is distilled into LSTM, the effect is not obvious, mainly for the following reasons:

  • The number of parameters in LSTM cannot accurately represent semantic features in complex tasks

  • Distillation of the fine-tuned model alone cannot fully learn the full generalization power of the teacher’s model

  • It is impossible to extract all the teacher model knowledge by distilling only the last layer of the teacher model

In view of the above three points, many researchers have improved the distilling model. Later students’ models are mainly Transformer models. For example, Bert-PKD distills BERT’s middle layer, and DIstillBERT starts distilling at BERT’s pre-training stage. TinyBERT made more specific use of the attention matrix of the middle layer of the teacher model, which achieved a good effect.

TinyBERT, for example

The distillation details of one layer are represented here. Implicit state and attention matrix of Transformer in each layer of teacher network are studied simultaneously. Attn-loss and Hidn-Loss are calculated with corresponding layer of student network. This is integrated with label Loss and word vector Loss as the overall final loss.

The distilling effect of Bert to LSTM is significantly improved compared with 2.1.

3. Adaptive exit mechanism (reduce reasoning calculation)

In the use of Bert’s multi-layer Transformer, we found that the results of each layer can be predicted, while the accuracy of the model at the bottom layer is low, while the effect at the top layer is better, because Transformer at the top layer can extract more semantic information. However, for some simple short text classification tasks with obvious characteristics, the inference does not need to use the results of the last layer, and the results of the middle layer can be well predicted. This is different from distillation in Chapter 2, which transfers the teacher’s knowledge to the student model, while the early exit mechanism is an adaptive choice to end the reasoning in a certain middle layer, which can greatly accelerate the reasoning speed in the reasoning process. At the same time, the adaptive confidence threshold can be dynamically adjusted according to the business scenario.

This mechanism was first proposed by FastBert in ACl2020, which is an extension of the dynamic inference that each sample in the CV field has gone through different paths. The author adds a fully connected classifier after each Transformer layer. These classifiers are the branches in the figure, and the original Bert is the main stem. In the course of training, the trunk is fine-tuned. After the training, start self-distillation training branch, using the trunk of the last layer behind the fully connected classifier as the teacher, training students branch. This distillation is a self-distillation process. Loss is designed to measure the KL divergence of trunk and branch.

In the process of reasoning, adaptive reasoning is used. According to the results of branch classifier, the samples are filtered layer by layer, which is simple to give the results directly, and difficult to continue to predict. Here, the author defines a new uncertainty index, which is measured by the entropy of predicted results. The greater the entropy, the greater the uncertainty:

For each level of classification results, the author uses “Speed” to represent the threshold of uncertainty, which is proportional to the inference Speed. Because a smaller threshold => less uncertainty => fewer filtered samples => slower inference speed.

According to the author, speed=0.1 can achieve a 1-10 times faster speed and halve the amount of computation. This method shows good results in the classification of multiple data sets. However, this method can only be used to categorize tasks and needs to be transformed based on business.

4. Experiment with TextBrower

TextBrewer is a PyTorch based toolkit designed for knowledge distillation tasks in NLP.

For the sentence pair dichotomous task LCQMC published by the Intelligent Computing Research Center of Harbin Institute of Technology Shenzhen Graduate School, the purpose of the task is to judge whether the semantics of two sentences are the same.

The teacher network is Roberta-WWM, and the comparison results are as follows:

model LCQMC (Acc) Layers Hidden_size Feed-forward size Params Relative size
RoBERTa-wwm 89.4 12 768 3072 108M 100%
Bert 86.68 12 768 3072 108M 100%
T3 89.0 (30) 3 768 3072 44M 41%
T3-small 88.1 (30) 3 384 1536 17M 16%
T4-tiny 88.4 (30) 4 312 1200 14M 13%

RoBERTa- WWM as teacher network achieves better effect than Bert-Base.

References:

[1] Distilling the Knowledge in a Neural Network: arxiv.org/abs/1503.02…

[2] Distilling Task-Specific Knowledge from BERT into Simple Neural Networks: arxiv.org/abs/1903.12…

[3] Patient Knowledge for BERT Model Compression: arxiv.org/abs/1908.09…

[4] DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter: arxiv.org/abs/1910.01…

[5] Distilling BERT for Natural Language Understanding: arxiv.org/abs/1909.10…

[6] A Self-distilling BERT with Adaptive Inference Time: arxiv.org/abs/2004.02… .

[7] TextBrewer: github.com/airaria/Tex…


PS: more dry technology, pay attention to the public, | xingzhe_ai 】, and walker to discuss together!