Cache HLO in xb.call_jax and support non-tensor args#8878
Merged
Conversation
The main purpose is to replace the clunky manual XlaComputation object caching at https://github.com/AI-Hypercomputer/torchprime/blob/b0bd47e3c732c56e75d8d2b315f05e06d485dd22/torchprime/torch_xla_models/experimental/custom_kernel.py#L16, and just write `xb.call_jax(some_jax_func)` and simply avoid repeated tracing there. We can't reuse the tracing cache in `jax.jit` because we jit a wrapper and not `jax_func`. Also `as_serialized_hlo_module_proto` has overhead itself and it would be nice to avoid calling that repeatedly. Also we improve `xb.call_jax` to support non-tensor arguments. These arguments are passed from `xb.call_jax` to the JAX function unchanged. They are considered "static arguments" and will be baked into the HLO. Because they are considered static args, we'll re-trace the jax function whenever their values change. Fixes #8795.
qihqi
reviewed
Mar 24, 2025
qihqi
approved these changes
Mar 24, 2025
zpcore
pushed a commit
that referenced
this pull request
Mar 26, 2025
zpcore
reviewed
Mar 31, 2025
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The main purpose is to replace the clunky manual XlaComputation object caching at
https://github.com/AI-Hypercomputer/torchprime/blob/b0bd47e3c732c56e75d8d2b315f05e06d485dd22/torchprime/torch_xla_models/experimental/custom_kernel.py#L16, and just write
xb.call_jax(some_jax_func)and simply avoid repeated tracing there.We can't reuse the tracing cache in
jax.jitbecause we jit a wrapper and notjax_func. Alsoas_serialized_hlo_module_protohas overhead itself and it would be nice to avoid calling that repeatedly.Also we improve
xb.call_jaxto support non-tensor arguments. These arguments are passed fromxb.call_jaxto the JAX function unchanged. They are considered "static arguments" and will be baked into the HLO.Because they are considered static args, we'll re-trace the jax function whenever their values change.
Fixes #8795.