Skip to content

Add Comm-Compute Preserving Bucketer#163960

Closed
eellison wants to merge 10 commits intogh/eellison/831/basefrom
gh/eellison/831/head
Closed

Add Comm-Compute Preserving Bucketer#163960
eellison wants to merge 10 commits intogh/eellison/831/basefrom
gh/eellison/831/head

Conversation

@eellison
Copy link
Contributor

@eellison eellison commented Sep 26, 2025

Stack from ghstack (oldest at bottom):

tl;dr performs bucketing while preserving comm-compute overlap.

In comm-compute overlap we will have a graph with:

def foo(...):
     ag = all_gather(...)
     hiding_compute = mm(...)
     wait(ag)

There is no explicit dependency between the hiding compute and the collectives, but we want to add implicit dependencies from wait->hiding_compute, and from hiding_compute->all_gather to preserve overlap.

Additionally, while bucketing, we will merge collective starts and collective waits together. In this case, we will want to treat the two nodes as a single subgraph - each node in the merged set will have the union of all deps in the set.

We perform bucketing while augmenting the graph with these relationships. This can be done separably from comm-compute overlap, so long as the hiding compute relationships are passed in.

TODO:

  • need to instrument fx graph so inductor respects these relationships.
  • the compile time of the bucketing search can be sped up significantly by limiting what portion of the graph we traverse through
  • more memory aware handling

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 26, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/163960

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1072d8a with merge base 3a7db34 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

eellison added a commit that referenced this pull request Sep 26, 2025
ghstack-source-id: 63342e8
Pull Request resolved: #163960
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
eellison added a commit that referenced this pull request Sep 26, 2025
ghstack-source-id: 16b9e45
Pull Request resolved: #163960
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
eellison added a commit that referenced this pull request Sep 26, 2025
ghstack-source-id: 9fe12da
Pull Request resolved: #163960
cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
@eellison eellison requested a review from fmassa September 26, 2025 16:36
eellison added a commit that referenced this pull request Sep 26, 2025
ghstack-source-id: b9532ae
Pull Request resolved: #163960
@eellison eellison added the topic: not user facing topic category label Sep 26, 2025
@eellison eellison requested a review from ezyang September 26, 2025 17:11
tl;dr performs bucketing while preserving comm-compute overlap. 

In comm-compute overlap we will have a graph with: 

```
def foo(...):
     ag = all_gather(...)
     hiding_compute = mm(...)
     wait(ag)
```

There is no explicit dependency between the hiding compute and the collectives, but we want to add implicit dependencies from wait->hiding_compute, and from hiding_compute->all_gather to preserve overlap. 

Additionally, while bucketing, we will merge collective starts and collective waits together. In this case, we will want to treat the two nodes as a single subgraph - each node in the merged set will have the union of all deps in the set. 

We perform bucketing while augmenting the graph with these relationships. This can be done separably from comm-compute overlap, so long as the hiding compute relationships are passed in.

TODO: 
- need to instrument fx graph so inductor respects these relationships.
- the compile time of the bucketing search can be sped up significantly by limiting what portion of the graph we traverse through
- more memory aware handling


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
eellison added a commit that referenced this pull request Sep 26, 2025
ghstack-source-id: 81a45da
Pull Request resolved: #163960
Copy link
Contributor

@ruisizhang123 ruisizhang123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you!

tl;dr performs bucketing while preserving comm-compute overlap. 

In comm-compute overlap we will have a graph with: 

```
def foo(...):
     ag = all_gather(...)
     hiding_compute = mm(...)
     wait(ag)
```

There is no explicit dependency between the hiding compute and the collectives, but we want to add implicit dependencies from wait->hiding_compute, and from hiding_compute->all_gather to preserve overlap. 

Additionally, while bucketing, we will merge collective starts and collective waits together. In this case, we will want to treat the two nodes as a single subgraph - each node in the merged set will have the union of all deps in the set. 

We perform bucketing while augmenting the graph with these relationships. This can be done separably from comm-compute overlap, so long as the hiding compute relationships are passed in.

TODO: 
- need to instrument fx graph so inductor respects these relationships.
- the compile time of the bucketing search can be sped up significantly by limiting what portion of the graph we traverse through
- more memory aware handling


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
eellison added a commit that referenced this pull request Sep 26, 2025
ghstack-source-id: 1ce95b2
Pull Request resolved: #163960
insert_before: Optional[torch.fx.Node] = None,
wait_insertion_point: Optional[torch.fx.Node] = None,
) -> dict[torch.fx.Node, torch.fx.Node]:
) -> tuple[list[torch.fx.Node], dict[torch.fx.Node, torch.fx.Node]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

severe tuple blindness lol

@ezyang
Copy link
Contributor

ezyang commented Sep 29, 2025

I didn't do a detailed review, but ACKing the high level approach

IvanKobzarev pushed a commit to IvanKobzarev/pytorch that referenced this pull request Sep 29, 2025
ghstack-source-id: 1ce95b2
Pull Request resolved: pytorch#163960
tl;dr performs bucketing while preserving comm-compute overlap. 

In comm-compute overlap we will have a graph with: 

```
def foo(...):
     ag = all_gather(...)
     hiding_compute = mm(...)
     wait(ag)
```

There is no explicit dependency between the hiding compute and the collectives, but we want to add implicit dependencies from wait->hiding_compute, and from hiding_compute->all_gather to preserve overlap. 

Additionally, while bucketing, we will merge collective starts and collective waits together. In this case, we will want to treat the two nodes as a single subgraph - each node in the merged set will have the union of all deps in the set. 

We perform bucketing while augmenting the graph with these relationships. This can be done separably from comm-compute overlap, so long as the hiding compute relationships are passed in.

TODO: 
- need to instrument fx graph so inductor respects these relationships.
- the compile time of the bucketing search can be sped up significantly by limiting what portion of the graph we traverse through
- more memory aware handling


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
tl;dr performs bucketing while preserving comm-compute overlap. 

In comm-compute overlap we will have a graph with: 

```
def foo(...):
     ag = all_gather(...)
     hiding_compute = mm(...)
     wait(ag)
```

There is no explicit dependency between the hiding compute and the collectives, but we want to add implicit dependencies from wait->hiding_compute, and from hiding_compute->all_gather to preserve overlap. 

Additionally, while bucketing, we will merge collective starts and collective waits together. In this case, we will want to treat the two nodes as a single subgraph - each node in the merged set will have the union of all deps in the set. 

We perform bucketing while augmenting the graph with these relationships. This can be done separably from comm-compute overlap, so long as the hiding compute relationships are passed in.

TODO: 
- need to instrument fx graph so inductor respects these relationships.
- the compile time of the bucketing search can be sped up significantly by limiting what portion of the graph we traverse through
- more memory aware handling


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
eellison added a commit that referenced this pull request Sep 29, 2025
ghstack-source-id: cbeac24
Pull Request resolved: #163960
@eellison
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

tl;dr performs bucketing while preserving comm-compute overlap. 

In comm-compute overlap we will have a graph with: 

```
def foo(...):
     ag = all_gather(...)
     hiding_compute = mm(...)
     wait(ag)
```

There is no explicit dependency between the hiding compute and the collectives, but we want to add implicit dependencies from wait->hiding_compute, and from hiding_compute->all_gather to preserve overlap. 

Additionally, while bucketing, we will merge collective starts and collective waits together. In this case, we will want to treat the two nodes as a single subgraph - each node in the merged set will have the union of all deps in the set. 

We perform bucketing while augmenting the graph with these relationships. This can be done separably from comm-compute overlap, so long as the hiding compute relationships are passed in.

TODO: 
- need to instrument fx graph so inductor respects these relationships.
- the compile time of the bucketing search can be sped up significantly by limiting what portion of the graph we traverse through
- more memory aware handling


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben

[ghstack-poisoned]
eellison added a commit that referenced this pull request Sep 29, 2025
ghstack-source-id: bcb9860
Pull Request resolved: #163960
@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Sep 29, 2025
@pytorchmergebot
Copy link
Collaborator

Merge failed

Reason: New commits were pushed while merging. Please rerun the merge command.

Details for Dev Infra team Raised by workflow job

@eellison
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

ruisizhang123 added a commit to pytorch/torchtitan that referenced this pull request Oct 14, 2025
This pr adds the autobucketing pass at aten-level to simplefsdp. It runs
autobucketing + aot_eager backend without inductor. The aten fx
autobucketing pass can be find in this PR:
pytorch/pytorch#163960.

Key updates are:

1. Support customized `aot_eger_autobucketing` backend to perform
autobucketing optimization.
2. In simplefsdp, the model_backend can be replaced by user's customized
passes using `compile.model_backend_override`.
githubsgi pushed a commit to githubsgi/torchtitan that referenced this pull request Oct 15, 2025
This pr adds the autobucketing pass at aten-level to simplefsdp. It runs
autobucketing + aot_eager backend without inductor. The aten fx
autobucketing pass can be find in this PR:
pytorch/pytorch#163960.

Key updates are:

1. Support customized `aot_eger_autobucketing` backend to perform
autobucketing optimization.
2. In simplefsdp, the model_backend can be replaced by user's customized
passes using `compile.model_backend_override`.
githubsgi pushed a commit to githubsgi/torchtitan that referenced this pull request Oct 16, 2025
This pr adds the autobucketing pass at aten-level to simplefsdp. It runs
autobucketing + aot_eager backend without inductor. The aten fx
autobucketing pass can be find in this PR:
pytorch/pytorch#163960.

Key updates are:

1. Support customized `aot_eger_autobucketing` backend to perform
autobucketing optimization.
2. In simplefsdp, the model_backend can be replaced by user's customized
passes using `compile.model_backend_override`.
githubsgi pushed a commit to githubsgi/torchtitan that referenced this pull request Oct 29, 2025
This pr adds the autobucketing pass at aten-level to simplefsdp. It runs
autobucketing + aot_eager backend without inductor. The aten fx
autobucketing pass can be find in this PR:
pytorch/pytorch#163960.

Key updates are:

1. Support customized `aot_eger_autobucketing` backend to perform
autobucketing optimization.
2. In simplefsdp, the model_backend can be replaced by user's customized
passes using `compile.model_backend_override`.
@github-actions github-actions bot deleted the gh/eellison/831/head branch October 31, 2025 02:17
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 13, 2026
This pr adds the autobucketing pass at aten-level to simplefsdp. It runs
autobucketing + aot_eager backend without inductor. The aten fx
autobucketing pass can be find in this PR:
pytorch/pytorch#163960.

Key updates are:

1. Support customized `aot_eger_autobucketing` backend to perform
autobucketing optimization.
2. In simplefsdp, the model_backend can be replaced by user's customized
passes using `compile.model_backend_override`.
xrsrke pushed a commit to NousResearch/torchtitan that referenced this pull request Feb 25, 2026
This pr adds the autobucketing pass at aten-level to simplefsdp. It runs
autobucketing + aot_eager backend without inductor. The aten fx
autobucketing pass can be find in this PR:
pytorch/pytorch#163960.

Key updates are:

1. Support customized `aot_eger_autobucketing` backend to perform
autobucketing optimization.
2. In simplefsdp, the model_backend can be replaced by user's customized
passes using `compile.model_backend_override`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor ciflow/trunk Trigger trunk jobs on your pull request Merged module: inductor oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants