Skip to content

[call_jax] support returning PyTree from the JAX function#8957

Merged
tengyifei merged 2 commits intomasterfrom
yifeit/call-jax-out-pytree
Apr 10, 2025
Merged

[call_jax] support returning PyTree from the JAX function#8957
tengyifei merged 2 commits intomasterfrom
yifeit/call-jax-out-pytree

Conversation

@tengyifei
Copy link
Copy Markdown
Collaborator

Also clean up the call_jax logic a bit, and cache the computation object instead of the HLO to further reduce overhead.

Also clean up the call_jax logic a bit, and cache the
computation object instead of the HLO to further reduce overhead.
@tengyifei tengyifei requested a review from qihqi April 10, 2025 06:15
@tengyifei tengyifei marked this pull request as ready for review April 10, 2025 06:15
@tengyifei tengyifei requested a review from bhavya01 April 10, 2025 06:16
Comment thread torch_xla/core/xla_builder.py Outdated
@tengyifei tengyifei force-pushed the yifeit/call-jax-out-pytree branch from 55162aa to 30b07a9 Compare April 10, 2025 20:43
@tengyifei tengyifei enabled auto-merge (squash) April 10, 2025 20:45
@tengyifei tengyifei merged commit 297f5de into master Apr 10, 2025
24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants