This is the 9th day of my participation in the November Gwen Challenge. Check out the event details: The last Gwen Challenge 2021

This paper mainly explains the paper Raise a Child in Large Language Model: Towards Effective and Generalizable fine-tuning in EMNLP2021. The title of the paper is somewhat abstract, but in the words of the author, the idea of this paper can be summarized into two words: Child Tuning

Although this article focuses on NLP tasks and NLP-related models, I actually read it and feel that this is a general approach that can be used in the CV domain as well. Specifically, at present, the parameters of the pre-training model are very large, and in downstream tasks, we can only fine-tune the model with limited training sets, which feels like praying mantle-in-the-car. Therefore, the author proposes a new fine-tuning method — Child Tuning. The idea can be summarized in one sentence: in the process of back propagation, we don’t need to update all parameters, only some parameters can be updated, and the Network structure corresponding to these updated parameters is called Child Network (sub-network).

As shown in the figure above, the top line is the normal backpropagation process, where


Δ w 0 = eta partial L partial w 0 (1) \Delta w_0 = -\eta \frac{\partial \mathcal{L}}{\partial \mathbf{w}_0}\tag{1}

The subscript 0 does not refer to a parameter, but to the 0th iteration, and η\etaη is the learning rate. For the following line, part of the δ w0\Delta \mathbf{w}_0 δ w0 is masked out, resulting in a gradient of 0


Δ w 0 = eta partial L partial w 0 Even though M (2) \Delta w_0 = -\eta \frac{\partial \mathcal{L}}{\partial \mathbf{w}_0} \odot M\tag{2}

Where, the elements of MMM matrix are 1 if not 0, and ⊙\odot⊙ is the corresponding position multiplication of elements in the matrix. We can summarize the process of Child Tuning in two steps:

  1. The Child Network is found and validated in the pre-training model, and a 0-1 MASK corresponding to Weights is generated
  2. After calculating the gradient in back propagation, only the parameters in the Child Network are updated

So now the question is how to identify Child Network?

How to find Child Network?

In fact, we don’t really need to find the Child Network, just determine the matrix MMM. This paper provides two algorithms for generating matrix MMM, namely, task-independent algorithm Child_Tuning_F (F for task-free) and task-related algorithm Child_Tuning_D (D for task-drivern)

Child_Tuning_F

Task-independent algorithm means it doesn’t matter what specific task you’re doing, you can use this algorithm, it’s a general method. Specifically, at this time ** MMM is generated according to Bernoulli distribution


w t + 1 = w t eta partial L ( w t ) partial w t Even though M t M t …… Bernoulli ( p F ) (3) \begin{aligned} \mathbf{w}_{t+1}&=\mathbf{w}_{t}-\eta \frac{\partial \mathcal{L}\left(\mathbf{w}_{t}\right)}{\partial \mathbf{w}_{t}} \odot M_{t}\\ M_{t} &\sim \text{Bernoulli}(p_F) \end{aligned}\tag{3}

PF ∈[0,1]p_F\in [0,1]pF is a hyperparameter, which controls the size of Child Network. If pF=1p_F=1pF=1, then Child Network is the original Network. In this case, Child Tuning is Fine Tuning; If pF=0p_F=0pF=0, no parameters will be updated. Here is a simple simulation of the code I wrote to help you understand

import torch
from torch.distributions.bernoulli import Bernoulli

gradient = torch.randn((3.4)) The gradient is represented here by a randomly generated matrix
p_F = 0.2
gradient_mask = Bernoulli(gradient.new_full(size=gradien.size(), fill_value=p_F))
gradient_mask = gradient_mask.sample() / p_F Lambda divided by p_F is to keep the expectation of the gradient constant
print(gradient_mask)

gradient *= gradient_mask
print(gradient)
Copy the code

Bernoulli is a class, and the generated gradient_mask is an object whose sample() method we need to call to get a matrix. The important point is that although we have a 0-1 MASK, we need to enlarge all the 1s in this MASK by 1/pF1/p_F1/pF to maintain the expected value of the gradient

All the other gradients are gone, and the living gradients should spread back with the strong will of others!

Child_Tuning_D

Considering the existence of different downstream tasks, we propose a task-specific algorithm Child_Tuning_D, which can detect the subnetwork (or parameter) that is most important to the target task. Specifically, the authors use Fisher information estimation to find parameters that are highly correlated with specific downstream tasks. Formally, the Fisher Information Matrix(FIM) of the model parameter W \mathbf{w} W is defined as follows:


F ( w ) = E [ ( partial log p ( y x ; w ) partial w ) ( partial log p ( y x ; w ) partial w ) ] (4) \mathbf{F}(\mathbf{w})=\mathbb{E}\left[\left(\frac{\partial \log p(y \mid \mathbf{x} ; \mathbf{w})}{\partial \mathbf{w}}\right)\left(\frac{\partial \log p(y \mid \mathbf{x} ; \mathbf{w})}{\partial \mathbf{w}}\right)^{\top}\right]\tag{4}

Where, x,yx,yx and y are the input and output respectively, so we can deduce the Fisher information of the third parameter as follows:


F ( i ) ( w ) = 1 D j = 1 D ( partial log p ( y j x j ; w ) partial w ( i ) ) 2 (5) \mathbf{F}^{(i)}(\mathbf{w})=\frac{1}{|D|} \sum_{j=1}^{|D|}\left(\frac{\partial \log p\left(y_{j} \mid \mathbf{x}_{j} ; \mathbf{w}\right)}{\partial \mathbf{w}^{(i)}}\right)^{2}\tag{5}

Where ∣ D ∣ | D | ∣ D ∣ is the number of all samples. The author believes that the more important the parameters are to the target task, the larger the Fisher information is. Therefore, Child Tuning is composed of those parameters with the highest Fisher information, and the proportion of Child Network is


p D = C C + C ˉ ( 0 . 1 ] (6) P_D = \ frac {\ mathcal {\ \ mid C mid}} {\ mid \ mathcal {C} / mid + / mid/bar {\ mathcal {C}} \ mid} \ \ tag in (0, 1] {6}

Including ∣ C ˉ ∣ | \ bar {\ mathcal {C}} | ∣ C ˉ ∣ said the subnet, when p_d pD = 1 = 1 pD = 1, the Child Tuning is degraded to Fine Tuning. Calculation of Fisher information is in fact quite time consuming, if we go after each back propagation calculation once all the parameters of Fisher information, then find out the biggest first few is very troublesome, so the author put forward before actually started training, we first of all samples for a complete (an Epoch) transmission of forward and reverse transmission, At this time, the parameters with the highest Fisher information are calculated, and the Child Network determined at this time will not change in the future, and the one selected this time shall prevail

The code to calculate Fisher’s information is given below

def calculate_fisher() :
    gradient_mask, p_F = {}, 0.2
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle=True)
    N = len(train_dataloader) # N = |D|
    for name, params in model.named_parameters():
        if 'layer' in name:
            gradient_mask[params] = params.new_zeros(params.size())
    for batch in train_loader:
        outpus = model(**batch)
        loss = outpus['loss'] if isinstance(outpus, dict) else outputs[0]
        loss.backward()

        for name, params in model.named_parameters():
            if 'layer' in name:
                torch.nn.utils.clip_grad_norm(params, 1)
                gradient_mask[params] += (params.grad ** 2) / N
        model.zero_grad()
    
    r = None
    for k, v in gradient_mask.items():
        v = v.view(-1).cpu().numpy() # flatten
        if r is None:
            r = v
        else:
            r = np.append(r, v)
    
    Percentile (a, q) # a has q% less than polar
    polar = np.percentile(r, (1-p_F)*100)
    for k in gradient_mask:
        gradient_mask[k] = gradient_mask[k] >= polar
    print('Polar => {}'.format(polar))

    return gradient_mask
Copy the code

Proof

If this paper only tells these things, it is highly likely that EMNLP will fail. I personally think that the reason why EMNLP is accepted is related to a large number of proofs in this paper. The author proved that Child Tuning can help the model escape from local minimum points

