Skip to content

Cache HLO in xb.call_jax and support non-tensor args#8878

Merged
tengyifei merged 2 commits intomasterfrom
yifeit/call-jax-cache
Mar 24, 2025
Merged

Cache HLO in xb.call_jax and support non-tensor args#8878
tengyifei merged 2 commits intomasterfrom
yifeit/call-jax-cache

Conversation

@tengyifei
Copy link
Copy Markdown
Collaborator

The main purpose is to replace the clunky manual XlaComputation object caching at
https://github.com/AI-Hypercomputer/torchprime/blob/b0bd47e3c732c56e75d8d2b315f05e06d485dd22/torchprime/torch_xla_models/experimental/custom_kernel.py#L16, and just write xb.call_jax(some_jax_func) and simply avoid repeated tracing there.

We can't reuse the tracing cache in jax.jit because we jit a wrapper and not jax_func. Also as_serialized_hlo_module_proto has overhead itself and it would be nice to avoid calling that repeatedly.

Also we improve xb.call_jax to support non-tensor arguments. These arguments are passed from xb.call_jax to the JAX function unchanged. They are considered "static arguments" and will be baked into the HLO.

Because they are considered static args, we'll re-trace the jax function whenever their values change.

Fixes #8795.

The main purpose is to replace the clunky manual XlaComputation object
caching at
https://github.com/AI-Hypercomputer/torchprime/blob/b0bd47e3c732c56e75d8d2b315f05e06d485dd22/torchprime/torch_xla_models/experimental/custom_kernel.py#L16,
and just write `xb.call_jax(some_jax_func)` and simply avoid repeated
tracing there.

We can't reuse the tracing cache in `jax.jit` because we jit a wrapper
and not `jax_func`. Also `as_serialized_hlo_module_proto` has overhead
itself and it would be nice to avoid calling that repeatedly.

Also we improve `xb.call_jax` to support non-tensor arguments. These
arguments are passed from `xb.call_jax` to the JAX function unchanged.
They are considered "static arguments" and will be baked into the HLO.

Because they are considered static args, we'll re-trace the jax function
whenever their values change.

Fixes #8795.
@tengyifei tengyifei marked this pull request as ready for review March 24, 2025 20:28
@tengyifei tengyifei requested review from bhavya01, qihqi and zpcore March 24, 2025 20:31
Comment thread torch_xla/core/xla_builder.py
@tengyifei tengyifei merged commit a3ef52e into master Mar 24, 2025
23 checks passed
Comment thread torch_xla/core/xla_builder.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

xb.call_jax only works with torch.Tensor arguments

3 participants