🚀 Feature
Instead of using PyTorch autograd and checkpointing, we'll investigate using jax.grad, jax.vjp, jax.remat etc. to control the rematerialization of a PyTorch model.
Motivation
JAX remat is more powerful than PyTorch autograd. For example, we can name individual tensors and selectively save/offload them. PyTorch does not support naming a tensor.
Pitch
Something like https://github.com/tengyifei/playground/blob/master/torch-jax-autograd.ipynb combined with #8781 and torchax.
cc @qihqi
🚀 Feature
Instead of using PyTorch autograd and checkpointing, we'll investigate using
jax.grad,jax.vjp,jax.rematetc. to control the rematerialization of a PyTorch model.Motivation
JAX remat is more powerful than PyTorch autograd. For example, we can name individual tensors and selectively save/offload them. PyTorch does not support naming a tensor.
Pitch
Something like https://github.com/tengyifei/playground/blob/master/torch-jax-autograd.ipynb combined with #8781 and torchax.
cc @qihqi