🐛 Bug
Currently, checkpointing (torch_xla.utils.checkpoint) does not work with torch.autograd.grad:
if not torch.autograd._is_checkpoint_valid():
raise RuntimeError(
"Checkpointing is not compatible with .grad() or when an `inputs` parameter"
" is passed to .backward(). Please use .backward() and do not pass its `inputs`"
" argument.")
The PR was added by Jack some time ago: 16498c3458b#diff-29c25bb9605e07d3044ee5b3794d34a8a7f8d848caf2fc1376df32782e97d6fc, mostly from https://github.com/pytorch/pytorch/blob/main/torch/utils/checkpoint.py, but explicitly adding an optimization barrier before saving the input for the backward
However, there have been some efforts to close some of the missing gaps (e.g. pytorch/pytorch@7a41195), adding two different checkpointing variants with use_reentrant.
Particularly (https://github.com/pytorch/pytorch/blob/main/torch/utils/checkpoint.py#L440):
use_reentrant(bool):
specify whether to use the activation checkpoint variant that
requires reentrant autograd. This parameter should be passed
explicitly. In version 2.5 we will raise an exception if
``use_reentrant`` is not passed. If ``use_reentrant=False``,
``checkpoint`` will use an implementation that does not require
reentrant autograd. This allows ``checkpoint`` to support additional
functionality, such as working as expected with
``torch.autograd.grad`` and support for keyword arguments input into
the checkpointed function.
When disabling reentrant, to accommodate this additional support, we see that PyTorch/XLA does not support it (https://github.com/pytorch/xla/blob/master/torch_xla/utils/checkpoint.py#L304), although there is no concrete reasoning behind it, or known plans to fix it:
ValueError: XLA currently does not support use_reentrant==False
To Reproduce
Add --use_gradient_checkpointing to COMMON_GRAD_ACC_ARGS and re-run the TestSPMDLinearModelGradientAccumulation test with --use_gradient_accumulation_loop, failing with both use_reentrant to True or False.
🐛 Bug
Currently, checkpointing (
torch_xla.utils.checkpoint) does not work with torch.autograd.grad:The PR was added by Jack some time ago: 16498c3458b#diff-29c25bb9605e07d3044ee5b3794d34a8a7f8d848caf2fc1376df32782e97d6fc, mostly from https://github.com/pytorch/pytorch/blob/main/torch/utils/checkpoint.py, but explicitly adding an optimization barrier before saving the input for the backward
However, there have been some efforts to close some of the missing gaps (e.g. pytorch/pytorch@7a41195), adding two different checkpointing variants with
use_reentrant.Particularly (https://github.com/pytorch/pytorch/blob/main/torch/utils/checkpoint.py#L440):
When disabling reentrant, to accommodate this additional support, we see that PyTorch/XLA does not support it (https://github.com/pytorch/xla/blob/master/torch_xla/utils/checkpoint.py#L304), although there is no concrete reasoning behind it, or known plans to fix it:
To Reproduce
Add
--use_gradient_checkpointingtoCOMMON_GRAD_ACC_ARGSand re-run the TestSPMDLinearModelGradientAccumulation test with--use_gradient_accumulation_loop, failing with bothuse_reentrantto True or False.