Add Tuple input and token support to all-gather and reduce-scatter.#58377
Add Tuple input and token support to all-gather and reduce-scatter.#58377hjm-aws wants to merge 2 commits intotensorflow:masterfrom
Conversation
Committer: Junmin Hao <junminh@amazon.com>
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
|
@hjm-aws Can you please resolve conflicts? Thank you! |
|
@hjm-aws Can you please sign CLA. Thank you! |
|
|
@hjm-aws It still shows CLA is pending, can you please sign CLA. Thank you! |
cheshire
left a comment
There was a problem hiding this comment.
Thanks a lot! Overall this makes sense, let me also check internally.
QQ: this is not decomposable, right? E.g. changes inside XLA could not be split from builder changes?
| Shape inferred_shape, | ||
| ShapeInference::InferAllGatherShape({operand_shape}, | ||
| all_gather_dimension, shard_count)); | ||
| std::vector<const Shape*> operand_shapes; |
There was a problem hiding this comment.
Could you also update documentation of semantics in operation_semantics.md?
|
Thanks, overall this looks like a very good change! Added @Kariddi and @blakehechtman for clarifications. |
| HasSubstr("Replica groups expected to be of uniform size")); | ||
| } | ||
|
|
||
| TEST_F(HloVerifierTest, ReduceScatterTwoTokens) { |
There was a problem hiding this comment.
I think there is some confusion here between tokens and tuples. I think the intent is to add tuple support, so we should remove any mention of tokens in test names or in the change description.
| all_gather_dimension, shard_count)); | ||
| std::vector<const Shape*> operand_shapes; | ||
| std::vector<XlaOp> operands; | ||
| if (operand_shape->IsTuple()) { |
There was a problem hiding this comment.
This looks like copy/pasted from all-reduce. Is it possible to factor it out into a helper function?
| ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 64}))); | ||
| } | ||
|
|
||
| TEST_F(XlaBuilderTest, AllGatherWithToken) { |
| ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 8}))); | ||
| } | ||
|
|
||
| TEST_F(XlaBuilderTest, ReduceScatterWithToken) { |
| TF_RET_CHECK(ag->operand_count() >= 1); | ||
|
|
||
| int64_t shard_count; | ||
| // There can be one token in the input Tuple. The token is a scalar or |
There was a problem hiding this comment.
This is confusing IMO. Can we get clarification as to what is intent of the token input to all-gather? Also treating a scalar as a token is confusing.
|
Overall it seems this is attempting to add tuple support for all-gather and reduce-scatter as well as add a optional dummy token input to the all-gather, the purpose of which is unclear. I think we should split this into 2 PRs, one for tuple support and discuss support for token types in all-gather separately before adding it. |
|
@hjm-aws Any update on this PR? Please. Thank you! |
1 similar comment
|
@hjm-aws Any update on this PR? Please. Thank you! |
|
This PR is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you. |
|
This PR was closed because it has been inactive for 14 days since being marked as stale. Please reopen if you'd like to work on this further. |
@gbaned @jurahul , the revived PR on openxla is openxla/xla#5740 . The new PR has a description that hopefully answer you questions. |
Imported from GitHub PR #5740 This PR adds tuple input support to all-gather and reduce-scatter. This is a revival of part of tensorflow/tensorflow#58377 and to be used in conjunction with pytorch/xla#5624 . In FSDP, different layers' weights need to be all-gathered/reduced-scatter during training. If some layers are small, multiple layers' weights can be aggregated for more efficient data transfer (same concept as bucket_cap_mb in DDP). With existing all-gather and reduce-scatter in PyTorch-XLA, you would have to do the bucketing and decomposing outside of the operation. This PR enables multiple different tensors to be all-gathered/reduce-scatter, keeping the original tensor shapes to enable bucketing and decomposing optimizations inside the operation. Original PR has token support like the token used for allreduce to ensure order between CCops. That will be separate PR if needed. Copybara import of the project: -- 7ea1159 by Junmin Hao <junminh@amazon.com>: Add Tuple input and token support to all-gather and reduce-scatter. Committer: Junmin Hao <junminh@amazon.com> -- cdb873e by Junmin Hao <junminh@amazon.com>: lint fix -- aad3521 by Jeffrey Huynh <jthuynh@amazon.com>: Fix hlo_verifier_test failure due to changed expectation -- 32e8145 by Jeffrey Huynh <jthuynh@amazon.com>: Separate the token change out into a separate PR with RFC. -- b301c2a by Jeffrey Huynh <jthuynh@amazon.com>: Change *WithToken tests to *WithTuple -- 5890278 by Jeffrey Huynh <jthuynh@amazon.com>: Fix missing parenthesis Merging this change closes #5740 COPYBARA_INTEGRATE_REVIEW=#5740 from jeffhataws:ag_rs_coalesce_revived 14e09f0 PiperOrigin-RevId: 573976449
Imported from GitHub PR openxla/xla#5740 This PR adds tuple input support to all-gather and reduce-scatter. This is a revival of part of #58377 and to be used in conjunction with pytorch/xla#5624 . In FSDP, different layers' weights need to be all-gathered/reduced-scatter during training. If some layers are small, multiple layers' weights can be aggregated for more efficient data transfer (same concept as bucket_cap_mb in DDP). With existing all-gather and reduce-scatter in PyTorch-XLA, you would have to do the bucketing and decomposing outside of the operation. This PR enables multiple different tensors to be all-gathered/reduce-scatter, keeping the original tensor shapes to enable bucketing and decomposing optimizations inside the operation. Original PR has token support like the token used for allreduce to ensure order between CCops. That will be separate PR if needed. Copybara import of the project: -- 7ea1159a1464efddebe9384e87ed6df504d89b2e by Junmin Hao <junminh@amazon.com>: Add Tuple input and token support to all-gather and reduce-scatter. Committer: Junmin Hao <junminh@amazon.com> -- cdb873e6d97f5f12b3d3c587bb5782d58e3554c5 by Junmin Hao <junminh@amazon.com>: lint fix -- aad352117ba950ac5ae62330e3980f4b5898a701 by Jeffrey Huynh <jthuynh@amazon.com>: Fix hlo_verifier_test failure due to changed expectation -- 32e814524b88a474af5e4e904c0dd19841430b86 by Jeffrey Huynh <jthuynh@amazon.com>: Separate the token change out into a separate PR with RFC. -- b301c2a2a5b52180f9e9626173e6b67a78782960 by Jeffrey Huynh <jthuynh@amazon.com>: Change *WithToken tests to *WithTuple -- 5890278fc16c9f900782d32a92d40ecf548aea85 by Jeffrey Huynh <jthuynh@amazon.com>: Fix missing parenthesis Merging this change closes #5740 PiperOrigin-RevId: 573976449
…tter Imported from GitHub PR openxla#5740 This PR adds tuple input support to all-gather and reduce-scatter. This is a revival of part of tensorflow/tensorflow#58377 and to be used in conjunction with pytorch/xla#5624 . In FSDP, different layers' weights need to be all-gathered/reduced-scatter during training. If some layers are small, multiple layers' weights can be aggregated for more efficient data transfer (same concept as bucket_cap_mb in DDP). With existing all-gather and reduce-scatter in PyTorch-XLA, you would have to do the bucketing and decomposing outside of the operation. This PR enables multiple different tensors to be all-gathered/reduce-scatter, keeping the original tensor shapes to enable bucketing and decomposing optimizations inside the operation. Original PR has token support like the token used for allreduce to ensure order between CCops. That will be separate PR if needed. Copybara import of the project: -- 7ea1159 by Junmin Hao <junminh@amazon.com>: Add Tuple input and token support to all-gather and reduce-scatter. Committer: Junmin Hao <junminh@amazon.com> -- cdb873e by Junmin Hao <junminh@amazon.com>: lint fix -- aad3521 by Jeffrey Huynh <jthuynh@amazon.com>: Fix hlo_verifier_test failure due to changed expectation -- 32e8145 by Jeffrey Huynh <jthuynh@amazon.com>: Separate the token change out into a separate PR with RFC. -- b301c2a by Jeffrey Huynh <jthuynh@amazon.com>: Change *WithToken tests to *WithTuple -- 5890278 by Jeffrey Huynh <jthuynh@amazon.com>: Fix missing parenthesis Merging this change closes openxla#5740 COPYBARA_INTEGRATE_REVIEW=openxla#5740 from jeffhataws:ag_rs_coalesce_revived 14e09f0 PiperOrigin-RevId: 573976449
…tter Imported from GitHub PR openxla#5740 This PR adds tuple input support to all-gather and reduce-scatter. This is a revival of part of tensorflow/tensorflow#58377 and to be used in conjunction with pytorch/xla#5624 . In FSDP, different layers' weights need to be all-gathered/reduced-scatter during training. If some layers are small, multiple layers' weights can be aggregated for more efficient data transfer (same concept as bucket_cap_mb in DDP). With existing all-gather and reduce-scatter in PyTorch-XLA, you would have to do the bucketing and decomposing outside of the operation. This PR enables multiple different tensors to be all-gathered/reduce-scatter, keeping the original tensor shapes to enable bucketing and decomposing optimizations inside the operation. Original PR has token support like the token used for allreduce to ensure order between CCops. That will be separate PR if needed. Copybara import of the project: -- 7ea1159 by Junmin Hao <junminh@amazon.com>: Add Tuple input and token support to all-gather and reduce-scatter. Committer: Junmin Hao <junminh@amazon.com> -- cdb873e by Junmin Hao <junminh@amazon.com>: lint fix -- aad3521 by Jeffrey Huynh <jthuynh@amazon.com>: Fix hlo_verifier_test failure due to changed expectation -- 32e8145 by Jeffrey Huynh <jthuynh@amazon.com>: Separate the token change out into a separate PR with RFC. -- b301c2a by Jeffrey Huynh <jthuynh@amazon.com>: Change *WithToken tests to *WithTuple -- 5890278 by Jeffrey Huynh <jthuynh@amazon.com>: Fix missing parenthesis Merging this change closes openxla#5740 COPYBARA_INTEGRATE_REVIEW=openxla#5740 from jeffhataws:ag_rs_coalesce_revived 14e09f0 PiperOrigin-RevId: 573976449
…tter Imported from GitHub PR openxla#5740 This PR adds tuple input support to all-gather and reduce-scatter. This is a revival of part of tensorflow/tensorflow#58377 and to be used in conjunction with pytorch/xla#5624 . In FSDP, different layers' weights need to be all-gathered/reduced-scatter during training. If some layers are small, multiple layers' weights can be aggregated for more efficient data transfer (same concept as bucket_cap_mb in DDP). With existing all-gather and reduce-scatter in PyTorch-XLA, you would have to do the bucketing and decomposing outside of the operation. This PR enables multiple different tensors to be all-gathered/reduce-scatter, keeping the original tensor shapes to enable bucketing and decomposing optimizations inside the operation. Original PR has token support like the token used for allreduce to ensure order between CCops. That will be separate PR if needed. Copybara import of the project: -- 7ea1159 by Junmin Hao <junminh@amazon.com>: Add Tuple input and token support to all-gather and reduce-scatter. Committer: Junmin Hao <junminh@amazon.com> -- cdb873e by Junmin Hao <junminh@amazon.com>: lint fix -- aad3521 by Jeffrey Huynh <jthuynh@amazon.com>: Fix hlo_verifier_test failure due to changed expectation -- 32e8145 by Jeffrey Huynh <jthuynh@amazon.com>: Separate the token change out into a separate PR with RFC. -- b301c2a by Jeffrey Huynh <jthuynh@amazon.com>: Change *WithToken tests to *WithTuple -- 5890278 by Jeffrey Huynh <jthuynh@amazon.com>: Fix missing parenthesis Merging this change closes openxla#5740 COPYBARA_INTEGRATE_REVIEW=openxla#5740 from jeffhataws:ag_rs_coalesce_revived 14e09f0 PiperOrigin-RevId: 573976449
No description provided.