0 x00 the

GPipe is a parallel library for neural network training that supports very large scale models developed by the Google Brain team. This article describes its recalculation capabilities, which can be verified with other implementations.

Other articles in this series are as follows:

Deep learning pipeline parallel Gpipe(1)– pipeline basic implementation

Deep learning pipeline parallel GPipe (2) —– gradient accumulation

0 x01 overview

1.1 Review

As mentioned above, there are several necessary parallel techniques for distributed model training at present:

  • Flow parallel, especially how to automatically set flow;
  • Gradient Accumulation;
  • Backward recalculation;
  • 1F1B strategy (we will use PipeDream analysis);

In the previous article, we described how Gpipe implements pipeline parallelism and gradient accumulation.

One problem with stream parallelism is that it consumes too much video memory. If activation intermediate results from each micro-batch forward calculation are consumed by backward calculation, n complete forward activations (the number of gradients accumulated) need to be cached in video memory. It was then necessary to use another important technique, Checkpointing.

In this paper, pyTorch and Gpipe source code is analyzed based on the paper “Training deep Nets with Sublinear Memory Cost “. Expect to have a specific understanding of “Gradient Checkpointing” technology.

1.2 Gradient checkpointing

In 2016, Tianqi Chen’s team proposed “Gradient/Activation Checkpointing” and other technologies related to sublinear memory optimization, aiming to reduce the occupation of video memory caused by intermediate activation during deep learning training. Checkpointing is a sublinear memory optimization technique, along with other techniques such as CPU offload (widely used in Microsoft Deepspeed framework).

Gradient checkpoint is a systematic method to reduce memory consumption during deep neural network training, specifically by re-running the forward propagation segment for each segment set as checkpoint in back propagation:

  • The gradient checkpoint approach focuses on reducing the memory overhead of storing intermediate results (feature graphs) and gradients, because in many common deep networks, intermediate results are much larger than model parameters.
  • Gradient checkpoint is a time (computational power) for space (video memory) method that compresses the model space by reducing the number of stored activation values, but the gradient must be recalculated for activation values that are not stored, which takes twice the computation time of forward propagation.
  • Specifically, it is to set some gradient checkpoints, and the intermediate results outside the checkpoints are released first. In the future, if the forward results are not found in the video memory in the process of back propagation, the nearest gradient checkpoint is found and the forward calculation is performed to recover the released tensor.

0x02 Background

2.1 How does derivation work

Here we draw lessons from the training video memory optimization technology — OP merge and gradient checkpoint.

The DNN model consists of a number of different types of layers (e.g., convolution layer, full connection layer, pooling layer).

The key to back propagation is “automatic chain derivative”, but in fact BP adds a bit of dynamic programming mechanism to this foundation. General BP consists of the following two steps:

  • Forward conduction. Taking image classification as an example, the current model first predicts a small number of training samples (also known as minibatch). This process is called forward conduction.

    • To make predictions, input data from small batches is fed into the first layer of the model.
    • Each layer then computes a function on its input to generate the output for the next layer. Forward conduction records the following two values: the output value of the intermediate node, and the gradient of the output value with respect to the input value.
    • The final layer of output is the class prediction. Based on the model’s prediction labels and the actual labels for each image, the output layer calculates losses (or errors).
  • Calculation of back propagation gradient. Back propagation is the process of calculating the gradient of the network’s final output with respect to the output of the layer. That is, from the output, propagate the gradient back, compute the gradient of the output for each intermediate variable, and save it. Each layer calculates the error of the previous layer and updates the weight (loss gradient) of all related layers, which moves the model’s prediction towards the desired output.

The output value of nodes is needed in the process of gradient back transmission, but BP will not repeat the calculation during gradient calculation of back propagation, because intermediate variables are stored during forward transmission, that is, the output value of each intermediate node. BP continuously propagates the gradient back and saves the intermediate gradient until the gradient of all the intermediate and initial values of the computed graph is solved.

Let’s see how back propagation works.

The so-called automatic derivative framework is actually “semi-automatic” : it does not directly solve the analytic form of the derivative of a complex function, but realizes the automatic derivative by constructing the calculation diagram and the pre-written derivative rules of the basic function, combined with the chain derivative rule.

Let’s take a function as an example, whose expression is as follows:

f(x) = x * (x + 1)
Copy the code

Through simple mathematical derivation, the analytic formula of its gradient is f'(x) = x + 1 + x; Putting that aside for a moment, let’s see how the auto derivative framework works out the result step by step, drawing the calculation as follows:

+---------+ | | +------>+ x + 1 +----+ | | | | 3 2 | +---------+ | | | | v +-----+--+ ++------+ | | | | +------> | x + -- -- -- -- -- -- -- -- -- -- -- -- -- -- -- - > | + + -- -- -- -- -- -- -- -- -- - > | | 1 | | + -- -- -- -- -- -- -- -- + + -- -- -- -- -- -- -- +Copy the code

In the calculation diagram, the back propagation is first multiplied, according to the above derivation rules:

  • The gradient on path 1 is zerox + 1;
  • The gradient on path 3 is zerox;
  • Path 3 and then back propagation through path 2, except that its gradient isx + 1And then we have to multiply by the gradient of path 21.
  • Path 2 and path 1 converge, so the final gradient is zeroX plus 1 of path 1 plus 1 times x of path 2 is equal to x plus 1 plus x, which is exactly what we calculated mathematically;

The automatic derivative framework is precisely relying on these basic rules and chain derivative rules in efficient and accurate operation.

In the training process of most neural networks, some intermediate variables obtained in the forward propagation process are very useful (in order to facilitate derivation) when calculating back propagation. As a practical matter, it is best to implement code that caches these intermediate variables so that they can be used when propagating back. So the large portion of video memory is the intermediate result, which is called the “feature map.” For this article, x is the intermediate result (feature graph) of the previous layer’s output.

When we apply the derivative rule of multiplication, we’re going to have to preserve the intermediate x and x plus 1. Note that the framework defines multiplication and its derivative rules as general rules, and it is entirely possible that the left and right sides of multiplication are two unrelated values, so they must be preserved at the same time. That is, x plus 1 in some other function, maybe x plus y plus z…. , may also contain other input variables, so it cannot be computed from an input x by a simple formula like + 1.

Without considering the framework’s own optimization, the video memory footprint consists of an X and an x + 1. Note that x is not a single number, but a feature graph of a size like 32x32x128.

2.2 the gradient:

As described in the previous section, in the original mode of neural networks:

  • In the forward function, the values of the activation function at each level need to be saved after the calculation, because they need to be consumed in the backward propagation calculation.
  • In the case of BACKWARD, the gradient is calculated according to the loss function value and the corresponding activation function value of the layer.
  • Therefore, we need to cache n copies of a full forward activation (the number of gradients accumulated) in video memory. In other words, in this case the video memory usage is proportional to the number of layers.

Therefore, current flow parallel has a problem: video memory occupancy is too large.

Is it possible not to store activation values? For example, in the case of BACKWARD, it is ok to re-perform forward when the function value needs to be activated.

What if we don’t store any of them, we recalculate them forward? So in a large model that takes too much time. So we can choose a compromise, such as only storing the activation function value of part of the layer. When an activation function value is needed with backward, the nearest activation value is taken. So an important technique was introduced: Checkpointing.

2.3 Thesis Content

2.3.1 Main papers

Gpipe’s Checkpointing mainly comes from the following two papers:

  • Andreas Griewank and Andrea Walther. Algorithm 799: revolve: an implementation of check- pointing for the reverse or adjoint mode of computational differentiation. ACM Transactions on Mathematical Software (TOMS), 26 (1) : 19-45, 2000.
  • Tianqi Chen, Bing Xu, Chiyuan Zhang, and Carlos Guestrin. Training deep nets with sublinear memory cost. ArXiv preprint arXiv: 1604.06174, 2016.

The main ideas are to swap memory for computational power (computing for video memory, and recalculating the intermediate result from checkpoint in reverse differentiation) and bandwidth for video memory.

2.3.2 Training Deep Nets with Sublinear Memory Cost

2.3.2.1 Main Ideas

Let’s focus on this paper.

Checkpointing, also known as Sublinear Memory optimization, was mentioned by Tianqi Chen in his paper Training Deep Nets with Sublinear Memory Cost published in 2016. Sublinear memory optimization has two ideas, Checkpointing and CPU offload:

  • Checkpointing is the core idea in the former to mark a small amount of Tensor in network (the Checkpointing Tensor), the forward calculation will only keep these marked Tensor, the rest of the former to the activation, It will be obtained by temporarily recalculating the forward Tensor from Checkpointing in back propagation. This means that a lot of activation doesn’t need to be stored until the calculation is done backwards, reducing the lifetime of a lot of Tensor and making memory overuse much more efficient.
  • The idea of CPU offload is similar to the “virtual Memory” technology in the computer operating system (temporary swapping of unused Memory into and out of disk to increase the total Memory). In deep learning, GPU Device Memory is characterized by expensive, high speed and small capacity. The characteristics of CPU Host Memory are cheap, relatively low speed and large capacity. Video memory can also be saved by temporarily switching activation from forward computation to CPU main memory and then switching it to GPU video memory when needed for reverse computation.

Two kinds of sublinear memory optimization achieve video memory optimization in different ways: Checkpointing is to swap video memory by extra computing overhead, CPU offload is to swap video memory by extra transmission overhead.

2.3.2.2 Checkpointing optimization

The figure above shows a comparison of the calculation chart before and after Checkpointing.

The network configuration is gray on the left.

The Normal Gradient Graph in the middle is the forward and backward propagation process of the common network.

The Memory Optimized Gradient Graph on the right is the result of applying gradient-checkpoint. To further reduce memory, some intermediate results are removed and recovered from additional forward calculations if needed.

  • First, the neural network is divided into several parts (three in the image on the right), and the algorithm only remembers the output of each segment and removes all intermediate results from each segment.
  • Second, in the backpropagation phase, we can recalculate the discarded intermediate results by running forward from the most recent recorded results.
  • Therefore, we pay only the memory cost of storing the output of each segment plus the maximum memory cost of propagating back over each segment.

Therefore, gradient-checkpoint is not that intermediate results are not needed, but that there is a way to calculate the intermediate results that have been discarded in real time during the derivation process.

Recalculation was not designed solely for pipelinization and was previously mostly used in single-card or data-parallel scenarios. But this optimization is critical in pipelining, because it doesn’t cache all the activation forward, but only a very small number (for example, a Transformer Layer will cache only one), a specific Tensor at checkpoint, Thus, the overhead of video memory under stream parallel is greatly saved.

0x03 OpenAI

Gradient-checkpoint proposed by OpenAI is the realization of the idea of Training Deep Nets with Sublinear Memory Cost, because its documentation is relatively complete (github.com/openai/grad…). We can learn from it.

The general idea is as follows: set several checkpoints in the middle of the neural network, and reserve one checkpoint every SQRT (n) for the middle result feature map. All intermediate results other than the checkpoint are discarded, and the time of derivative derivation is back-propagated. When an intermediate result is needed, the calculation starts from the nearest checkpoint, which not only saves video memory, but also avoids the tedious process of ab initio calculation.

3.1 calculation chart

For a simple N-layer feedforward neural network, the gradient is obtained as follows:

Details are as follows:

  • The hierarchical activation value of the neural network corresponds to the nodes labeled with F, and all these nodes need to be calculated in order during forward propagation.
  • The loss function is marked with B-nodes for activation values and gradients of these hierarchical parameters, and all of these nodes need to be computed in reverse order during backpropagation.
  • Calculation of the activation value of F node is the prerequisite for further calculation of the gradient of B node, so F node will be retained in memory after forward propagation.
  • These activation values can only be cleared from memory if the backpropagation is performed far enough so that the computation of the corresponding gradient no longer requires the use of activation values from later levels or children of F. This means that simple back propagation requires a linear growth relationship between memory and the number of layers of the neural network.

3.2 heavy calculation

Simple back propagation is already optimal because each node only needs to be evaluated once. However, if we are willing to recalculate the nodes, then we can save a lot of memory. When we need the node activation value, we can simply recalculate the forward-propagated node activation value. We can perform the calculations sequentially until we calculate the nodes that need to be propagated back with the activation value.

To use this strategy, we need to make the number of neural network layers for computing gradients stable at n, and n is optimal in terms of memory. Note, however, that the number of nodes counted is now expanded by n^2, compared to the previous n. Each of the n nodes is reevaluated n times. So the graph becomes slow to compute the deep network, making this method unsuitable for deep learning.

3.3 strategy

To strike a balance between memory and computation, we need a policy that allows nodes to be recalculated, but this recalculation does not happen very often. The strategy we use here is to label a subset of neural network activation as a node. The purple nodes are the ones that need to be stored in memory at a given time.

These checkpoint nodes remain in memory after forward propagation, while the remaining nodes are recalculated at most once. After recalculation, non-checkpoint nodes remain in memory until they are no longer needed to perform back propagation. For a simple feedforward neural network, the active nodes of all neurons are connection points or graph separation points defined by forward propagation. This means that we only need to recalculate the nodes between node B and the last checkpoint during backpropagation, and when backpropagation reaches our saved checkpoint node, all recalculated nodes from that node can be removed from memory.

3.4 process

First, we have two checkpoints, two purple points on the left of the first line of the graph. Note that the first purple point on the right is the input.

Second, forward propagation has been completed, and back propagation has begun, that is, from the bottom line of purple 1.

Third, we arrive at purple 2 on the bottom line, which depends on purple 3 on the top (recall that backward propagation requires the output of forward computation). This purple 3 is checkpoint and exists in memory, so it performs normal backward propagation

Fourth, we arrive at the white checkpoint 4 on the bottom line, which relies on the purple checkpoint 5 on the top. 5 is not a checkpoint, it is not in memory, so we need to start the calculation at the checkpoint before it, that is, at purple 7. Compute a new checkpoint and delete the original purple 5 from the top row because it is no longer needed.

Fifth, calculate the new purple number 4 below, thus continuing the calculation backward.

Because it involves automatic checkpoint generation, the OpenAI code is quite obscure, so it is not analyzed here. If you are interested, you can learn by yourself.

0 x04 Pytorch implementation

Let’s use Pyorch next.

4.1 Basic Knowledge

4.4.1 Variable and Function

In PyTorch, Autograd is the heart of all the neural networks, providing automatic derivatives for all the Tensor operations. It is a run-as-you-go framework, meaning that Backprop is defined by how the code is run.

Autograd. Variable is the most core class of Autograd. It wraps a Tensor and supports almost all the operations defined on it. Once you’ve done your arithmetic, you can call.backward() to automatically calculate all gradients.

Another class that is very important for the implementation of Autograd is Function. Function simply refers to the operation of Variable, such as addition, subtraction, multiplication, division, relu, pool, etc. But it’s not just simple arithmetic. Unlike normal Python or Numpy operations, Function computes graphs and computes backpropagation gradients. Therefore, it not only needs to perform this operation (the forward process), but also needs to use the cache to preserve the input of forward propagation (calculating the gradient) and support backward propagation to calculate the gradient.

Pytorch uses variables and functions to build computational diagrams. As a review of Variable, Variable is like the node in the calculation graph, saving the calculation result (including the activation value propagated forward and the gradient propagated back), while Function is like the edge in the calculation graph, realizing the calculation of Variable and outputting new Variable.

In conclusion, Function and Variable constitute the automatic derivation mechanism of PyTorch, which defines the computational relationship between various variables.

Note: In the latest PyTorch code, Function has been changed to the Node class, presumably to better represent the concept of nodes in the computation diagram.

4.1.2 Function

We can use the Autograd. Function class to customize a model, a layer, an activation Function, and a loss Function, which are essentially the same Function, depending on whether the Function is simple or complex.

4.2 Common Mode

This code is located in the torch/utils/checkpoint. Py. Pytorch requires the user to checkpoint, so implementation is relatively simple.

2 encapsulation

In the torch/utils/checkpoint. Py, has a package to the checkpoint, the annotation is very worthy of our reading, we learn more about it.

  • Checkpointing is essentially trading calculations for memory.

  • Checkpointing does not store all intermediate activation values of the entire computed graph required for backward calculation, but recalculates them in backpropagation.

  • For forward propagation, the Checkpointing function runs in torch. No_grad mode so that intermediate activation values are not calculated. Instead, pass forward to save the input tuple and function parameters.

  • In backpass, the saved input and function are retrieved, function is evaluated again, this time tracking the intermediate activation value, and

    These activation values are used to compute gradients.

def checkpoint(function, *args, **kwargs): r"""Checkpoint a model or part of the model Checkpointing works by trading compute for memory. Rather than storing all intermediate activations of the entire computation graph for computing backward, the checkpointed part does **not** save intermediate activations, and instead recomputes them in backward pass. It can be applied on any part of a model. Specifically, in the forward pass, :attr:`function` will run in :func:`torch.no_grad` manner, i.e., not storing the intermediate activations. Instead, the forward pass saves the inputs tuple and the :attr:`function` parameter. In the backwards pass, the saved inputs and :attr:`function` is retrieved, and the forward pass is computed on :attr:`function` again, now tracking the intermediate activations, and then the gradients are calculated using these activation values. The output of :attr:`function` can contain non-Tensor values and gradient recording is only performed for the Tensor values. Note that if the output consists of nested structures (ex: custom objects, lists, dicts etc.) consisting of Tensors, these Tensors nested in custom structures will not be considered as part of autograd. Args: function: describes what to run in the forward pass of the model or part of the model. It should also know how to handle the inputs passed as the tuple. For example, in LSTM, if user passes ``(activation, hidden)``, :attr:`function` should correctly use the first input as ``activation`` and the second input as ``hidden`` preserve_rng_state(bool, optional, default=True): Omit stashing and restoring the RNG state during each checkpoint. args: tuple containing inputs to the :attr:`function` Returns: Output of running :attr: 'function' on :attr: '* ARgs' "" # Hack to mix *args with **kwargs in a Python 2.7-Compliant way preserve = kwargs.pop('preserve_rng_state', True) if kwargs: raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)) return CheckpointFunction.apply(function, preserve, *args)Copy the code

4.2.2 Processing equipment

Since PyTorch has no way of knowing whether the forward-propagating function will move some parameters to different devices, some logic is needed to save the RNG state for those devices. Although it is possible to save/restore all RNG state for all visible devices, this is wasteful in most cases, so pyTorch only saves RNG state for all devices with tensor parameters as a compromise.

def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]:
    # This will not error out if "arg" is a CPU tensor or a non-tensor type because
    # the conditionals short-circuit.
    fwd_gpu_devices = list(set(arg.get_device() for arg in args
                               if isinstance(arg, torch.Tensor) and arg.is_cuda))
​
    fwd_gpu_states = []
    for device in fwd_gpu_devices:
        with torch.cuda.device(device):
            fwd_gpu_states.append(torch.cuda.get_rng_state())
​
    return fwd_gpu_devices, fwd_gpu_states
​
​
def set_device_states(devices, states) -> None:
    for device, state in zip(devices, states):
        with torch.cuda.device(device):
            torch.cuda.set_rng_state(state)
Copy the code

4.2.3 Core Logic

CheckpointFunction inherits torch. Autograd. Function.

We can expand Function to meet our own needs, and the expansion needs to customize the forward operation of Function and the corresponding BACKWARD operation, and at the same time, the input value needs to be saved for BACKWARD in the forward.

  • The forward function puts in the tensor, calculates the output tensor.

    • For forward propagation, the Checkpointing function runs in torch. No_grad mode so that intermediate activation values are not calculated.
    • Pass forward saves the input tuple andfunctionParameters.
    • For CheckpointFunction, you still need to store some additional information (RNG information) in the forward for calculation when propagating backwards.
    • The forward propagation returns the activation value.
  • A function with backward receives the gradient of the output tensor with respect to some scalar value and computes the gradient of the input tensor with respect to that same scalar value.

    • In backward pass, the saved input andfunctionIs removed.
    • functionWill be computed again, this time tracing the intermediate activation values, which are then used to compute the gradient.
"" We can implement our custom autograd function by subclassing torch. Autograd and propagating the tensor forward and back. """ class CheckpointFunction(torch.autograd.Function): @staticmethod def forward(CTX, run_function, preserve_rng_state, *args): """ In a forward function, take the Tensor with the input and return the Tensor with the output. CTX is the environment variable used to provide the information needed for backpropagation. We can use context objects to cache objects for use in back propagation. Save_for_backward can be used to cache data. Save_for_backward can only be passed variables of Variable or Tensor. Check_backward_validity (args) # Save the forward propagation function ctx.run_function = run_function ctx.preserve_rng_state = preserve_rng_state  ctx.had_autocast_in_fwd = torch.is_autocast_enabled() if preserve_rng_state: ctx.fwd_cpu_state = torch.get_rng_state() # Don't eagerly initialize the cuda context by accident. # (If the user intends that the context is initialized later, within their # run_function, we SHOULD actually stash the cuda state here. Unfortunately, # we have no way to anticipate this will happen before we run the function.  if torch.cuda._initialized: ctx.had_cuda_in_fwd = True ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args) # Save non-tensor inputs in ctx, keep a placeholder None for tensors # to be filled out during the backward. ctx.inputs = [] ctx.tensor_indices = [] Tensor_inputs = [] for I, arg in enumerate(args): if torch. Is_tensor (ARg): tensor_inputs.append(arg) ctx.tensor_indices.append(i) ctx.inputs.append(None) else: Ctx.inputs. Append (arg) # 'saved_for_BACKWARD' is to retain all information about this input and avoid the case where the input is modified on a backward as a result of an in-place operation. It saves the input parameters of the function for later use in the derivation and plays a coordinating role in the forward and back propagation. ctx.save_for_backward(*tensor_inputs) with torch.no_grad(): Outputs = run_function(*args) # in back propagation, we receive the context object and a tensor containing a gradient relative to the loss of the output generated during forward propagation. We can retrieve cached data from the context object, and must calculate and return the gradient of the loss associated with the forward-propagated input. # staticmethod def backward(CTX, *args) # staticmethod def backward(CTX, *args) if not torch.autograd._is_checkpoint_valid(): raise RuntimeError( "Checkpointing is not compatible with .grad() or when an `inputs` parameter" " is passed to .backward(). Please use .backward() and do not pass its `inputs`" " argument.") # Copy the list to avoid modifying original list. inputs = list(ctx.inputs) tensor_indices = ctx.tensor_indices tensors = ctx.saved_tensors # Saved_variables # Fill in inputs with appropriate saved tensors. For I, idx in enumerate(tensor_indices): • Inputs [IDX] = tensors[I] # Stash the surrounding RNG state and mimic the state that was # present at this time during forward. Restore the surrounding state # when we're done. # Rng_devices = [] if ctx.preSERVE_rNG_state and ctx.had_CUDA_in_FWD: if ctx.preserVE_Rng_state and ctx.had_CUDA_in_FWD: rng_devices = ctx.fwd_gpu_devices with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state): If ctx.preserVE_rng_state: torch. Set_rng_state (ctx.fwd_CPU_state) # Restore the device state when propagating forward if ctx.had_CUDA_in_FWD: set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states) detached_inputs = detach_variable(tuple(inputs)) with torch.enable_grad(), torch.cuda.amp.autocast(ctx.had_autocast_in_fwd): Outputs = ctx.run_function(*detached_inputs) if isinstance(outputs, torch.Tensor): Outputs = (outputs,) # run backward() with only tensor that requires grad outputs_with_grad = [] # activation args_with_grad = [] # gradient # Filter the tensor to propagate from the result of forward propagation calculation for I in range(len(outputs)): if torch.is_tensor(outputs[i]) and outputs[i].requires_grad: outputs_with_grad.append(outputs[i]) args_with_grad.append(args[i]) if len(outputs_with_grad) == 0: Raise RuntimeError(" None of output has requires_grad=True," "This checkpoint() is not necessary") # torch.autograd.backward(outputs_with_grad, args_with_grad) grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None for inp in detached_inputs) return (None, None) + gradsCopy the code

4.3 Pipeline mode

Let’s look at how we Checkpoint in pipeline mode.

The Pytorch flow-water parallelism pattern is inspired by GPipe and is mentioned in its notes.

With CheckpointFunction, PyTorch can combine recalculation and recursive backpropagation into an auto-derivative function so that recalculation begins when the gradient arrives. However, in pipelined-line mode, in order to reduce GPU idle time, recalculation needs to take place before the gradient arrives (because recalculation is actually independent of gradient, recalculation can be carried out before the gradient arrives to obtain the activation value, and after the backward propagated gradient comes, the activation value can be aggregated for gradient calculation).

To solve this problem, PyTorch introduces two automatic derivative functions: Class :Recompute and class:Checkpoint, respectively, separate the CheckpointFunction in normal mode into two phases, so that the automatic differentiation engine and CUDA can be controlled using these two functions. CUDA synchronization is inserted between class:Recompute and class:Checkpoint, delaying class:Checkpoint until the gradient copy is complete.

By splitting segments, multiple pipeline stages can be run in parallel.

This sample

We can see the test/distributed/pipeline/sync/test_checkpoint py the code.

Through clever log printing, we can see the use of checkpoint in forward and backward propagation at runtime.

Timeline final result is [” A :forward”, “B :forward”,” B :forward”, “B: Backward “,” A :forward”, “A: Backward “],

The two groups correspond to forward pass, Checkpoint(Log[b]), and Checkpoint(Log[a]) respectively.

@pytest.mark.parametrize("device", devices)
def test_serial_checkpoints(device):
    # Copied from https://github.com/pytorch/pytorch/pull/18568.
    timeline = []
​
    class Log(torch.autograd.Function):
        @staticmethod
        def forward(ctx, name, x):
            ctx.name = name
            timeline.append(f"{name}:forward")
            return x.detach()
​
        @staticmethod
        def backward(ctx, grad_output):
            name = ctx.name
            timeline.append(f"{name}:backward")
            return None, grad_output
​
    a = torch.rand(1, device=device, requires_grad=True)
    b = torch.rand(1, device=device, requires_grad=True)
​
    # Increase the next function sequence number.
    _ = a + 1 + 2 + 3 + 4 + 5
​
    # 这里意味着最后 backward 实际会运行"a:forward", "a:backward"
    a = checkpoint(partial(Log.apply, "a"), a)
​
    a, phony = fork(a)
    b = join(b, phony)
​
    # 这里意味着最后 backward 实际会运行"b:forward", "b:backward"
    b = checkpoint(partial(Log.apply, "b"), b)
​
    c = torch.cat((a, b))
​
    out = c.sum()
​
    #                        +--> {a} --Checkpoint(Log)--> {a}
    # {out} --Sum--> {c} --Cat     ^-----------------------------+
    #                        +--> {b} --Checkpoint(Log)--> {b} --First--> {b}
    out.backward()
​
    assert timeline == ["a:forward", "b:forward", "b:forward", "b:backward", "a:forward", "a:backward"]
    #    |----------------------|  |-----------------------|  |-----------------------|
    #          forward pass            Checkpoint(Log[b])         Checkpoint(Log[a])
Copy the code

4.3.2 Sharing variables

Class :Recompute and class:Checkpoint are used to store shared variables in the Context.

# Types for shared memory between Checkpoint and Recompute. Recomputed = Tuple[TensorOrTensors, Tensors] # (output, input_leaf) RNGStates = Tuple[Tensor, Optional[Tensor]] # (cpu_rng_state, gpu_rng_state) class Context: """The common interface between the :class:`Checkpoint` and :class:`Recompute` context. """ recomputed: Deque[Recomputed] rng_states: Deque[RNGStates] function: Function input_atomic: bool saved_tensors: Tuple[Tensor, ...]  def save_for_backward(self, *tensors: Tensor) -> None: # pragma: no cover passCopy the code

4.3.3 rng state

RNG status can have different performance impacts depending on the runtime, so you need to store the RNG status of the current device during each checkpoint and restore the RNG status of the current device before recalculation.

The save_rng_STATES and restore_rng_STATES methods are used to access RNG state, respectively.

def save_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> None: """:meth:`Checkpoint.forward` captures the current PyTorch's random number generator states at CPU and GPU to reuse in :meth:`Recompute.backward`. .. seealso:: :ref:`Referential Transparency` """ cpu_rng_state = torch.get_rng_state() gpu_rng_state: Optional[Tensor] if device.type == "cuda": gpu_rng_state = torch.cuda.get_rng_state(device) else: gpu_rng_state = None rng_states.append((cpu_rng_state, gpu_rng_state)) @contextmanager def restore_rng_states(device: torch.device, rng_states: Deque[RNGStates],) -> Generator[None, None, None]: """:meth:`Recompute.backward` restores the random number generator states captured by :func:`save_rng_states` within its  context. .. seealso:: :ref:`Referential Transparency` """ cpu_rng_state, gpu_rng_state = rng_states.pop() gpu_devices: List[torch.device] = [] if device.type == "cuda": gpu_devices.append(device) with torch.random.fork_rng(gpu_devices): torch.set_rng_state(cpu_rng_state) if gpu_rng_state is not None: torch.cuda.set_rng_state(gpu_rng_state, device) yieldCopy the code

4.3.4:

Checkpoint and the following Recompute divide the Checkpoint code in normal mode into two phases (the forward function is divided into two segments and the backward function into two segments) to make better use of the pipeline.

class Checkpoint(torch.autograd.Function): @staticmethod # type: ignore[override] def forward( ctx: Context, phony: Tensor, recomputed: Deque[Recomputed], rng_states: Deque[RNGStates], function: Function, input_atomic: bool, *input: Tensor, ) -> TensorOrTensors: Ctx. recomputed = recomputed CTx. rng_STATES = rng_States # Save the RNG state save_rng_STATES (INPUT [0]. Device, Ctx.rng_states) ctx.function = function ctx.input_atomic = input_atomic # Save_for_backward (*input) with torch. No_grad (), enable_checkpointing(): output = function(input[0] if input_atomic else input) return output @staticmethod def backward(ctx: Context, *grad_output: Tensor,) -> Tuple[Optional[Tensor], ...] Input_leaf = ctx.recomputed. Pop () if isinstance(output, tuple): tensors = output else: tensors = (output,) if any(y.requires_grad for y in tensors): • Tensors = tuple([x for X in tensors if x.requires_grad]) # grad_output) grad_input: List[Optional[Tensor]] = [None, None, None, None, None] grad_input.extend(x.grad for x in input_leaf) return tuple(grad_input)Copy the code

4.3.5 Recompute

Recompute means recalculating the intermediate variables based on the stored information.

class Recompute(torch.autograd.Function): @staticmethod # type: ignore[override] def forward( ctx: Context, phony: Tensor, recomputed: Deque[Recomputed], rng_states: Deque[RNGStates], function: Function, input_atomic: bool, *input: Tensor, ) -> Tensor: ctx.recomputed = recomputed ctx.rng_states = rng_states ctx.function = function ctx.input_atomic = input_atomic ctx.save_for_backward(*input) return phony @staticmethod def backward(ctx: Context, *grad_output: Tensor) -> Tuple[None, ...] : Input = ctx.saved_tensors input_leaf = tuple(x.dietach ().requires_grad_(x.equires_grad) for x in input) With restore_rng_states(INPUT [0]. Device, ctx.rng_STATES): with torch. Enable_grad (), enable_recomputing(): Output = ctx.function(input_leaf[0] if ctx.input_atomic else input_leaf) Use ctx.recomputed. Append ((output, input_leaf)) grad_input for Checkpoint: List[None] = [None, None, None, None, None] grad_input.extend(None for _ in ctx.saved_tensors) return tuple(grad_input)Copy the code

4.3.6 Pipeline

4.3.6.1 Task

Let’s first look at the Task class. Code is located in the torch/distributed/pipeline/sync/worker. Py.

As you can see from the comments, a Task is used to calculate a micro-batch on a partition.

Compute can be executed in parallel within worker threads.

Finalize should be executed after compute ends.

class Task: """A task represents how to compute a micro-batch on a partition. It consists of two parts: :meth:`compute` and :meth:`finalize`. :meth:`compute` should be executed in worker threads concurrently. :meth:`finalize` should be executed after when worker threads complete to execute :meth:`compute`. :meth:`compute` might  be boosted by worker threads. Because it produces several CUDA API calls by user code. In PyTorch, parallel CUDA API calls are not serialized through GIL. So more than one CUDA API call can be produced at the same time.  """ def __init__( self, stream: AbstractStream, *, compute: Callable[[], Batch], finalize: Optional[Callable[[Batch], None]], ) -> None: self.stream = stream self._compute = compute self._finalize = finalize self._grad_enabled = torch.is_grad_enabled() def compute(self) -> Batch: with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled): return self._compute() def finalize(self, batch: Batch) -> None: if self._finalize is None: return with use_stream(self.stream), torch.set_grad_enabled(self._grad_enabled): self._finalize(batch)Copy the code
4.3.6.2 compute

Compute function of Pipeline class.

The logic of Pipeline is as it appears in its annotations (PyTorch’s annotations are really informative). Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute) Task(streams[j], compute=chk.checkpoint, finalize=chk.recompute)

As you can see, the recompute method is set to the Finalize method of Task, and then the recalculation is scheduled.

class Pipeline: """The pipeline parallelism for Pipe.""" def compute( self, batches: List[Batch], schedule: List[Tuple[int, int]], skip_trackers: List[SkipTrackerThroughPotals], ) -> None: """Runs tasks with synchronization to copy streams.""" partitions = self.partitions devices = self.devices copy_streams = self.copy_streams checkpoint_stop = self.checkpoint_stop # Disable checkpointing if in eval mode. if not self.partitions[0].training: checkpoint_stop = 0 n = len(partitions) streams = [current_stream(d) for d in devices] exc_info: Optional[ExcInfo] = None # With checkpointing, the autograd graph looks like this diagram: # ┌ ─ ─ ─ ─ ─ ┸ ─ ─ ─ ─ ─ ─ ┐ # │ Copy │ # └ ─ ─ ─ ─ ─ ┰ ─ ─ ─ ─ ─ ─ ┘ (fence) # ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─ # ┃ (compute) # ┌ ─ ─ ─ ─ ─ ┸ ─ ─ ─ ─ ─ ─ ┐ # │ Wait │ [1] the Synchronize the current stream with the copy stream. # └ ─ ─ ─ ─ ─ ┰ ─ ─ ─ ─ ─ ─ ┘ # ┌ ─ ─ ─ ─ ─ ┸ ─ ─ ─ ─ ─ ─ ┐ # │ Checkpoint │ [2] Compute a partition within checkpointing. # └ ─ ─ ─ ─ ─ ┰ ─ ─ ─ ─ ─ ─ ┘ # ┌ ─ ─ ─ ─ ─ ┸ ─ ─ ─ ─ ─ ─ ┐ # │ Wait │ [3] the Synchronize the copy stream With the current stream. # └ ─ ─ ─ ─ ─ ┰ ─ ─ ─ ─ ─ ─ ┘ # ┠ ─ ─ ─ ┐ # ┃ ┌ ─ ─ ─ ─ ─ ┴ ─ ─ ─ ─ ─ ┐ # ┃ │ Recompute │ [4] the Schedule the recomputation The at backpropagation. # ┃ └ ─ ─ ─ ─ ─ ┬ ─ ─ ─ ─ ─ ┘ # ┠ ─ ─ ─ ┘ # # ┃ ─ ─ ─ ╂ ─ ─ ─ ─ ─ ─ ─ ─ ─ # ┌ ─ ─ ─ ─ ─ ┸ ─ ─ ─ ─ ─ ─ ┐ (fence) # │ Copy │ # ├ ──┰─── our r company for I, J in schedule: batch = batches[i] partition = partitions[j] # Synchronize with the copied input. ([1] in the diagram) if j ! = 0: _wait(batch, copy_streams[j][i], streams[j]) # Determine whether checkpointing or not. checkpoint = i < checkpoint_stop if checkpoint: def function( input: TensorOrTensors, partition: nn.Sequential = partition, skip_tracker: SkipTrackerThroughPotals = skip_trackers[i], chunk_id: int = i, part_id: int = j, ) -> TensorOrTensors: with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)): CHK = Checkpointing(function) Batch) # batch =chk.checkpoint =chk. recompute task = task (streams[j], chk.recompute task =chk.checkpoint, finalize=chk.recompute) del function, chk else: def compute( batch: Batch = batch, partition: nn.Sequential = partition, skip_tracker: SkipTrackerThroughPotals = skip_trackers[i], chunk_id: int = i, part_id: int = j, ) -> Batch: with use_skip_tracker(skip_tracker), record_function("chunk%d-part%d" % (chunk_id, part_id)): return batch.call(partition) task = Task(streams[j], compute=compute, ([2] in the diagram) self. In_queues [j]. Put (task) # Pipeline queue, which can be parallel. for i, j in schedule: ok, payload = self.out_queues[j].get() # Hold the first exception. if exc_info is not None: continue elif not ok: Batch = cast(Tuple[task, batch], payload) # The copy stream synchronizes to copy the output. ([3] in the # diagram) if j ! = n - 1: _wait(batch, streams[j], copy_streams[j][i]) # Finalize tasks. If checkpointing is enabled, here the # recomputation is scheduled at backpropagation. ([4] in the # diagram) with use_device(devices[j]): Task.finalize (Batch) # Batches [I] = Batch # Fail at the first exception. If exc_info is not None: raise exc_info[0].with_traceback(exc_info[1], exc_info[2])Copy the code

0 x05 Gpipe implementation

Gpipe can recalculate the forward propagation function F_k on the k-th accelerator when propagating back.

5.1 API function _Rematerialize

First, let’s look at the API methods.

In Builder.py there is the _Rematerialize function, which can be used to wrap a layer that needs to be recalculated.

  def _Rematerialize(self, name, body):
    """Forces rematerialization on FProp of the body layer."""
    return builder_layers.RematerializationLayer.Params().Set(
        name=name, body=body)
Copy the code

5.2 Packaging Layer RematerializationLayer

The RematerializationLayer is the wrapper layer, which has:

FProp wraps the wrapped layer as a function Fn and then passes Fn along with the input variable by calling Py_utils.rematerializefn.

class RematerializationLayer(base_layer.BaseLayer): """A wrapper layer with rematerialization.""" @classmethod def Params(cls): p = super().Params() p.Define('body', None, 'The main layer whose FProp will be wrapped by RematerializeFn.') return p def __init__(self, params): super().__init__(params) self.CreateChild('body', self.params.body) def FProp(self, theta, *xs): Flatten() # get theta theta_len = len(input_list) input_list += list(xs) # get input parameter input_len = len(input_list) def Fn(*args): # wrap function, Pack(args[:theta_len]) return self.body.fprop (body_theta, *args[theta_len:input_len]) return py_utils.RematerializeFn(Fn, *input_list) @classmethod def FPropMeta(CLS, p, *args): CheckShapes(args) return p.bobdy.cls.fpropMeta (p.bobdy, *args)Copy the code

3.2.3 Tensorflow gradients function

RematerializeFn calls the Tensorflow gradients function to calculate the gradient, so we need to explain.

In Tensorflow, the gradients function automatically calculates the gradient of the function. We just need to design our function and call the tF.gradients function.

The parameters for tf.gradients() are as follows, where

  • tf.gradients()implementationysrightxsderivative
  • grad_ysIt’s also a list whose length is equal tolen(ys). The significance of this parameter is thatxsThe derivative weight of each element in.
tf.gradients(ys, xs, 
             grad_ys=None, 
             name='gradients',
             colocate_gradients_with_ops=False,
             gate_gradients=False,
             aggregation_method=None,
             stop_gradients=None)
Copy the code

5.4 Function RematerializeFn

RematerializeFn is the final function that calls the FN and rematerializes the FN in the backpropagation process.

def RematerializeFn(fn, *xs): """Calls fn and rematerializes fn in the backward pass. `fn(*xs) -> ys`, where xs and ys can be a single tensor or a tuple of tensors. Args: fn: A python function to be rematerialized in the backprop pass. *xs: A single tensor or a list/tuple of tensors. `xs` are input args to the fn function. Returns: `fn(*xs)` """ initial_step_seed = GetStepSeed() final_step_seed = MaybeGenerateSeedFromScope() def Backward(fwd_xs, fwd_ys, d_fwd_ys): ""The backward function that rematerializes forward.""" del fwd_ys # Removing The passed parameters, Because internally a backup Checkpoint is needed to deal with always_true = tf.random. Uniform ([]) < 2.0 # Alternatively, can do this: # tf.where(tf.math.is_nan(x), # tf.constant(float('nan'), dtype=x.dtype) * tf.ones_like(x), # x) bak_xs = [tf. Where (always_true, x, tf.zerOS_like (x)) for x in fwd_xs.xs] # src in zip(bak_xs, xs): Dst.set_shape (src.shape) ResetStepSeed(initial_step_seed) ys = fn(*bak_xs MaybeResetStepSeed(final_step_seed) dxs = tf.gradients(ys, bak_xs, Dxs_final = [] # aggregate for dx, x in zip(DXS, bak_xs): if dx is None: dxs_final.append(tf.zeros_like(x)) else: dxs_final.append(dx) assert len(dxs_final) == len(bak_xs) return NestedMap( initial_step_seed=tf.zeros_like(initial_step_seed), xs=dxs_final) ys_shapes = [] # TODO(huangyp, yonghui): Check Forward doesn't use any stateful random ops. def Forward(fwd_xs): """Forward function plus sanity checks.""" for dst, src in zip(fwd_xs.xs, xs): Dst.set_shape (src.shape) ResetStepSeed(fwd_xs.initial_step_seed) ys = fn(* FWd_xs.xs) # Some sanity check. Assert not GetExtraInputs() assert not GetExtraArgs() assert not GetExtraVars() if isinstance(ys, tuple): for y in ys: assert isinstance(y, tf.Tensor) ys_shapes.append(y.shape) else: assert isinstance(ys, tf.Tensor) ys_shapes.append(ys.shape) return ys ys = CallDefun( Forward, NestedMap(initial_step_seed=initial_step_seed, xs=xs), bak=Backward) if isinstance(ys, tuple): for y, s in zip(ys, ys_shapes): y.set_shape(s) else: ys.set_shape(ys_shapes[0]) # TODO(b/129159299): The ResetStepSeed below is needed to work around this # bug, which is a problem with global tensors being shared by different # inference graphs. It should be replaced with the new step seed value # returned from the Forward function when the bug is fixed. MaybeResetStepSeed(final_step_seed) return ysCopy the code

CallDefun encapsulates FWD and back. Function builds a TensorFlow graph Function based on a callable

def CallDefun(fwd, args=None, bak=None, bak_as_function=False, device=None): """Wraps fwd in a defun with custom gradient bak and calls it with args. Args: fwd: A callable xs: Nested Structure -> ys: Nested Structure. args: A Nested Structure of tf.Tensor or None. bak: A callable xs, ys, dys: Nested Structure -> dxs[, dcapture]: Nested Structure. The custom backprop function for fwd. bak needs to return dcapture if fwd uses any implicitly captured  tensors, whose gradients are dcapture. bak_as_function: Whether to create a TF graph function for bak. device: the device on which to run fwd and bak. Returns: A Nested Structure equivalent to what fwd(args) computes. """ if args is not None: args = Transform(tf.convert_to_tensor, args) sigs = Function( fwd_sig=TensorSpecs(args), bak=bak, bak_as_function=bak_as_function, device=device)( fwd=fwd) if args is None: return sigs() else: return sigs(args)Copy the code

So far, the analysis of GPipe is finished, and the next article will start to analyze PipeDream, please look forward to it.

In addition, there will be a special series of analysis on PyTorch Pipeline.

0xEE Personal information

★★★★ Thoughts on life and technology ★★★★★

Wechat official account: Rosie’s Thoughts

0 XFF reference

Lingvo framework day reading notes

Tensorflow implements the accumulation of multiple minibatch-computed gradients before propagating them back

Gradient accumulation is realized by tensorflow2

Tenfold model calculation time increased by only 20% : OpenAI open source gradient replacement plugin

PipeDream: Fast and Efficient Pipeline Parallel DNN Training

Paper interpretation series 5: Microsoft Stanford and other PipeDream fast training large-scale neural network

Cs231n. Making. IO/neural – netw…

www.cnblogs.com/geekfx/p/14…

Video memory optimization techniques during training — OP merge and Gradient Checkpoint

Pytorch Note 04- Customize torch. Autograd. Function

PyTorch tutorial Autograd

A simple definition and case study of Torch. Autograd. Function

Pytorch’s custom extension (2) — Torch. Autograd. Function completes the custom layer

Torch. Autograd: Gradient calculation in detail

Back Propagation

CS231n Course Notes Translation: Backpropagation notes