First of all, we assume that g (I) \ mathbf {g} ^ {(I)} g x (I) (I) is a given sample \ mathbf {x} ^ {} (I) x (I) the parameter w \ mathbf {w} w gradient, And it is normal distribution g (I) ~ N (partial w partial L, sigma g2Ik) \ mathbf {g} ^ {(I)} \ sim N (\ frac {\ partial \ mathcal {L}} {\ partial \ mathbf {w}}, \ sigma ^ 2 _ \ mathbf {g} \ mathbf {I} _k) g (I) ~ N (partial w partial L, sigma g2Ik), Define ∣ B ∣ g = ∑ I = 1 g (I) ∣ B ∣ \ mathbf {g} = \ sum \ limits_ {I = 1} ^ {| \ mathcal {B} |} \ frac {\ mathbf {g} ^ {(I)}} {| \ mathcal {B} |} g = I = 1 ∑ ∣ B ∣ ∣ B ∣ g (I), there are


Δ w = eta i = 1 B g ( i ) B Even though M = eta g Even though M (7) \Delta \mathbf{w} =-\eta \sum\limits_{i=1}^{|\mathcal{B}|}\frac{\mathbf{g}^{(i)}}{|\mathcal{B}|}\odot M = -\eta \mathbf{g}\odot M\tag{7}

For g\mathbf{g}g, we have


E [ g ] = partial L partial w . Σ [ g ] = sigma g 2 I k B (8) \mathbb{E}[\mathbf{g}]=\frac{\partial \mathcal{L}}{\partial \mathbf{w}}, \Sigma[\mathbf{g}]=\frac{\sigma^2_{\mathbf{g}}\mathbf{I}_k}{|\mathcal{B}|}\tag{8}

Set g^=gp⊙M\hat{\mathbf{g}} = \frac{\mathbf{g}}{p}\odot Mg^=pg⊙M, where PPP is pDp_DpD or pFp_FpF (depending on which algorithm you use), then


E [ g ^ ] = E [ 1 p g Even though M ] = 1 p E [ g Even though M ] = p p E [ g ] = partial L partial w (9) \begin{aligned} \mathbb{E}[\hat{\mathbf{g}}] &= \mathbb{E}[\frac{1}{p}{\mathbf{g}}\odot M]\\ &= \frac{1}{p}\mathbb{E}[\mathbf{g}\odot M]\\ &=\frac{p}{p}\mathbb{E}[\mathbf{g}]\\ &= \frac{\partial \mathcal{L}}{\partial \mathbf{w}} \end{aligned}\tag{9}

For example, the PPP of the molecule can only be the result of E[M]\mathbb{E}[M]E[M]. MMM is a matrix. How can the expectation of the matrix become a number? But you can force it, because if you add up all the ones in MMM and divide them by all the elements in MMM it also seems to equal PPP

Gi ^, gi \ hat {g_i}, g_igi ^, gi g ^, g \ hat {\ mathbf {g}}, \ mathbf {g} g ^, g iii dimension values, Then have a gi ^ = gip you Mi \ hat = {g_i} \ frac {g_i} {p} \ odot M_igi ^ = pgi you Mi


D [ g i ^ ] = E [ g i ^ 2 ] ( E [ g i ^ ] ) 2 = p E [ ( g i p ) 2 ] ( E [ g i ^ ] ) 2 = E [ g i 2 ] p ( E [ g i ^ ] ) 2 = ( E [ g i ] ) 2 + D [ g i ] p ( E [ g i ^ ] ) 2 = ( E [ g i ] ) 2 + D [ g i ] p ( E [ g i p Even though M i ] ) 2 = ( E [ g i ] ) 2 + D [ g i ] p ( E [ g i ] ) 2 = D [ g i ] p + ( 1 p ) ( E [ g i ^ ] ) 2 p (10) \begin{aligned} \mathbf{D}[\hat{g_i}] &= \mathbb{E}[\hat{g_i}^2] – (\mathbb{E}[\hat{g_i}])^2\\ &=p\mathbb{E}[(\frac{g_i}{p})^2] – (\mathbb{E}[\hat{g_i}])^2\\ &=\frac{\mathbb{E}[g_i^2]}{p} – (\mathbb{E}[\hat{g_i}])^2\\ &=\frac{(\mathbb{E}[g_i])^2 + \mathbf{D}[g_i]}{p} – (\mathbb{E}[\hat{g_i}])^2\\ &=\frac{(\mathbb{E}[g_i])^2 + \mathbf{D}[g_i]}{p} – (\mathbb{E}[\frac{g_i}{p}\odot M_i])^2\\ &=\frac{(\mathbb{E}[g_i])^2 + \mathbf{D}[g_i]}{p} – (\mathbb{E}[{g_i}])^2\\ &=\frac{\mathbf{D}[g_i]}{p} + \frac{(1-p)(\mathbb{E}[\hat{g_i}])^2}{p} \end{aligned}\tag{10}

so


Σ [ g ^ ] = Σ [ g ] p + ( 1 p ) diag { E [ g ] } 2 p = sigma g 2 I k p B + ( 1 p ) diag { E [ g ] } 2 p (11) \begin{aligned} \Sigma[\hat{\mathbf{g}}] &= \frac{\Sigma[\mathbf{g}]}{p} + \frac{(1-p)\text{diag}\{\mathbb{E}[\mathbf{g}]\}^2}{p}\\ &=\frac{\sigma^2_{\mathbf{g}}\mathbf{I}_k}{p|\mathcal{B}|} + \frac{(1-p)\text{diag}\{\mathbb{E}[\mathbf{g}]\}^2}{p} \end{aligned}\tag{11}

And then we end up with


E [ Δ w ] = eta partial L partial w Σ [ Δ w ] = eta 2 sigma g 2 I k p B + ( 1 p ) eta 2 diag { partial L partial w } 2 p (12) \begin{aligned} \mathbb{E}[\boldsymbol{\Delta} \mathbf{w}] &=-\eta \frac{\partial \mathcal{L}}{\partial \mathbf{w}} \\ \Sigma[\boldsymbol{\Delta} \mathbf{w}] &=\frac{\eta^{2} \sigma_{\mathbf{g}}^{2} \mathbf{I}_{k}}{p|\mathcal{B}|}+\frac{(1-p) \eta^{2} \operatorname{diag}\left\{\frac{\partial \mathcal{L}}{\partial \mathbf{w}}\right\}^{2}}{p} \end{aligned}\tag{12}

In particular, partial L partial w=0\frac{\partial{\mathcal{L}} {\partial \mathbf{w}}=0 \ w partial L=0 \ w partial L=0 \frac{\partial{\mathcal{L}} {\partial \mathbf{w}}=0, Now E Δ w = 0, Σ [Δ w] = eta 2 sigma g2Ikp ∣ B ∣ \ mathbb {E} [\ Delta \ mathbf {w}] = 0, \Sigma[\Delta \mathbf{w}] = \frac{\eta^{2} \sigma_{\mathbf{g}}^{2} \ mathbf {I} _ {k}} {p | \ mathcal {B} |} [Δ w] E = 0, Σ] [Δ w = p ∣ B ∣ 2 Sigma g2Ik eta, we noticed that Σ [Δ w] \ Sigma Delta [\ \ mathbf {w}] Σ [Δ w] is a decreasing function of the PPP, the PPP, the greater the σ [δ w]\Sigma[\Delta \mathbf{w}] σ [δ w] the smaller, the more extreme p=1p=1p=1, the Child Tuning is degraded to Fine Tuning, Sigma[δ w]\Sigma[\Delta \mathbf{w}] Sigma[δ w] is the smallest, which means it doesn’t change much each time, so it’s hard to get out of the local minimum; The smaller the PPP, the larger the σ [δ w]\Sigma[\Delta \mathbf{w}] σ [δ w], the larger the variation, the easier it is to jump out of the local minimum

Personal summary

When I first read this paper, I thought it was very impressive, but after learning about it, I realized that it was just a backpropagation version of Dropout. The actual innovation was not very big, and the Fisher information mentioned in it was not put forward by this paper. Moreover, there are many experiments in the paper. The experimental results show that compared with Fine Tuning, it can be improved by 1.5 ~ 8.6 points. Finally, I would like to talk about the formula proof part of this paper. I personally think the proof of this paper is not very rigorous, such as why the expectation of a matrix becomes a number. In general, this method can be used as a Trick in a game