[inductor] do comm compute overlap at aten fx level#163215
[inductor] do comm compute overlap at aten fx level#163215eellison wants to merge 14 commits intogh/eellison/825/basefrom
Conversation
…posee analysis to avoid increasing exposed time in bucketing [ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163215
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit c9503a8 with merge base 3a7db34 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
Don't forget to post the scripts you were using to validate 2d llama! |
|
|
||
| if torch._inductor.config.test_configs.aten_fx_overlap_scheduling: | ||
| from torch._inductor.fx_passes.overlap_scheduling import ( | ||
| schedule_overlap_bucketing, |
There was a problem hiding this comment.
It would be nice to have a test that also exercises this from outside of Inductor, to show that we aren't overly reliant on Inductor configs
There was a problem hiding this comment.
tests are done on aten graphs
|
oh no merge conflicts |
| raise ValueError(f"node is not a collective kernel: {node}") | ||
|
|
||
| kernel_name = node.python_kernel_name | ||
| def get_collective_type_from_kernel_name(kernel_name: str) -> NCCL_COLL: |
There was a problem hiding this comment.
str matching seems pretty sus, any reason we're not matching on OpOverload?
There was a problem hiding this comment.
this is a refactoring with existing logic, but i agree we should improve
… and use exposure analysis to avoid increasing exposed time in bucketing" TODO: - finish up the bucketing logic post comm/compute reordering - ghstack-ify some of the refactorings of related files - pass all the tests in [test_compute_comm_reordering](https://github.com/pytorch/pytorch/blob/main/test/distributed/test_compute_comm_reordering.py) + add other tests - Validate accuracy of nccl estimations ( use ruisi's profiling instead ?) - Write up comments - make sure inductor respects comm/compute overlapping done at fx level For a llama 2d parallelism test, on forward, we overlap all but 2 of potentially hidden collectives. For backward, we overlap 217/269 of potentially hidden collectives. If you increase `compute_overlap_multipler` (for fudge factor of inaccurate comms estimation), that goes down to all but 16 of potentially hidden collectives. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
… and use exposure analysis to avoid increasing exposed time in bucketing" TODO: - finish up the bucketing logic post comm/compute reordering - ghstack-ify some of the refactorings of related files - pass all the tests in [test_compute_comm_reordering](https://github.com/pytorch/pytorch/blob/main/test/distributed/test_compute_comm_reordering.py) + add other tests - Validate accuracy of nccl estimations ( use ruisi's profiling instead ?) - Write up comments - make sure inductor respects comm/compute overlapping done at fx level For a llama 2d parallelism test, on forward, we overlap all but 2 of potentially hidden collectives. For backward, we overlap 217/269 of potentially hidden collectives. If you increase `compute_overlap_multipler` (for fudge factor of inaccurate comms estimation), that goes down to all but 16 of potentially hidden collectives. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
… and use exposure analysis to avoid increasing exposed time in bucketing" TODO: - finish up the bucketing logic post comm/compute reordering - ghstack-ify some of the refactorings of related files - pass all the tests in [test_compute_comm_reordering](https://github.com/pytorch/pytorch/blob/main/test/distributed/test_compute_comm_reordering.py) + add other tests - Validate accuracy of nccl estimations ( use ruisi's profiling instead ?) - Write up comments - make sure inductor respects comm/compute overlapping done at fx level For a llama 2d parallelism test, on forward, we overlap all but 2 of potentially hidden collectives. For backward, we overlap 217/269 of potentially hidden collectives. If you increase `compute_overlap_multipler` (for fudge factor of inaccurate comms estimation), that goes down to all but 16 of potentially hidden collectives. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
This is first part of the stack that does comm/compute reordering, and then uses the exposure analysis to do bucketing. Subsequent prs will handle: - use of exposure analysis to do bucketing - make sure inductor respects comm/compute overlapping done at fx level - non-profiling mm estimation/rank broadcasting of profile results Other mis: - Validate accuracy of nccl estimations ( use ruisi's profiling instead ?) For a llama 2d parallelism test, on forward, we overlap all but 2 of potentially hidden collectives. For backward, we overlap 217/269 of potentially hidden collectives. If you increase `compute_overlap_multipler` (for fudge factor of inaccurate comms estimation), that goes down to all but 16 of potentially hidden collectives. cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
|
@eellison when do you think you'll be able to rebase the PR and resolve the merge conflicts? (Not a request, just trying to plan around this.) |
This is first part of the stack that does comm/compute reordering, and then uses the exposure analysis to do bucketing. Subsequent prs will handle: - use of exposure analysis to do bucketing - make sure inductor respects comm/compute overlapping done at fx level - non-profiling mm estimation/rank broadcasting of profile results Other mis: - Validate accuracy of nccl estimations ( use ruisi's profiling instead ?) For a llama 2d parallelism test, on forward, we overlap all but 2 of potentially hidden collectives. For backward, we overlap 217/269 of potentially hidden collectives. If you increase `compute_overlap_multipler` (for fudge factor of inaccurate comms estimation), that goes down to all but 16 of potentially hidden collectives. fwd example: https://gist.github.com/eellison/76209c49d8829c5f1e323d34a3f040c3 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
This is first part of the stack that does comm/compute reordering, and then uses the exposure analysis to do bucketing. Subsequent prs will handle: - use of exposure analysis to do bucketing - make sure inductor respects comm/compute overlapping done at fx level - non-profiling mm estimation/rank broadcasting of profile results Other mis: - Validate accuracy of nccl estimations ( use ruisi's profiling instead ?) For a llama 2d parallelism test, on forward, we overlap all but 2 of potentially hidden collectives. For backward, we overlap 217/269 of potentially hidden collectives. If you increase `compute_overlap_multipler` (for fudge factor of inaccurate comms estimation), that goes down to all but 16 of potentially hidden collectives. fwd example: https://gist.github.com/eellison/76209c49d8829c5f1e323d34a3f040c3 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
This is first part of the stack that does comm/compute reordering, and then uses the exposure analysis to do bucketing. Subsequent prs will handle: - use of exposure analysis to do bucketing - make sure inductor respects comm/compute overlapping done at fx level - non-profiling mm estimation/rank broadcasting of profile results Other mis: - Validate accuracy of nccl estimations ( use ruisi's profiling instead ?) For a llama 2d parallelism test, on forward, we overlap all but 2 of potentially hidden collectives. For backward, we overlap 217/269 of potentially hidden collectives. If you increase `compute_overlap_multipler` (for fudge factor of inaccurate comms estimation), that goes down to all but 16 of potentially hidden collectives. fwd example: https://gist.github.com/eellison/76209c49d8829c5f1e323d34a3f040c3 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
This is first part of the stack that does comm/compute reordering, and then uses the exposure analysis to do bucketing. Subsequent prs will handle: - use of exposure analysis to do bucketing - make sure inductor respects comm/compute overlapping done at fx level - non-profiling mm estimation/rank broadcasting of profile results Other mis: - Validate accuracy of nccl estimations ( use ruisi's profiling instead ?) For a llama 2d parallelism test, on forward, we overlap all but 2 of potentially hidden collectives. For backward, we overlap 217/269 of potentially hidden collectives. If you increase `compute_overlap_multipler` (for fudge factor of inaccurate comms estimation), that goes down to all but 16 of potentially hidden collectives. fwd example: https://gist.github.com/eellison/76209c49d8829c5f1e323d34a3f040c3 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
Preparatory refactory Pull Request resolved: #163754 Approved by: https://github.com/IvanKobzarev ghstack dependencies: #163215
In comm-compute overlap we will have a graph with:
```
def foo(...):
ag = all_gather(...)
hiding_compute = mm(...)
wait(ag)
```
There is no explicit dependency between the hiding compute and the collectives, but we want to add implicit dependencies from wait->hiding_compute, and from hiding_compute->all_gather to preserve overlap.
Additionally, while bucketing, we will merge collective starts and collective waits together. In this case, we will want to treat the two nodes as a single subgraph - each node in the merged set will have the union of all deps in the set.
This pr adds `AugmentedGraphHelper` that adds the apis, and allows querying for dependency with this augmented graph.
Pull Request resolved: #163959
Approved by: https://github.com/v0i0, https://github.com/IvanKobzarev
ghstack dependencies: #163215, #163754
|
@pytorchbot revert -m "seems fails inductor/test_aten_comm_compute_reordering for macos test, see https://hud.pytorch.org/pytorch/pytorch/commit/c9b5af9a384e7ef5f95613abe1622f5f55133c3a#51526707590-box" -c nosignal |
|
@pytorchbot successfully started a revert job. Check the current status here. |
This reverts commit b5d4d35. Reverted #163959 on behalf of https://github.com/yangw-dev due to seems fails inductor/test_aten_comm_compute_reordering for macos test, see https://hud.pytorch.org/pytorch/pytorch/commit/c9b5af9a384e7ef5f95613abe1622f5f55133c3a#51526707590-box ([comment](#163215 (comment)))
This reverts commit e1bd5b6. Reverted #163754 on behalf of https://github.com/yangw-dev due to seems fails inductor/test_aten_comm_compute_reordering for macos test, see https://hud.pytorch.org/pytorch/pytorch/commit/c9b5af9a384e7ef5f95613abe1622f5f55133c3a#51526707590-box ([comment](#163215 (comment)))
This reverts commit c9b5af9. Reverted #163215 on behalf of https://github.com/yangw-dev due to seems fails inductor/test_aten_comm_compute_reordering for macos test, see https://hud.pytorch.org/pytorch/pytorch/commit/c9b5af9a384e7ef5f95613abe1622f5f55133c3a#51526707590-box ([comment](#163215 (comment)))
|
@eellison your PR has been successfully reverted. |
This is first part of the stack that does comm/compute reordering, and then uses the exposure analysis to do bucketing. Subsequent prs will handle: - use of exposure analysis to do bucketing - make sure inductor respects comm/compute overlapping done at fx level - non-profiling mm estimation/rank broadcasting of profile results Other mis: - Validate accuracy of nccl estimations ( use ruisi's profiling instead ?) For a llama 2d parallelism test, on forward, we overlap all but 2 of potentially hidden collectives. For backward, we overlap 217/269 of potentially hidden collectives. If you increase `compute_overlap_multipler` (for fudge factor of inaccurate comms estimation), that goes down to all but 16 of potentially hidden collectives. fwd example: https://gist.github.com/eellison/76209c49d8829c5f1e323d34a3f040c3 bwd example: https://gist.github.com/eellison/6cfc2285df53a94cfa4012f5fdae5c51 cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben [ghstack-poisoned]
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
Starting merge as part of PR stack under #163960 |
|
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
Preparatory refactory Pull Request resolved: #163754 Approved by: https://github.com/IvanKobzarev ghstack dependencies: #163215
In comm-compute overlap we will have a graph with:
```
def foo(...):
ag = all_gather(...)
hiding_compute = mm(...)
wait(ag)
```
There is no explicit dependency between the hiding compute and the collectives, but we want to add implicit dependencies from wait->hiding_compute, and from hiding_compute->all_gather to preserve overlap.
Additionally, while bucketing, we will merge collective starts and collective waits together. In this case, we will want to treat the two nodes as a single subgraph - each node in the merged set will have the union of all deps in the set.
This pr adds `AugmentedGraphHelper` that adds the apis, and allows querying for dependency with this augmented graph.
Pull Request resolved: #163959
Approved by: https://github.com/v0i0, https://github.com/IvanKobzarev
ghstack dependencies: #163215, #163754
tl;dr performs bucketing while preserving comm-compute overlap.
In comm-compute overlap we will have a graph with:
```
def foo(...):
ag = all_gather(...)
hiding_compute = mm(...)
wait(ag)
```
There is no explicit dependency between the hiding compute and the collectives, but we want to add implicit dependencies from wait->hiding_compute, and from hiding_compute->all_gather to preserve overlap.
Additionally, while bucketing, we will merge collective starts and collective waits together. In this case, we will want to treat the two nodes as a single subgraph - each node in the merged set will have the union of all deps in the set.
We perform bucketing while augmenting the graph with these relationships. This can be done separably from comm-compute overlap, so long as the hiding compute relationships are passed in.
TODO:
- need to instrument fx graph so inductor respects these relationships.
- the compile time of the bucketing search can be sped up significantly by limiting what portion of the graph we traverse through
- more memory aware handling
Pull Request resolved: #163960
Approved by: https://github.com/ruisizhang123, https://github.com/v0i0, https://github.com/IvanKobzarev
ghstack dependencies: #163215, #163754, #163959
Stack from ghstack (oldest at bottom):
This is first part of the stack that does comm/compute reordering, and then uses the exposure analysis to do bucketing.
Subsequent prs will handle:
Other mis:
For a llama 2d parallelism test, on forward, we overlap all but 2 of potentially hidden collectives. For backward, we overlap 217/269 of potentially hidden collectives. If you increase
compute_overlap_multipler(for fudge factor of inaccurate comms estimation), that goes down to all but 16 of potentially hidden collectives.fwd example: https://gist.github.com/eellison/76209c49d8829c5f1e323d34a3f040c3
bwd example: https://gist.github.com/eellison/6cfc2285df53a94cfa4012f5fdae5c51
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben