Skip to content

composibility of assume_pure and call_jax#8989

Merged
qihqi merged 3 commits intomasterfrom
hanq_assume_pure2
Apr 21, 2025
Merged

composibility of assume_pure and call_jax#8989
qihqi merged 3 commits intomasterfrom
hanq_assume_pure2

Conversation

@qihqi
Copy link
Copy Markdown
Collaborator

@qihqi qihqi commented Apr 16, 2025

No description provided.

@qihqi qihqi force-pushed the hanq_assume_pure2 branch 2 times, most recently from 234ed02 to 31c0b37 Compare April 16, 2025 21:44
@qihqi qihqi requested a review from tengyifei April 16, 2025 21:51
Copy link
Copy Markdown
Collaborator

@tengyifei tengyifei left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Thanks!

Comment thread torch_xla/core/xla_builder.py
Comment thread torch_xla/core/xla_builder.py Outdated
Comment thread torch_xla/distributed/spmd/xla_sharding.py
@qihqi qihqi requested a review from tengyifei April 16, 2025 22:49
@qihqi qihqi force-pushed the hanq_assume_pure2 branch from 2c041b3 to 901c465 Compare April 16, 2025 22:54
Comment thread torch_xla/distributed/spmd/xla_sharding.py Outdated
@tengyifei tengyifei self-requested a review April 17, 2025 00:47
@qihqi qihqi force-pushed the hanq_assume_pure2 branch 2 times, most recently from 577f1d1 to 7c74bd4 Compare April 18, 2025 03:03
@qihqi qihqi force-pushed the hanq_assume_pure2 branch from 7c74bd4 to 575b7f0 Compare April 18, 2025 03:21
@qihqi qihqi enabled auto-merge (squash) April 18, 2025 03:32
@qihqi qihqi force-pushed the hanq_assume_pure2 branch from 91a8464 to 8fa4fe3 Compare April 21, 2025 15:26
@qihqi qihqi merged commit 8e6a5e5 into master Apr 21, 2025
23 of 24 checks passed
tengyifei added a commit that referenced this pull request Apr 25, 2025
Now we can run a JAX SPMD function that accesses the ambient SPMD mesh
from xb.call_jax.

Fixes #8972.

Also I beefed up the assume_pure tests and updated the docs to mention
that mark_sharding is supported thanks to qihqi@'
#8989.
tengyifei added a commit that referenced this pull request Apr 25, 2025
Now we can run a JAX SPMD function that accesses the ambient SPMD mesh
from xb.call_jax.

Fixes #8972.

Also I beefed up the assume_pure tests and updated the docs to mention
that mark_sharding is supported thanks to qihqi@'
#8989.
tengyifei added a commit that referenced this pull request Apr 25, 2025
Now we can run a JAX SPMD function that accesses the ambient SPMD mesh
from xb.call_jax.

Fixes #8972.

Also I beefed up the assume_pure tests and updated the docs to mention
that mark_sharding is supported thanks to qihqi@'
#8989.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants