Analysis of Checkpoint mechanism of PyTorch

Analysis of Checkpoint mechanism of PyTorch

This article has authorized the platform of the polar city and is the official account of the polar platform. No second reprint is allowed without permission

Original link: https://www.yuque.com/lart/ugkv9f/azvnyg

PyTorch provides a very convenient way to save video memory, which is the Checkpoint mechanism. The purpose of this article is to understand its internal mechanism more thoroughly.

Checkpoint mechanism

The core of this technology is a strategy of using time for space. It is widely used in many existing methods, such as DenseNet and swing transformer.

In order to understand its working principle, we must first understand the problem that what is the main memory occupation of PyTorch model used to store in the training process?

About this, Connolly's article PyTorch video memory mechanism analysis The introduction is very detailed:

To be frank, PyTorch has 4 most of the video memory overhead during in-depth learning training, which are model parameters, gradients of model parameters, optimizer States and intermediate activations or intermediate results.

Through Checkpoint technology, we can use the services provided by PyTorch in an ingenious way "no-grad" (no_grad()) mode to avoid recording this part of the operation in the "backward graph" by autograd, so as to avoid the storage demand for intermediate activation values.

Personal understanding (errors are welcome):

During forward propagation, autograd records some information and intermediate variables required for back propagation of each operation. After back propagation, the intermediate results used to calculate the gradient are released. In other words, the model parameters, optimizer state and parameter gradient always occupy the storage space, and the intermediate activation value is automatically cleared after back propagation. Specific changes in video memory occupation are visible PyTorch video memory occupancy analysis , I simply modified it here PyTorch video memory mechanism analysis Examples given in A little verification was carried out .

In fact, another question arises here. Why do custom functions generally reduce the occupation of video memory? (this phenomenon can be clearly seen in the comparison of various implementations in Vision Longformer)

I think it's mainly because when customizing functions, we can store intermediate variables in ctx from the perspective of an entire module, and the automatic derivation engine may pay too much attention, resulting in the storage of many unnecessary intermediate variables. I don't know how to verify this for the time being.

This can avoid storing the intermediate operation results of the specific layer of the model, so as to effectively reduce the occupation of video memory in forward propagation. These intermediate results are recalculated immediately during back propagation. It should be noted that when the layer wrapped by the checkpoint is back propagated, it will still open up the space for storing the gradient at the first back propagation.

Because the checkpoint is in torch no_ The forward function of the target operation calculated in grad () mode will not modify the state of the original leaf node, and the gradient will be maintained. Only the temporarily generated intermediate variables associated with these leaf nodes will be set to no gradient, so the gradient chain relationship will be broken.

In this way, although the back propagation time is prolonged, it also alleviates the occupation of video memory caused by storing a large number of intermediate variables to a certain extent.

Source code analysis

The following code is from pytorch v1 Version 10.1: https://github.com/pytorch/pytorch/blob/v1.10.1/torch/utils/checkpoint.py . Some new contents have been added in the latest version. Let's wait until it is finally released. The following contents have introduced the core of checkpoint.

auxiliary function

This part of the code first constructs several auxiliary functions, which are mainly used to check and process the input, and also deal with the problem of random seeds.

def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
    if isinstance(inputs, tuple):
        out = []
        for inp in inputs:
            if not isinstance(inp, torch.Tensor):
                out.append(inp)
                continue
            
            # detach() directly from the calculation diagram where inp is located. By default, requires will be automatically_ Grad set to False
            x = inp.detach()
            # However, in the actual requirements here, it is still necessary to maintain its own attribute to record the gradient, and its gradient becomes None
            x.requires_grad = inp.requires_grad
            # Because only the parameters of the gradient need to be saved can the propagation path of the gradient be constructed
            out.append(x)
        return tuple(out)
    else:
        raise RuntimeError(
            "Only tuple of tensors is supported. Got Unsupported input type: ", type(inputs).__name__)

def check_backward_validity(inputs: Iterable[Any]) -> None:
    """Check whether the input parameter has at least one gradient to be recorded Tensor,This ensures that the output also has a gradient."""
    if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)):
        warnings.warn("None of the inputs have requires_grad=True. Gradients will be None")

Due to the need for repeated calculation, the consistency of random state needs to be paid attention to. Because the forward propagation part will still be calculated once in the reverse process, if the original random state is not used, the recalculation will be different from the random state in the original normal calculation process, which will affect the behavior of the model.

In addition, an interesting point is mentioned in the comments of this Code:

Since it is impossible to know whether the operation processed by checkpoint will move some parameters to different devices during operation, it may be necessary to manually save the corresponding random state of these devices. The current implementation directly saves the random state on all visible devices, but this may sometimes be unnecessary, but there is no good solution at present.

So according to the meaning of the document, it means that if there is no such movement, there is no need to save the random state? This is actually somewhat confusing.

# We can't know if the run_fn will internally move some args to different devices,
# which would require logic to preserve rng states for those devices as well.
# We could paranoically stash and restore ALL the rng states for all visible devices,
# but that seems very wasteful for most cases.  Compromise:  Stash the RNG state for
# the device of all Tensor args.
#
# To consider:  maybe get_device_states and set_device_states should reside in torch/random.py?
def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]:
    """Get the corresponding of different inputs GPU Status of the random number generator for the device"""
    # This will not error out if "arg" is a CPU tensor or a non-tensor type because
    # the conditionals short-circuit.
    fwd_gpu_devices = list(set(arg.get_device() for arg in args
                               if isinstance(arg, torch.Tensor) and arg.is_cuda))

    fwd_gpu_states = []
    for device in fwd_gpu_devices:
        with torch.cuda.device(device):
            fwd_gpu_states.append(torch.cuda.get_rng_state())

    return fwd_gpu_devices, fwd_gpu_states

def set_device_states(devices, states) -> None:
    """Set the status of the random number generator for different devices"""
    for device, state in zip(devices, states):
        with torch.cuda.device(device):
            torch.cuda.set_rng_state(state)

Core Function

As you can see, the Checkpoint here is based on PyTorch Function of PyTorch custom operator Implementation of an extension operator, so this part of the code also involves many functions of Function.

Reading it can not only help us review relevant knowledge at the same time, but also further understand how to build more complex processing logic.

class CheckpointFunction(torch.autograd.Function):

    @staticmethod
    def forward(ctx, run_function, preserve_rng_state, *args):
        check_backward_validity(args)
        # Forward propagation function
        ctx.run_function = run_function
        ctx.preserve_rng_state = preserve_rng_state
        # Used to save the state of the mixing accuracy of the current model for back propagation
        ctx.had_autocast_in_fwd = torch.is_autocast_enabled()
        if preserve_rng_state:  # Save the state of the random number generator of the CPU and GPU before the forward propagation of the target module
            ctx.fwd_cpu_state = torch.get_rng_state()
            # Don't eagerly initialize the cuda context by accident.
            # (If the user intends that the context is initialized later, within their
            # run_function, we SHOULD actually stash the cuda state here.  Unfortunately,
            # we have no way to anticipate this will happen before we run the function.)
            ctx.had_cuda_in_fwd = False
            if torch.cuda._initialized:  
                # An internal variable provided by PyTorch to determine whether the CUDA state has been initialized
                # torch. cuda. is_ This variable is used in initialized
                ctx.had_cuda_in_fwd = True
                # Save the random state of each GPU device involved in the input variable
                ctx.fwd_gpu_devices, ctx.fwd_gpu_states = get_device_states(*args)

        # Save non-tensor inputs in ctx, keep a placeholder None for tensors
        # to be filled out during the backward.
        ctx.inputs = []
        ctx.tensor_indices = []
        tensor_inputs = []
        for i, arg in enumerate(args):
            if torch.is_tensor(arg):
                tensor_inputs.append(arg)
                ctx.tensor_indices.append(i)
                ctx.inputs.append(None)
            else:
                ctx.inputs.append(arg)

        # save_ for_ The input and output tensor s required for back propagation are saved in backward().
        # Since the output of the recording gradient needs to be recalculated in back propagation, do not save the output.
        # And the latter calculation does not need to be calculated in gradient mode.
        ctx.save_for_backward(*tensor_inputs)

        with torch.no_grad():  
            # The forward propagation operation without saving the gradient, that is, the output here will not record intermediate variables and cannot directly calculate the gradient.
            outputs = run_function(*args)
        return outputs

    @staticmethod
    def backward(ctx, *args):
        if not torch.autograd._is_checkpoint_valid():
            raise RuntimeError(
                "Checkpointing is not compatible with .grad() or when an `inputs` parameter"
                " is passed to .backward(). Please use .backward() and do not pass its `inputs`"
                " argument.")
        # Copy the list to avoid modifying original list.
        inputs = list(ctx.inputs)
        tensor_indices = ctx.tensor_indices
        tensors = ctx.saved_tensors # Gets the input tensor saved in the forward propagation

        # Fill in inputs with appropriate saved tensors.
        for i, idx in enumerate(tensor_indices):
            inputs[idx] = tensors[i]

        # Stash the surrounding rng state, and mimic the state that was
        # present at this time during forward.  Restore the surrounding state
        # when we're done.
        rng_devices = []
        if ctx.preserve_rng_state and ctx.had_cuda_in_fwd:
            rng_devices = ctx.fwd_gpu_devices
        
        # Using as like as two peas, the state of the random number generator saved before the forward propagation is conducted to conduct a forward process of exactly the same propagation.
        with torch.random.fork_rng(devices=rng_devices, enabled=ctx.preserve_rng_state):
            # Use the context manager to protect the state of the original random number generator and restore it after internal processing
            if ctx.preserve_rng_state:
                torch.set_rng_state(ctx.fwd_cpu_state)
                if ctx.had_cuda_in_fwd:
                    set_device_states(ctx.fwd_gpu_devices, ctx.fwd_gpu_states)
            # Here, the inputs are stripped from the calculation diagram, but its attribute requires_grad is the same as the original. The purpose of this is to cut off the back propagation path.
            # For the whole purpose of operation, since we need to recalculate the output and pass the gradient back to the input, the input itself needs to be able to record the gradient.
            # However, the return here cannot affect those operations beyond the checkpoint,
            # After backward, the previously saved intermediate variables will be released, and we only want to calculate the current small structure, so the gradient return needs to be truncated.
            detached_inputs = detach_variable(tuple(inputs))  # Will become leaf nodes, grad and grad_fn reset to None
            # After handling the random state, it's time to start forward propagation again.
            # This forward propagation is performed in gradient mode (torch.enable_grad()). The intermediate variable is saved.
            with torch.enable_grad(), torch.cuda.amp.autocast(ctx.had_autocast_in_fwd):
                outputs = ctx.run_function(*detached_inputs)

        if isinstance(outputs, torch.Tensor):
            outputs = (outputs,)

        # run backward() with only tensor that requires grad
        outputs_with_grad = []
        args_with_grad = []
        for i in range(len(outputs)):
            # Record the output [i] of the gradient to be calculated and the corresponding returned effective gradient args[i]
            if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
                outputs_with_grad.append(outputs[i])
                args_with_grad.append(args[i])
        # Check the output that needs to calculate the gradient. If there is no output that needs to calculate the gradient, it actually means that this module does not participate in the gradient calculation,
        # In other words, the module does not need to use checkpoint to adjust.
        if len(outputs_with_grad) == 0:
            raise RuntimeError(
                "none of output has requires_grad=True,"
                " this checkpoint() is not necessary")
        # This operation calculates back propagation for the wrapped target operation, that is, the calculation is returned to the input detached_ Gradient on inputs.
        # Since the input tensor has been stripped from the overall gradient graph, it can be regarded as a leaf node. Its gradient can be obtained after back propagation, and the intermediate variables will be released accordingly.
        # In addition, the back-propagation of the calculated gradient here will not lead to the release of the parameters temporarily saved in the front structure to calculate the gradient.
        torch.autograd.backward(outputs_with_grad, args_with_grad)
        # If you don't execute detach(), here's InP Grad will be released directly and set to None, which is not expected
        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else None
                      for inp in detached_inputs)

        # The gradient returned here corresponds to the forward input of the current class one by one,
        # Because the forward here contains two parameters, run, which do not need a gradient_ function,preserve_rng_state, so you can send back None.
        return (None, None) + grads

In fact, an intermediate layer is added between the original operation and the overall calculation diagram for information interaction:

  1. When the data of the original model is transmitted to the wrapped target layer, the data enters the forward() of the checkpoint, is checked and recorded by the checkpoint, and then sent to the target layer;
  2. The target layer performs forward propagation in non gradient mode. In this mode, newly created tensor s will not record gradient information;
  3. The results of the target layer are transmitted to other subsequent structures of the model through the forward propagation output of the checkpoint;
  4. Perform back propagation, loss derivation, chain return, and calculate gradient;
  5. The returned gradient corresponding to the checkpoint output is fed into its corresponding back propagation function, that is, the backward() of the checkpoint.
  6. After the gradient is sent into the checkpoint, the gradient needs to be further returned to the input of the target layer. Since the forward propagation of the target layer itself is in a non gradient state in the forward of the checkpoint, there is no gradient subgraph of the operation in the target layer on the return path. Therefore, in order to obtain this part of information, it is necessary to carry out a forward propagation on the target layer in the gradient state, and perform torch by combining the returned gradient with the output of the target layer autograd. Backward (outputs_with_grad, args_with_grad), so as to obtain the gradient information corresponding to the input.
  7. The gradient information entered by the corresponding target operation is returned after occupying the gradient of other auxiliary parameters with None according to the backward requirements of the Function of the checkpoint itself.
  8. The returned gradient corresponding to the output of other modules is sent into the backward of the corresponding operation along the back-propagation path, and the back-propagation layer by layer is accumulated on each leaf node.

After the operation is defined, a simple package is carried out, and the default parameters are processed to supplement more detailed documents:

def checkpoint(function, *args, use_reentrant: bool = True, **kwargs):
    r"""Checkpoint a model or part of the model
    
    Checkpointing works by trading compute for memory. Rather than storing all
    intermediate activations of the entire computation graph for computing
    backward, the checkpointed part does **not** save intermediate activations,
    and instead recomputes them in backward pass. It can be applied on any part
    of a model.
    
    Specifically, in the forward pass, :attr:`function` will run in
    :func:`torch.no_grad` manner, i.e., not storing the intermediate
    activations. Instead, the forward pass saves the inputs tuple and the
    :attr:`function` parameter. In the backwards pass, the saved inputs and
    :attr:`function` is retrieved, and the forward pass is computed on
    :attr:`function` again, now tracking the intermediate activations, and then
    the gradients are calculated using these activation values.
    This paragraph details checkpoint The core technology, that is, the forward propagation of target operation in non gradient mode, only retains the input and structural parameters, eliminating the preservation of intermediate activation. During back propagation, these activations are recalculated in gradient mode to reconstruct this part of the back graph, and then the normal return of the gradient is realized.
    
    The output of :attr:`function` can contain non-Tensor values and gradient
    recording is only performed for the Tensor values. Note that if the output
    consists of nested structures (ex: custom objects, lists, dicts etc.)
    consisting of Tensors, these Tensors nested in custom structures will not
    be considered as part of autograd.
    because checkpoint of backward In the implemented logic, directly traverse the output of the target operation (which will be converted to tuple type) and determine those outputs that need backflow gradient. If the output contains other non tensor Structure will cause these outputs to be ignored during traversal. However, it is true that this directly simplifies the processing, which reduces the flexibility, but also avoids the complexity of the code.
    
    .. warning::
        Checkpointing currently only supports :func:`torch.autograd.backward`
        and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
        is not supported.
    
    .. warning::
        If :attr:`function` invocation during backward does anything different
        than the one during forward, e.g., due to some global variable, the
        checkpointed version won't be equivalent, and unfortunately it can't be
        detected.
        Try to ensure the consistency of target operation during reverse calculation and forward calculation.
        Because in checkpoint The forward will be recalculated in the reverse, which may bring some differences from the conventional version due to undetectable uncertainties.
        
    .. warning::
        If checkpointed segment contains tensors detached from the computational
        graph by `detach()` or `torch.no_grad()`, the backward pass will raise an
        error. This is because `checkpoint` makes all the outputs require
        gradients which causes issues when a tensor is defined to have no
        gradient in the model. To circumvent this, detach the tensors outside of
        the `checkpoint` function.
        Do not include in the target operation detach Or non gradient mode processing.
        **It doesn't seem to have this problem in my actual test?**Maybe we should take a look here pytorch Test cases provided.
        
    .. warning::
        At least one of the inputs needs to have :code:`requires_grad=True` if
        grads are needed for model inputs, otherwise the checkpointed part of the
        model won't have gradients. At least one of the outputs needs to have
        :code:`requires_grad=True` as well.
        Make sure that at least one input is requires_grad This ensures that this part of the operation can be recorded.
        It is also necessary to calculate the gradient to ensure that there is at least one output.

    Args:
        function: describes what to run in the forward pass of the model or
            part of the model. It should also know how to handle the inputs
            passed as the tuple. For example, in LSTM, if user passes
            ``(activation, hidden)``, :attr:`function` should correctly use the
            first input as ``activation`` and the second input as ``hidden``
        preserve_rng_state(bool, optional, default=True):  Omit stashing and restoring
            the RNG state during each checkpoint.
        args: tuple containing inputs to the :attr:`function`

    Returns:
        Output of running :attr:`function` on :attr:`*args`
    """
    # Hack to mix *args with **kwargs in a python 2.7-compliant way
    preserve = kwargs.pop('preserve_rng_state', True)
    if kwargs:
        raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))

    return CheckpointFunction.apply(function, preserve, *args)

Application case

Checkpoint for Sequential

A very direct application case is given in the PyTorch source code, that is, the model built by applying checkpoint to Sequential. Divide the model into segments according to the specified number of segments.

def checkpoint_sequential(functions, segments, input, **kwargs):
    r"""A helper function for checkpointing sequential models.

    Sequential models execute a list of modules/functions in order
    (sequentially). Therefore, we can divide such a model in various segments
    and checkpoint each segment. All segments except the last will run in
    :func:`torch.no_grad` manner, i.e., not storing the intermediate
    activations. The inputs of each checkpointed segment will be saved for
    re-running the segment in the backward pass.

    See :func:`~torch.utils.checkpoint.checkpoint` on how checkpointing works.

    .. warning::
        Checkpointing currently only supports :func:`torch.autograd.backward`
        and only if its `inputs` argument is not passed. :func:`torch.autograd.grad`
        is not supported.

    .. warning:
        At least one of the inputs needs to have :code:`requires_grad=True` if
        grads are needed for model inputs, otherwise the checkpointed part of the
        model won't have gradients.

    .. warning:
        Since PyTorch 1.4, it allows only one Tensor as the input and
        intermediate outputs, just like :class:`torch.nn.Sequential`.

    Args:
        functions: A :class:`torch.nn.Sequential` or the list of modules or
            functions (comprising the model) to run sequentially.
        segments: Number of chunks to create in the model
        input: A Tensor that is input to :attr:`functions`
        preserve_rng_state(bool, optional, default=True):  Omit stashing and restoring
            the RNG state during each checkpoint.

    Returns:
        Output of running :attr:`functions` sequentially on :attr:`*inputs`

    Example:
        >>> model = nn.Sequential(...)
        >>> input_var = checkpoint_sequential(model, chunks, input_var)
    """
    # Hack for keyword-only parameter in a python 2.7-compliant way
    preserve = kwargs.pop('preserve_rng_state', True)
    if kwargs:
        raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwargs))

    def run_function(start, end, functions):
        def forward(input):
            for j in range(start, end + 1):
                input = functions[j](input)
            return input
        return forward

    if isinstance(functions, torch.nn.Sequential):
        functions = list(functions.children()) 
        # Obtain the sub module of Sequential. Here, the children method is used to obtain only the outermost layer

    segment_size = len(functions) // segments
    # The last chunk has to be non volatile
    end = -1
    for start in range(0, segment_size * (segments - 1), segment_size):
        end = start + segment_size - 1
        # Each sub module set is iteratively wrapped in checkpoint and propagated forward.
        input = checkpoint(run_function(start, end, functions), input,
                           preserve_rng_state=preserve)
    # The remaining structures no longer use checkpoint s
    return run_function(end + 1, len(functions) - 1, functions)(input)

Reference link

Keywords: Pytorch Deep Learning

Added by munchy on Wed, 19 Jan 2022 00:42:41 +0200