Conversation
514c163 to
cfe2e41
Compare
|
This is GREAT! Looking foward to this feature! |
|
Another issue to think about: the naming Some decorators change the function's behavior, like, @staticmethod. Others change the behavior of the function to the caller, like torch's
It may be too hard to think through all the ways we expose lazy tensor and harmonize them in the timeframe of this PR, so consider this optional, and we may have to refactor all the ways we discuss lazy tensor and compilation at some point in the future. (e.g. mark step, xm.optimizer_step, torch_xla.compile, etc). |
cfe2e41 to
1642715
Compare
|
Thanks for the great work! Just confirming that when using |
Depends on how those two are combined.
|
Ack. I don't have a great immediate thought (maybe |
|
This is ready for another look |
aeee946 to
074eb25
Compare
yaoshiang
left a comment
There was a problem hiding this comment.
tests appear to confirm the expected behavior of the decorator.
Fixes #8805.
We introduce a decorator,
@assume_pure, that can be placed on PyTorch/XLA functions and easily eliminate lazy tensor tracing overhead. If you have a pure function that only uses torch upstream ops, that function can be decorated with@assume_pureand will only be traced once for each unique input tensor shape combinations.Design
@assume_purebrings together three pieces of existing technologies:jax.vjp, which takes a JAX function and gives you the autograd forward and backward passtorchax, which converts a pure PyTorch function to a JAX functionxb.call_jax, which can call any JAX function from PyTorch/XLA and integrate it into the HLO graphIt works by:
torchax.interop.jax_viewto obtain a JAX function from the input PyTorch functionjax.vjpto get the forward and backward passtorch.autograd.Functioninstance, where the forward implementation isxb.call_jax(forward_pass), and the backward implementation isxb.call_jax(backward_pass), respectively.The core logic is actually just a single line:
How is the HLO cached
xb.call_jaxcaches the HLO if all the input shapes/dtypes and non-tensor arguments are the same.Therefore, subsequent
xb.call_jaxwill just reuse the cached HLO instead of retracing.The same kind of caching happens in both the forward and backward pass.
Different from the jax wrapper we used in
splash_attention, thej2t_autogradfunction saves the residuals (intermediate activations) during the forward pass and reuses them during the backward by plugging those into thevjp_funagain. This means it won't force a rematerialization (rerun the fwd) during the backward.Alternatives
Instead of
jax.vjpwe could also use AOTAutograd to get the forward and backward pass. However, AOTAutograd has a number of downsides:xp.Trace(...)tojax.named_scope(...).Instead of
assume_pure, we could also usetorch.compileto cache the XLA executable of the compiled function and skip the lazy tensor tracing. However,torch.compilehas its own downsides:torch.compileitself uses AOTAutograd and will suffer from the decomposition and customer operations issues etc.torch.compilehas a general perception of "either it works, or debugging will be complicated", which has been corroborated by experiments by people in the PyTorch/XLA team. See PyTorch team members' own recommendation 1. In contrast,@assume_purehas very simple rules for determining if it will work: if your function is pure, then it works.torch.compilewill graph break when entering and leaving the compiled region. In contrast,@assume_purecan avoid tracing overhead without even breaking the graph. The cached HLO is inlined into the overall HLO.Benchmarks
I tested tracing an example 100 layer decoder-only model:
Importantly, the
@assume_puredoes not scale with increasing complexity inside the model. That's because we only trace the model once, paying a fixed up-front cost, and then later runs will reuse the cached XLA computation object.Anecdotally, @bhavya01 reported saving >200ms tracing time in an SDXL experiment. That's very significant since each training step is sub-1 second.