Conversation
126ceee to
4d568ef
Compare
yeounoh
commented
Mar 12, 2024
yeounoh
commented
Mar 12, 2024
6ca8f97 to
d6dc442
Compare
yeounoh
commented
Mar 12, 2024
yeounoh
commented
Mar 12, 2024
yeounoh
commented
Mar 12, 2024
yeounoh
commented
Mar 12, 2024
yeounoh
commented
Mar 12, 2024
303b239 to
d3c1d70
Compare
* Assume REPLICATED for UNKNOWN during paramter resharding
…patch * Ungroup resharding ops * Replace device data after resharding
Delete quantization openxla patch Debugging probes
* Disable parameter wrapping with auto-sharding
…t fully support it, yet
* Linter fix
JackCaoG
reviewed
Mar 14, 2024
| run_test "$CDIR/spmd/test_xla_distributed_checkpoint.py" | ||
| run_test "$CDIR/spmd/test_xla_spmd_python_api_interaction.py" | ||
| run_test "$CDIR/spmd/test_dtensor_integration.py" | ||
| run_test "$CDIR/spmd/test_dtensor_integration2.py" |
Collaborator
There was a problem hiding this comment.
do we need this on TPU CI as well or it is ok to leave out?
Contributor
Author
There was a problem hiding this comment.
Ohhh i think it's ok to leave out. Want to run this sanity check on TPU!
JackCaoG
approved these changes
Mar 14, 2024
Collaborator
JackCaoG
left a comment
There was a problem hiding this comment.
Feel free to adjust remaining comments in a follow up [r
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This implemented a PoC prototype on XLA:TPU, as described in #6322
PyTorch/XLA auto-sharding can be enabled by one of the following:
XLA_SPMD_AUTO=1pytorch.distributed._tensor.distribute_modulewithauto-policyandxla:Some notable limitations that we will address in follow-ups:
cc @baoleai