Vision Transformer’s input visualization methods are widely used in Baidu’s content understanding and content risk control businesses, helping researchers build better models. Baidu content strategy team analyzed the false detection data of the risk control pornographic image classification model based on Transformer input visualization method and directional design of data processing strategy. Under the condition of keeping pornographic image recall unchanged, the false detection decreased by 76.2% compared with base, greatly improving the accuracy of the model.
The full text is 2760 words, and the expected reading time is 7 minutes
I. Visual meaning and meaning
Since the emergence of AlexNet in 2012, convolutional neural network has gradually become one of the most effective methods for common visual tasks such as image classification, target detection and semantic segmentation, attracting widespread attention and application. Later, researchers introduced Transformer, which is popular in NLP and other sequential tasks, into the visual field and achieved similar or even better effects than CNN in multiple tasks. However, both CNN and Transformer are end-to-end training and reasoning, which belong to the black box model to a certain extent. Researchers have little contact with the parameter distribution and feature changes inside the model, and their analysis and control of the model effect is very limited.
The purpose of neural network visualization is to transform the parameters or activation values inside the model into feature maps or significance maps that can directly transmit information, so as to help researchers intuitively and accurately judge the current fitting state of the model, or locate and analyze possible problems when the inference results are incorrect.
Previously, researchers trained an image classifier to distinguish polar bears from forest brown bears with 100 percent accuracy. Someone photocopied a picture of a polar bear into a forest environment and fed it to a classifier and misidentified it as a brown bear. It turns out that the image classifier didn’t learn polar bears and brown bears at all, but instead learned to make judgments based on their surroundings. The model’s visualization of input data can be used to solve such problems. Specifically, given an image classification model and test image, it is hoped to obtain the input basis for the model to divide the test image into a certain category C, that is, the model determines the test image belongs to category C by which pixels or regions.
This paper takes image classification task as an example to discuss the visualization of input data by model.
Second, the input visualization difference between CNN and Transformer in visual tasks
Data in neural networks mainly include Activation of neurons during forward propagation and Gradient of neurons or parameters during back propagation. Almost all visualization methods use these two types of data to construct feature maps or saliency maps.
2.1 Input visualization in CNN
In the common image classification CNN model structure, the input is 1 or 3 channel images with a certain width and height, through a series of convolution layer, pooling layer, GAP layer and full connection layer, and finally mapped into the probability of each category. The feature graph in the network can be regarded as a “shorter, thinner and thicker” input with a smaller width and height and an increased number of channels than the input. According to the calculation principles of convolution and pooling, the features of the feature graph in the width and height dimensions correspond to those of the input image in the width and height dimensions, that is, the value of the upper-left corner region of the width and height dimensions of the feature graph also corresponds to the value of the upper-left corner region of the width and height dimensions of the input image. Therefore, activation values or gradient values of feature maps are commonly used in CNN to construct visual results of input images.
For example, in the Class Activation Map (CAM) method, feature maps of the previous GAP layer are sampled to the input resolution, weights learned in the full connection layer are used as weights, and the feature maps of each channel are weighted and normalized into significant maps.
For example, in the Gradient Class Activation Map (Grad-CAM) method, the Gradient mean of each single feature Map in the output pair is used to replace the weights learned at the full connection layer as the weight, and then all feature maps are weighted and normalized to obtain the significance Map. You can get rid of the model’s dependence on a fully connected layer.
2.2 Input visualization in Transformer
Transformer processes sequence data. Vision Transformer (ViT) cuts the input image into N small patches, each patch gets one-dimensional token of fixed length after embedding. Add a total of N +1 tokens to the class tokens and encode them into Q, K and V respectively. Then, self-attention is performed to get the weights, and all the v’s are weighted and summed to get the output. In ViT, there is no longer a “shorter, thinner, thicker” feature map like CNN as a base to build a visual saliency map for the input.
So what data should be used in ViT to build the input saliency map?
We noticed that the input saliency map represents the meaning of finding those pixel regions in the input image that have an important influence on the final classification decision. The output of ViT is essentially the weighted sum of a series of values, and the weight represents the importance. The weight is calculated by self-attention between different tokens, and each token corresponds to each patch of the image one by one, that is, the prominence of some patches leads to the final decision result, which is exactly the meaning of input visualization. Therefore, the key to ViT input visualization lies in which patch’s token and class token get more weight in self-attention, and which patch’s pixel region has more influence on the final classification result.
Transformer input visualization method
3.1 Common Transformer input visualization methods
At present, Transformer has few research achievements on input visualization, and the existing methods are basically built around self-attention module. Rollout [1] uses self-attention activation values propagated forward from the model to construct input saliency graphs, but the construction results are class-Agnostic and cannot reflect the influence of different categories on model decision results. LRP[2] modeled the interaction between neurons through Deep Taylor Composition, and combined with the idea of back propagation, constructed the influence of input on output logit. Partial LRP[3] took into account the different importance of multi-heads in Transformer, eliminated the influence of heads with small contribution to the results by sorting weights norm, and constructed the input significance graph based on LRP. However, these methods rely on complex LRP calculation process and have poor visualization effect for small targets.
3.2 A robust input visualization method for Transformer
Combining the analysis in Section 2.2 and the grad-CAM approach, we propose a robust input visualization method for Transformer, which uses the activation value and gradient value of self-attention to construct a saliency map of the input. Taking ViT as an example, assume that the model has B Transformer-Encoder blocks with n+1 tokens including class tokens, and that the self-attention part of each Block B contains H heads. The gradient of self-attention values of class token and other N tokens in Block B is obtained by backward propagation after the output of the corresponding visualized target category is set to 1 and all the other categories are set to 0. Saliency map corresponding to the class token in Block B is:
Where, is the dot product, that is, saliency map is the mean value of the dot product of the activation value and the gradient value of self-attention module in each head of Block B. The input saliency diagram corresponding to all tokens/patches is as follows:
ε
That is, the input significance value corresponding to each token/patch is the saliency map of all Block B. Shift the values to zero to prevent an even number of negative values from multiplying into positive values, and add the epsilon to prevent the zero value from drowning out the significant values of other blocks. After the input significance value corresponding to each patch is obtained, it is then pieced into 2d data according to the sequence of patch segmentation and upsampled to the input size, which is used as the visual result of input after normalization. The following figure is the visualization result of THE ViT model trained by ImageNet, in which the odd number is the input image, and the even number is the input visualization significance map corresponding to the top1 class of model prediction results. It can be seen that the judgment basis of the model accurately hits the corresponding pixel region, indicating that the model has indeed learned the accurate classification criteria of input information.
Below for the ViT model in the same input output for different categories of visualization as a result, the odd number listed as the input image, even as the model predicted results ranked by the class corresponding to the input visual figure, can see the model for different categories of visualization result is also very accurate, for the convergence effect of the classification model is a strong proof.
The following figure shows the visualization results of ImageNet trained ViT, SWin-Transformer and Volo models using the method in this paper under the same input. It can be seen that Transformer with different architectures can get the input visualization results more accurately, which proves the good robustness and generalization of the method.
3.3 Service Application Cases
In the content risk control task, we trained an image classification model based on Volo to identify pornographic images in baidu image search scenarios. Due to the huge amount of image data in Baidu image search scenes, which almost covers all image categories in the whole network, the model has high false detection when recall meets the requirements, and it is difficult to continue optimization for a long time. Later, on the basis of full testing, we conducted visual analysis and statistics on misdetection data by using the method in this paper, and concluded 19 objects and scenes that are prone to misdetection, such as dark shorts, arm crossing, wedding, combat and crowd gathering. Negative samples were added to finetune the model in a oriented manner. In the case of maintaining the same recall, false detection was reduced by 76.2% compared with Base, which greatly improved the accuracy of the model.
Four, conclusion
In this paper, a robust visual Transformer input visualization method is introduced. Experimental results show that this method has good visualization effect on models with different Transformer architectures. At present, this technology is widely used in Baidu content understanding and risk control related businesses, helping researchers to build better models. The method in this paper mainly considers the influence of the self-attention value between q and K corresponding to the tokens in Transformer on classification decisions. In the future, the v value corresponding to the tokens can also be included in the visualization consideration to construct a more accurate and Luban visual Transformer visualization method.
Reference:
[1] Samira Abnar and Willem Zuidema. Quantifying attention flow in transformers. arXiv preprint arXiv:2005.00928, 2020.
[2] Sebastian Bach, Alexander Binder, Gre ́goire Montavon, Frederick Klauschen, Klaus-Robert Mu ̈ller, and Wojciech Samek. On pixel-wise explanations for non-linear classi- fier decisions by layer-wise relevance propagation.
[3] Elena Voita, David Talbot, Fedor Moiseev, Rico Sennrich, and Ivan Titov. Analyzing multi-head self-attention: Spe- cialized heads do the heavy lifting, the rest can be pruned. In Proceedings of the 57th Annual Meeting of the Association for Computational Linguistics, pages 5797–5808, 2019.
Recommended reading:
In-depth understanding of WKWebView (Render) – DOM tree construction
In-depth understanding of WKWebView (Introduction) – WebKit source debugging and analysis
GDP Streaming RPC design
Baidu APP video decoding optimization
Baidu Aipanpan real-time CDP construction practice
When technology refactoring meets DDD, how to achieve business and technology win-win?
———- END ———-
Baidu said Geek
Baidu official technology public number online!
Technical dry goods, industry information, online salon, industry conference
Recruitment information · Internal push information · technical books · Baidu surrounding