Skip to content

Add Tuple input and token support to all-gather and reduce-scatter.#58377

Closed
hjm-aws wants to merge 2 commits intotensorflow:masterfrom
hjm-aws:ag_rs_coalesce
Closed

Add Tuple input and token support to all-gather and reduce-scatter.#58377
hjm-aws wants to merge 2 commits intotensorflow:masterfrom
hjm-aws:ag_rs_coalesce

Conversation

@hjm-aws
Copy link
Copy Markdown

@hjm-aws hjm-aws commented Oct 31, 2022

No description provided.

@google-ml-butler google-ml-butler Bot added the size:L CL Change Size: Large label Oct 31, 2022
@google-cla
Copy link
Copy Markdown

google-cla Bot commented Oct 31, 2022

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@google-ml-butler google-ml-butler Bot requested a review from r4nt October 31, 2022 07:15
@google-ml-butler google-ml-butler Bot added the awaiting review Pull request awaiting review label Oct 31, 2022
@gbaned
Copy link
Copy Markdown
Contributor

gbaned commented Oct 31, 2022

@hjm-aws Can you please resolve conflicts? Thank you!

@gbaned
Copy link
Copy Markdown
Contributor

gbaned commented Nov 2, 2022

@hjm-aws Can you please sign CLA. Thank you!

@hjm-aws
Copy link
Copy Markdown
Author

hjm-aws commented Nov 10, 2022

@hjm-aws Can you please sign CLA. Thank you!
Done. Thanks!

@gbaned
Copy link
Copy Markdown
Contributor

gbaned commented Nov 24, 2022

@hjm-aws It still shows CLA is pending, can you please sign CLA. Thank you!

@gbaned gbaned removed the awaiting review Pull request awaiting review label Nov 24, 2022
@gbaned gbaned requested a review from cheshire November 24, 2022 22:39
@google-ml-butler google-ml-butler Bot added the awaiting review Pull request awaiting review label Nov 24, 2022
Copy link
Copy Markdown
Contributor

@cheshire cheshire left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot! Overall this makes sense, let me also check internally.

QQ: this is not decomposable, right? E.g. changes inside XLA could not be split from builder changes?

Shape inferred_shape,
ShapeInference::InferAllGatherShape({operand_shape},
all_gather_dimension, shard_count));
std::vector<const Shape*> operand_shapes;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also update documentation of semantics in operation_semantics.md?

@cheshire
Copy link
Copy Markdown
Contributor

Thanks, overall this looks like a very good change! Added @Kariddi and @blakehechtman for clarifications.

@jurahul jurahul self-requested a review November 29, 2022 15:38
HasSubstr("Replica groups expected to be of uniform size"));
}

TEST_F(HloVerifierTest, ReduceScatterTwoTokens) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is some confusion here between tokens and tuples. I think the intent is to add tuple support, so we should remove any mention of tokens in test names or in the change description.

all_gather_dimension, shard_count));
std::vector<const Shape*> operand_shapes;
std::vector<XlaOp> operands;
if (operand_shape->IsTuple()) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like copy/pasted from all-reduce. Is it possible to factor it out into a helper function?

ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 64})));
}

TEST_F(XlaBuilderTest, AllGatherWithToken) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AllGatherWithTuple?

ShapeUtil::Equal(root->shape(), ShapeUtil::MakeShape(F32, {4, 8})));
}

TEST_F(XlaBuilderTest, ReduceScatterWithToken) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ReduceScatterWithTuple?

TF_RET_CHECK(ag->operand_count() >= 1);

int64_t shard_count;
// There can be one token in the input Tuple. The token is a scalar or
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is confusing IMO. Can we get clarification as to what is intent of the token input to all-gather? Also treating a scalar as a token is confusing.

@jurahul
Copy link
Copy Markdown
Contributor

jurahul commented Nov 29, 2022

Overall it seems this is attempting to add tuple support for all-gather and reduce-scatter as well as add a optional dummy token input to the all-gather, the purpose of which is unclear.

I think we should split this into 2 PRs, one for tuple support and discuss support for token types in all-gather separately before adding it.

@gbaned gbaned removed the awaiting review Pull request awaiting review label Dec 5, 2022
@gbaned
Copy link
Copy Markdown
Contributor

gbaned commented Dec 5, 2022

@hjm-aws Can you please check @jurahul's comments and keep us posted ? Thank you!

@gbaned gbaned added the stat:awaiting response Status - Awaiting response from author label Dec 5, 2022
@gbaned
Copy link
Copy Markdown
Contributor

gbaned commented Dec 29, 2022

@hjm-aws Any update on this PR? Please. Thank you!

1 similar comment
@gbaned
Copy link
Copy Markdown
Contributor

gbaned commented Mar 21, 2023

@hjm-aws Any update on this PR? Please. Thank you!

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Apr 5, 2023

This PR is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you.

@github-actions github-actions Bot added the stale This label marks the issue/pr stale - to be closed automatically if no activity label Apr 5, 2023
@github-actions
Copy link
Copy Markdown

This PR was closed because it has been inactive for 14 days since being marked as stale. Please reopen if you'd like to work on this further.

@jeffhataws
Copy link
Copy Markdown
Contributor

jeffhataws commented Sep 22, 2023

Overall it seems this is attempting to add tuple support for all-gather and reduce-scatter as well as add a optional dummy token input to the all-gather, the purpose of which is unclear.

I think we should split this into 2 PRs, one for tuple support and discuss support for token types in all-gather separately before adding it.

@gbaned @jurahul , the revived PR on openxla is openxla/xla#5740 . The new PR has a description that hopefully answer you questions.

copybara-service Bot pushed a commit to openxla/xla that referenced this pull request Oct 17, 2023
Imported from GitHub PR #5740

This PR adds tuple input support to all-gather and reduce-scatter. This is a revival of part of tensorflow/tensorflow#58377 and to be used in conjunction with pytorch/xla#5624 .

In FSDP, different layers' weights need to be all-gathered/reduced-scatter during training. If some layers are small, multiple layers' weights can be aggregated for more efficient data transfer (same concept as bucket_cap_mb in DDP). With existing all-gather and reduce-scatter in PyTorch-XLA, you would have to do the bucketing and decomposing outside of the operation. This PR enables multiple different tensors to be all-gathered/reduce-scatter, keeping the original tensor shapes to enable bucketing and decomposing optimizations inside the operation.

Original PR has token support like the token used for allreduce to ensure order between CCops. That will be separate PR if needed.

Copybara import of the project:

--
7ea1159 by Junmin Hao <junminh@amazon.com>:

Add Tuple input and token support to all-gather and reduce-scatter.

Committer: Junmin Hao <junminh@amazon.com>

--
cdb873e by Junmin Hao <junminh@amazon.com>:

lint fix

--
aad3521 by Jeffrey Huynh <jthuynh@amazon.com>:

Fix hlo_verifier_test failure due to changed expectation

--
32e8145 by Jeffrey Huynh <jthuynh@amazon.com>:

Separate the token change out into a separate PR with RFC.

--
b301c2a by Jeffrey Huynh <jthuynh@amazon.com>:

Change *WithToken tests to *WithTuple

--
5890278 by Jeffrey Huynh <jthuynh@amazon.com>:

Fix missing parenthesis

Merging this change closes #5740

COPYBARA_INTEGRATE_REVIEW=#5740 from jeffhataws:ag_rs_coalesce_revived 14e09f0
PiperOrigin-RevId: 573976449
copybara-service Bot pushed a commit that referenced this pull request Oct 17, 2023
Imported from GitHub PR openxla/xla#5740

This PR adds tuple input support to all-gather and reduce-scatter. This is a revival of part of #58377 and to be used in conjunction with pytorch/xla#5624 .

