Skip to content

Checkpointing is not compatible with torch.autograd.grad #8729

@rpsilva-aws

Description

@rpsilva-aws

🐛 Bug

Currently, checkpointing (torch_xla.utils.checkpoint) does not work with torch.autograd.grad:

    if not torch.autograd._is_checkpoint_valid():
      raise RuntimeError(
          "Checkpointing is not compatible with .grad() or when an `inputs` parameter"
          " is passed to .backward(). Please use .backward() and do not pass its `inputs`"
          " argument.")

The PR was added by Jack some time ago: 16498c3458b#diff-29c25bb9605e07d3044ee5b3794d34a8a7f8d848caf2fc1376df32782e97d6fc, mostly from https://github.com/pytorch/pytorch/blob/main/torch/utils/checkpoint.py, but explicitly adding an optimization barrier before saving the input for the backward

However, there have been some efforts to close some of the missing gaps (e.g. pytorch/pytorch@7a41195), adding two different checkpointing variants with use_reentrant.

Particularly (https://github.com/pytorch/pytorch/blob/main/torch/utils/checkpoint.py#L440):

        use_reentrant(bool):
            specify whether to use the activation checkpoint variant that
            requires reentrant autograd. This parameter should be passed
            explicitly. In version 2.5 we will raise an exception if
            ``use_reentrant`` is not passed. If ``use_reentrant=False``,
            ``checkpoint`` will use an implementation that does not require
            reentrant autograd. This allows ``checkpoint`` to support additional
            functionality, such as working as expected with
            ``torch.autograd.grad`` and support for keyword arguments input into
            the checkpointed function.

When disabling reentrant, to accommodate this additional support, we see that PyTorch/XLA does not support it (https://github.com/pytorch/xla/blob/master/torch_xla/utils/checkpoint.py#L304), although there is no concrete reasoning behind it, or known plans to fix it:

ValueError: XLA currently does not support use_reentrant==False

To Reproduce

Add --use_gradient_checkpointing to COMMON_GRAD_ACC_ARGS and re-run the TestSPMDLinearModelGradientAccumulation test with --use_gradient_accumulation_loop, failing with both use_reentrant to True or False.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions