Skip to content

[inductor] do comm compute overlap at aten fx level#163215

Closed
eellison wants to merge 14 commits intogh/eellison/825/basefrom
gh/eellison/825/head
Closed

[inductor] do comm compute overlap at aten fx level#163215
eellison wants to merge 14 commits intogh/eellison/825/basefrom
gh/eellison/825/head

Conversation

@eellison
Copy link
Contributor

@eellison eellison commented Sep 18, 2025

Stack from ghstack (oldest at bottom):

This is first part of the stack that does comm/compute reordering, and then uses the exposure analysis to do bucketing.

Subsequent prs will handle:

  • use of exposure analysis to do bucketing
  • make sure inductor respects comm/compute overlapping done at fx level
  • non-profiling mm estimation/rank broadcasting of profile results

Other mis:

  • Validate accuracy of nccl estimations ( use ruisi's profiling instead ?)

For a llama 2d parallelism test, on forward, we overlap all but 2 of potentially hidden collectives. For backward, we overlap 217/269 of potentially hidden collectives. If you increase compute_overlap_multipler (for fudge factor of inaccurate comms estimation), that goes down to all but 16 of potentially hidden collectives.

fwd example: https://gist.github.com/eellison/76209c49d8829c5f1e323d34a3f040c3

bwd example: https://gist.github.com/eellison/6cfc2285df53a94cfa4012f5fdae5c51

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben

…posee analysis to avoid increasing exposed time in bucketing

[ghstack-poisoned]
@pytorch-bot
Copy link

pytorch-bot bot commented Sep 18, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163215

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c9503a8 with merge base 3a7db34 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

eellison added a commit that referenced this pull request Sep 18, 2025
…posee analysis to avoid increasing exposed time in bucketing

ghstack-source-id: f294750
Pull Request resolved: #163215
@eellison eellison changed the title [WIP] [inductor] do comm compute overlap at aten fx level, and use exposee analysis to avoid increasing exposed time in bucketing [WIP] [inductor] do comm compute overlap at aten fx level, and use exposure analysis to avoid increasing exposed time in bucketing Sep 18, 2025
@ezyang
Copy link
Contributor

ezyang commented Sep 22, 2025

Don't forget to post the scripts you were using to validate 2d llama!


if torch._inductor.config.test_configs.aten_fx_overlap_scheduling:
from torch._inductor.fx_passes.overlap_scheduling import (
schedule_overlap_bucketing,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice to have a test that also exercises this from outside of Inductor, to show that we aren't overly reliant on Inductor configs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tests are done on aten graphs

@ezyang
Copy link
Contributor

ezyang commented Sep 22, 2025

oh no merge conflicts

raise ValueError(f"node is not a collective kernel: {node}")

kernel_name = node.python_kernel_name
def get_collective_type_from_kernel_name(kernel_name: str) -> NCCL_COLL:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

str matching seems pretty sus, any reason we're not matching on OpOverload?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a refactoring with existing logic, but i agree we should improve

… and use exposure analysis to avoid increasing exposed time in bucketing"



TODO:
- finish up the bucketing logic post comm/compute reordering
- ghstack-ify some of the refactorings of related files
- pass all the tests in [test_compute_comm_reordering](https://github.com/pytorch/pytorch/blob/main/test/distributed/test_compute_comm_reordering.py) + add other tests
- Validate accuracy of nccl estimations  ( use ruisi's profiling instead ?)
- Write up comments
- make sure inductor respects comm/compute overlapping done at fx level

For a llama 2d parallelism test, on forward, we overlap all but 2 of potentially hidden collectives. For backward, we overlap 217/269 of potentially hidden collectives. If you increase `compute_overlap_multipler` (for fudge factor of inaccurate comms estimation), that goes down to all but 16 of potentially hidden collectives.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
eellison added a commit that referenced this pull request Sep 22, 2025
…posee analysis to avoid increasing exposed time in bucketing

ghstack-source-id: e85929a
Pull Request resolved: #163215
… and use exposure analysis to avoid increasing exposed time in bucketing"



TODO:
- finish up the bucketing logic post comm/compute reordering
- ghstack-ify some of the refactorings of related files
- pass all the tests in [test_compute_comm_reordering](https://github.com/pytorch/pytorch/blob/main/test/distributed/test_compute_comm_reordering.py) + add other tests
- Validate accuracy of nccl estimations  ( use ruisi's profiling instead ?)
- Write up comments
- make sure inductor respects comm/compute overlapping done at fx level

For a llama 2d parallelism test, on forward, we overlap all but 2 of potentially hidden collectives. For backward, we overlap 217/269 of potentially hidden collectives. If you increase `compute_overlap_multipler` (for fudge factor of inaccurate comms estimation), that goes down to all but 16 of potentially hidden collectives.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
eellison added a commit that referenced this pull request Sep 22, 2025
…posee analysis to avoid increasing exposed time in bucketing

ghstack-source-id: 2ae0fb0
Pull Request resolved: #163215
… and use exposure analysis to avoid increasing exposed time in bucketing"



TODO:
- finish up the bucketing logic post comm/compute reordering
- ghstack-ify some of the refactorings of related files
- pass all the tests in [test_compute_comm_reordering](https://github.com/pytorch/pytorch/blob/main/test/distributed/test_compute_comm_reordering.py) + add other tests
- Validate accuracy of nccl estimations  ( use ruisi's profiling instead ?)
- Write up comments
- make sure inductor respects comm/compute overlapping done at fx level

For a llama 2d parallelism test, on forward, we overlap all but 2 of potentially hidden collectives. For backward, we overlap 217/269 of potentially hidden collectives. If you increase `compute_overlap_multipler` (for fudge factor of inaccurate comms estimation), that goes down to all but 16 of potentially hidden collectives.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
eellison added a commit that referenced this pull request Sep 22, 2025
…posee analysis to avoid increasing exposed time in bucketing

ghstack-source-id: b83e780
Pull Request resolved: #163215
@eellison eellison changed the title [WIP] [inductor] do comm compute overlap at aten fx level, and use exposure analysis to avoid increasing exposed time in bucketing [inductor] do comm compute overlap at aten fx level Sep 22, 2025
This is first part of the stack that does comm/compute reordering, and then uses the exposure analysis to do bucketing.

Subsequent prs will handle:
- use of exposure analysis to do bucketing
- make sure inductor respects comm/compute overlapping done at fx level
- non-profiling mm estimation/rank broadcasting of profile results 

Other mis:
- Validate accuracy of nccl estimations  ( use ruisi's profiling instead ?)


For a llama 2d parallelism test, on forward, we overlap all but 2 of potentially hidden collectives. For backward, we overlap 217/269 of potentially hidden collectives. If you increase `compute_overlap_multipler` (for fudge factor of inaccurate comms estimation), that goes down to all but 16 of potentially hidden collectives.


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
eellison added a commit that referenced this pull request Sep 22, 2025
…posee analysis to avoid increasing exposed time in bucketing

ghstack-source-id: 11deb14
Pull Request resolved: #163215
@ezyang
Copy link
Contributor

ezyang commented Sep 22, 2025

@eellison when do you think you'll be able to rebase the PR and resolve the merge conflicts? (Not a request, just trying to plan around this.)

This is first part of the stack that does comm/compute reordering, and then uses the exposure analysis to do bucketing.

Subsequent prs will handle:
- use of exposure analysis to do bucketing
- make sure inductor respects comm/compute overlapping done at fx level
- non-profiling mm estimation/rank broadcasting of profile results 

Other mis:
- Validate accuracy of nccl estimations  ( use ruisi's profiling instead ?)


For a llama 2d parallelism test, on forward, we overlap all but 2 of potentially hidden collectives. For backward, we overlap 217/269 of potentially hidden collectives. If you increase `compute_overlap_multipler` (for fudge factor of inaccurate comms estimation), that goes down to all but 16 of potentially hidden collectives.

fwd example: https://gist.github.com/eellison/76209c49d8829c5f1e323d34a3f040c3



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
eellison added a commit that referenced this pull request Sep 22, 2025
…posee analysis to avoid increasing exposed time in bucketing

ghstack-source-id: 5a45377
Pull Request resolved: #163215
This is first part of the stack that does comm/compute reordering, and then uses the exposure analysis to do bucketing.

Subsequent prs will handle:
- use of exposure analysis to do bucketing
- make sure inductor respects comm/compute overlapping done at fx level
- non-profiling mm estimation/rank broadcasting of profile results 

Other mis:
- Validate accuracy of nccl estimations  ( use ruisi's profiling instead ?)


For a llama 2d parallelism test, on forward, we overlap all but 2 of potentially hidden collectives. For backward, we overlap 217/269 of potentially hidden collectives. If you increase `compute_overlap_multipler` (for fudge factor of inaccurate comms estimation), that goes down to all but 16 of potentially hidden collectives.

fwd example: https://gist.github.com/eellison/76209c49d8829c5f1e323d34a3f040c3



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
eellison added a commit that referenced this pull request Sep 22, 2025
…posee analysis to avoid increasing exposed time in bucketing

ghstack-source-id: 6311a69
Pull Request resolved: #163215
This is first part of the stack that does comm/compute reordering, and then uses the exposure analysis to do bucketing.

Subsequent prs will handle:
- use of exposure analysis to do bucketing
- make sure inductor respects comm/compute overlapping done at fx level
- non-profiling mm estimation/rank broadcasting of profile results 

Other mis:
- Validate accuracy of nccl estimations  ( use ruisi's profiling instead ?)


For a llama 2d parallelism test, on forward, we overlap all but 2 of potentially hidden collectives. For backward, we overlap 217/269 of potentially hidden collectives. If you increase `compute_overlap_multipler` (for fudge factor of inaccurate comms estimation), that goes down to all but 16 of potentially hidden collectives.

fwd example: https://gist.github.com/eellison/76209c49d8829c5f1e323d34a3f040c3



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
eellison added a commit that referenced this pull request Sep 22, 2025
…posee analysis to avoid increasing exposed time in bucketing

ghstack-source-id: 9eb0504
Pull Request resolved: #163215
This is first part of the stack that does comm/compute reordering, and then uses the exposure analysis to do bucketing.

Subsequent prs will handle:
- use of exposure analysis to do bucketing
- make sure inductor respects comm/compute overlapping done at fx level
- non-profiling mm estimation/rank broadcasting of profile results 

Other mis:
- Validate accuracy of nccl estimations  ( use ruisi's profiling instead ?)


For a llama 2d parallelism test, on forward, we overlap all but 2 of potentially hidden collectives. For backward, we overlap 217/269 of potentially hidden collectives. If you increase `compute_overlap_multipler` (for fudge factor of inaccurate comms estimation), that goes down to all but 16 of potentially hidden collectives.

fwd example: https://gist.github.com/eellison/76209c49d8829c5f1e323d34a3f040c3



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
pytorchmergebot pushed a commit that referenced this pull request Sep 29, 2025
Preparatory refactory

Pull Request resolved: #163754
Approved by: https://github.com/IvanKobzarev
ghstack dependencies: #163215
pytorchmergebot pushed a commit that referenced this pull request Sep 29, 2025
In comm-compute overlap we will have a graph with:

```
def foo(...):
     ag = all_gather(...)
     hiding_compute = mm(...)
     wait(ag)
```

There is no explicit dependency between the hiding compute and the collectives, but we want to add implicit dependencies from wait->hiding_compute, and from hiding_compute->all_gather to preserve overlap.

Additionally, while bucketing, we will merge collective starts and collective waits together. In this case, we will want to treat the two nodes as a single subgraph - each node in the merged set will have the union of all deps in the set.

This pr adds `AugmentedGraphHelper` that adds the apis, and allows querying for dependency with this augmented graph.

Pull Request resolved: #163959
Approved by: https://github.com/v0i0, https://github.com/IvanKobzarev
ghstack dependencies: #163215, #163754
@yangw-dev
Copy link
Contributor

@pytorchbot revert -m "seems fails inductor/test_aten_comm_compute_reordering for macos test, see https://hud.pytorch.org/pytorch/pytorch/commit/c9b5af9a384e7ef5f95613abe1622f5f55133c3a#51526707590-box" -c nosignal

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Sep 29, 2025
pytorchmergebot added a commit that referenced this pull request Sep 29, 2025
This reverts commit e1bd5b6.

Reverted #163754 on behalf of https://github.com/yangw-dev due to seems fails inductor/test_aten_comm_compute_reordering for macos test, see https://hud.pytorch.org/pytorch/pytorch/commit/c9b5af9a384e7ef5f95613abe1622f5f55133c3a#51526707590-box ([comment](#163215 (comment)))
pytorchmergebot added a commit that referenced this pull request Sep 29, 2025
@pytorchmergebot
Copy link
Collaborator

@eellison your PR has been successfully reverted.

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Sep 29, 2025
This is first part of the stack that does comm/compute reordering, and then uses the exposure analysis to do bucketing.

Subsequent prs will handle:
- use of exposure analysis to do bucketing
- make sure inductor respects comm/compute overlapping done at fx level
- non-profiling mm estimation/rank broadcasting of profile results 

Other mis:
- Validate accuracy of nccl estimations  ( use ruisi's profiling instead ?)


For a llama 2d parallelism test, on forward, we overlap all but 2 of potentially hidden collectives. For backward, we overlap 217/269 of potentially hidden collectives. If you increase `compute_overlap_multipler` (for fudge factor of inaccurate comms estimation), that goes down to all but 16 of potentially hidden collectives.

fwd example: https://gist.github.com/eellison/76209c49d8829c5f1e323d34a3f040c3

bwd example: https://gist.github.com/eellison/6cfc2285df53a94cfa4012f5fdae5c51



cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
@eellison eellison requested a review from a team as a code owner September 29, 2025 22:36
@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Sep 29, 2025
@eellison
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

Starting merge as part of PR stack under #163960

@pytorchmergebot
Copy link
Collaborator

The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command
For more information see pytorch-bot wiki.

pytorchmergebot pushed a commit that referenced this pull request Sep 30, 2025
Preparatory refactory

Pull Request resolved: #163754
Approved by: https://github.com/IvanKobzarev
ghstack dependencies: #163215
pytorchmergebot pushed a commit that referenced this pull request Sep 30, 2025
In comm-compute overlap we will have a graph with:

```
def foo(...):
     ag = all_gather(...)
     hiding_compute = mm(...)
     wait(ag)
```

There is no explicit dependency between the hiding compute and the collectives, but we want to add implicit dependencies from wait->hiding_compute, and from hiding_compute->all_gather to preserve overlap.

Additionally, while bucketing, we will merge collective starts and collective waits together. In this case, we will want to treat the two nodes as a single subgraph - each node in the merged set will have the union of all deps in the set.

This pr adds `AugmentedGraphHelper` that adds the apis, and allows querying for dependency with this augmented graph.

Pull Request resolved: #163959
Approved by: https://github.com/v0i0, https://github.com/IvanKobzarev
ghstack dependencies: #163215, #163754
pytorchmergebot pushed a commit that referenced this pull request Sep 30, 2025
tl;dr performs bucketing while preserving comm-compute overlap.

In comm-compute overlap we will have a graph with:

```
def foo(...):
     ag = all_gather(...)
     hiding_compute = mm(...)
     wait(ag)
```

There is no explicit dependency between the hiding compute and the collectives, but we want to add implicit dependencies from wait->hiding_compute, and from hiding_compute->all_gather to preserve overlap.

Additionally, while bucketing, we will merge collective starts and collective waits together. In this case, we will want to treat the two nodes as a single subgraph - each node in the merged set will have the union of all deps in the set.

We perform bucketing while augmenting the graph with these relationships. This can be done separably from comm-compute overlap, so long as the hiding compute relationships are passed in.

TODO:
- need to instrument fx graph so inductor respects these relationships.
- the compile time of the bucketing search can be sped up significantly by limiting what portion of the graph we traverse through
- more memory aware handling

Pull Request resolved: #163960
Approved by: https://github.com/ruisizhang123, https://github.com/v0i0, https://github.com/IvanKobzarev
ghstack dependencies: #163215, #163754, #163959
@github-actions github-actions bot deleted the gh/eellison/825/head branch October 31, 2025 02:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged merging module: inductor oncall: distributed Add this issue/PR to distributed oncall triage queue Reverted topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants