Add all-gather coalescing for FSDP/ZeRO1#5950
Conversation
Also allow using reduce-scatter's scale param in FSDP. (revived #4145)
…ter tuple change without token
|
@jeffhataws let me know when you are done addressing comments, I will take another look |
| sharding_world_size: Optional[int] = None, | ||
| shard_param_on_dim_0: bool = False, | ||
| pin_layout_in_collective_ops: bool = True, | ||
| coalesce_all_gather_ops: bool = False, |
There was a problem hiding this comment.
Do you mind explaining the change in this file? I think coalesce_all_gather_ops is always False in our test, did you run into these issues with your own test?
There was a problem hiding this comment.
When the coalesce_all_gather_ops is True, the parameter shards are collected into a list and gathered in one all-gather coalesced command at the end (instead of all-gather one parameter at a time).
It is off by default to avoid changing existing behavior. The code is same as what we are using in our local fork.
| ReduceContext cc_ctx = GetReduceContext(inputs); | ||
| std::vector<xla::XlaOp> result(inputs.size()); | ||
|
|
||
| for (auto& type_ctx : cc_ctx.contexts) { |
There was a problem hiding this comment.
if you want to assume there is only one type_ctx, let's not use the for loop and GetReduceContext at all. This way we don't need to handle the token per type.
There was a problem hiding this comment.
Let me check with others on this.
JackCaoG
left a comment
There was a problem hiding this comment.
mostly lgtm beside the changes in FSDP. If we didn't change the default behavior of all-gather test should pass right?
I will look into reduce scatter one today, let's try to merge these two pr soon.
* Add all-gather and reduce-scatter coalescence support for FSDP. Also allow using reduce-scatter's scale param in FSDP. (revived pytorch#4145) * clang-format-7 and python lint fixes * Fix "SyntaxError: 'return' outside function" error * Code/test fixes to get run_tests.sh to run on CPU * Fix allgather to be compatible with openxla allgather tuple change without token * Fix reduce-scatter-coalesce to be compatible with openxla reduce-scatter tuple change without token * Separate out the reduce-scatter-coalesce changes into a separate PR * Some cleanups * Add separate BuildAllGatherCoalesced builder and AllGatherCoalesced class * Use token_handler.GetInput to capture token * Clean up * Clean up * Switch to GetOperandListWithToken naming for func GetOperandList
* Add all-gather and reduce-scatter coalescence support for FSDP. Also allow using reduce-scatter's scale param in FSDP. (revived pytorch#4145) * clang-format-7 and python lint fixes * Fix "SyntaxError: 'return' outside function" error * Code/test fixes to get run_tests.sh to run on CPU * Fix allgather to be compatible with openxla allgather tuple change without token * Fix reduce-scatter-coalesce to be compatible with openxla reduce-scatter tuple change without token * Separate out the reduce-scatter-coalesce changes into a separate PR * Some cleanups * Add separate BuildAllGatherCoalesced builder and AllGatherCoalesced class * Use token_handler.GetInput to capture token * Clean up * Clean up * Switch to GetOperandListWithToken naming for func GetOperandList
* Add all-gather and reduce-scatter coalescence support for FSDP. Also allow using reduce-scatter's scale param in FSDP. (revived pytorch#4145) * clang-format-7 and python lint fixes * Fix "SyntaxError: 'return' outside function" error * Code/test fixes to get run_tests.sh to run on CPU * Fix allgather to be compatible with openxla allgather tuple change without token * Fix reduce-scatter-coalesce to be compatible with openxla reduce-scatter tuple change without token * Separate out the reduce-scatter-coalesce changes into a separate PR * Some cleanups * Add separate BuildAllGatherCoalesced builder and AllGatherCoalesced class * Use token_handler.GetInput to capture token * Clean up * Clean up * Switch to GetOperandListWithToken naming for func GetOperandList
* Add all-gather and reduce-scatter coalescence support for FSDP. Also allow using reduce-scatter's scale param in FSDP. (revived #4145) * clang-format-7 and python lint fixes * Fix "SyntaxError: 'return' outside function" error * Code/test fixes to get run_tests.sh to run on CPU * Fix allgather to be compatible with openxla allgather tuple change without token * Fix reduce-scatter-coalesce to be compatible with openxla reduce-scatter tuple change without token * Separate out the reduce-scatter-coalesce changes into a separate PR * Some cleanups * Add separate BuildAllGatherCoalesced builder and AllGatherCoalesced class * Use token_handler.GetInput to capture token * Clean up * Clean up * Switch to GetOperandListWithToken naming for func GetOperandList
* Add all-gather and reduce-scatter coalescence support for FSDP. Also allow using reduce-scatter's scale param in FSDP. (revived #4145) * clang-format-7 and python lint fixes * Fix "SyntaxError: 'return' outside function" error * Code/test fixes to get run_tests.sh to run on CPU * Fix allgather to be compatible with openxla allgather tuple change without token * Fix reduce-scatter-coalesce to be compatible with openxla reduce-scatter tuple change without token * Separate out the reduce-scatter-coalesce changes into a separate PR * Some cleanups * Add separate BuildAllGatherCoalesced builder and AllGatherCoalesced class * Use token_handler.GetInput to capture token * Clean up * Clean up * Switch to GetOperandListWithToken naming for func GetOperandList
This PR adds all-gather coalescence support and use that in FSDP/ZeRO1 (replacing #5624). This PR is to be used in conjunction with openxla/xla#5740 .
A separate and related PR for reduce-scatter coalescence that also enables using reduce-scatter's scale param in FSDP is #5938.
This is a revival of #4145 . Will need to address the comments.