In FSDP, different layers' weights need to be all-gathered/reduced-scatter during training. If some layers are small, multiple layers' weights can be aggregated for more efficient data transfer (same concept as bucket_cap_mb in DDP). With existing all-gather and reduce-scatter in PyTorch-XLA, you would have to do the bucketing and decomposing outside of the operation. This PR enables multiple different tensors to be all-gathered/reduce-scatter, keeping the original tensor shapes to enable bucketing and decomposing optimizations inside the operation.

Original PR has token support like the token used for allreduce to ensure order between CCops. That will be separate PR if needed.

Copybara import of the project:

--
7ea1159a1464efddebe9384e87ed6df504d89b2e by Junmin Hao <junminh@amazon.com>:

Add Tuple input and token support to all-gather and reduce-scatter.

Committer: Junmin Hao <junminh@amazon.com>

--
cdb873e6d97f5f12b3d3c587bb5782d58e3554c5 by Junmin Hao <junminh@amazon.com>:

lint fix

--
aad352117ba950ac5ae62330e3980f4b5898a701 by Jeffrey Huynh <jthuynh@amazon.com>:

Fix hlo_verifier_test failure due to changed expectation

--
32e814524b88a474af5e4e904c0dd19841430b86 by Jeffrey Huynh <jthuynh@amazon.com>:

Separate the token change out into a separate PR with RFC.

--
b301c2a2a5b52180f9e9626173e6b67a78782960 by Jeffrey Huynh <jthuynh@amazon.com>:

Change *WithToken tests to *WithTuple

--
5890278fc16c9f900782d32a92d40ecf548aea85 by Jeffrey Huynh <jthuynh@amazon.com>:

Fix missing parenthesis

Merging this change closes #5740

PiperOrigin-RevId: 573976449
jeffhataws added a commit to jeffhataws/openxla that referenced this pull request Nov 19, 2023
…tter

Imported from GitHub PR openxla#5740

This PR adds tuple input support to all-gather and reduce-scatter. This is a revival of part of tensorflow/tensorflow#58377 and to be used in conjunction with pytorch/xla#5624 .

In FSDP, different layers' weights need to be all-gathered/reduced-scatter during training. If some layers are small, multiple layers' weights can be aggregated for more efficient data transfer (same concept as bucket_cap_mb in DDP). With existing all-gather and reduce-scatter in PyTorch-XLA, you would have to do the bucketing and decomposing outside of the operation. This PR enables multiple different tensors to be all-gathered/reduce-scatter, keeping the original tensor shapes to enable bucketing and decomposing optimizations inside the operation.

Original PR has token support like the token used for allreduce to ensure order between CCops. That will be separate PR if needed.

Copybara import of the project:

--
7ea1159 by Junmin Hao <junminh@amazon.com>:

Add Tuple input and token support to all-gather and reduce-scatter.

Committer: Junmin Hao <junminh@amazon.com>

--
cdb873e by Junmin Hao <junminh@amazon.com>:

lint fix

--
aad3521 by Jeffrey Huynh <jthuynh@amazon.com>:

Fix hlo_verifier_test failure due to changed expectation

--
32e8145 by Jeffrey Huynh <jthuynh@amazon.com>:

Separate the token change out into a separate PR with RFC.

--
b301c2a by Jeffrey Huynh <jthuynh@amazon.com>:

Change *WithToken tests to *WithTuple

--
5890278 by Jeffrey Huynh <jthuynh@amazon.com>:

Fix missing parenthesis

Merging this change closes openxla#5740

COPYBARA_INTEGRATE_REVIEW=openxla#5740 from jeffhataws:ag_rs_coalesce_revived 14e09f0
PiperOrigin-RevId: 573976449
jeffhataws added a commit to jeffhataws/openxla that referenced this pull request Dec 10, 2023
…tter

Imported from GitHub PR openxla#5740

This PR adds tuple input support to all-gather and reduce-scatter. This is a revival of part of tensorflow/tensorflow#58377 and to be used in conjunction with pytorch/xla#5624 .

In FSDP, different layers' weights need to be all-gathered/reduced-scatter during training. If some layers are small, multiple layers' weights can be aggregated for more efficient data transfer (same concept as bucket_cap_mb in DDP). With existing all-gather and reduce-scatter in PyTorch-XLA, you would have to do the bucketing and decomposing outside of the operation. This PR enables multiple different tensors to be all-gathered/reduce-scatter, keeping the original tensor shapes to enable bucketing and decomposing optimizations inside the operation.

Original PR has token support like the token used for allreduce to ensure order between CCops. That will be separate PR if needed.

Copybara import of the project:

--
7ea1159 by Junmin Hao <junminh@amazon.com>:

Add Tuple input and token support to all-gather and reduce-scatter.

Committer: Junmin Hao <junminh@amazon.com>

--
cdb873e by Junmin Hao <junminh@amazon.com>:

lint fix

--
aad3521 by Jeffrey Huynh <jthuynh@amazon.com>:

Fix hlo_verifier_test failure due to changed expectation

--
32e8145 by Jeffrey Huynh <jthuynh@amazon.com>:

Separate the token change out into a separate PR with RFC.

--
b301c2a by Jeffrey Huynh <jthuynh@amazon.com>:

Change *WithToken tests to *WithTuple

--
5890278 by Jeffrey Huynh <jthuynh@amazon.com>:

Fix missing parenthesis

Merging this change closes openxla#5740

COPYBARA_INTEGRATE_REVIEW=openxla#5740 from jeffhataws:ag_rs_coalesce_revived 14e09f0
PiperOrigin-RevId: 573976449
jeffhataws added a commit to jeffhataws/openxla that referenced this pull request Dec 11, 2023
…tter

Imported from GitHub PR openxla#5740

This PR adds tuple input support to all-gather and reduce-scatter. This is a revival of part of tensorflow/tensorflow#58377 and to be used in conjunction with pytorch/xla#5624 .

In FSDP, different layers' weights need to be all-gathered/reduced-scatter during training. If some layers are small, multiple layers' weights can be aggregated for more efficient data transfer (same concept as bucket_cap_mb in DDP). With existing all-gather and reduce-scatter in PyTorch-XLA, you would have to do the bucketing and decomposing outside of the operation. This PR enables multiple different tensors to be all-gathered/reduce-scatter, keeping the original tensor shapes to enable bucketing and decomposing optimizations inside the operation.

Original PR has token support like the token used for allreduce to ensure order between CCops. That will be separate PR if needed.

Copybara import of the project:

--
7ea1159 by Junmin Hao <junminh@amazon.com>:

Add Tuple input and token support to all-gather and reduce-scatter.

Committer: Junmin Hao <junminh@amazon.com>

--
cdb873e by Junmin Hao <junminh@amazon.com>:

lint fix

--
aad3521 by Jeffrey Huynh <jthuynh@amazon.com>:

Fix hlo_verifier_test failure due to changed expectation

--
32e8145 by Jeffrey Huynh <jthuynh@amazon.com>:

Separate the token change out into a separate PR with RFC.

--
b301c2a by Jeffrey Huynh <jthuynh@amazon.com>:

Change *WithToken tests to *WithTuple

--
5890278 by Jeffrey Huynh <jthuynh@amazon.com>:

Fix missing parenthesis

Merging this change closes openxla#5740

COPYBARA_INTEGRATE_REVIEW=openxla#5740 from jeffhataws:ag_rs_coalesce_revived 14e09f0
PiperOrigin-RevId: 573976449
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

comp:xla XLA size:L CL Change Size: Large stale This label marks the issue/pr stale - to be closed automatically if no activity

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants