Skip to content

[FSDP][Collectives] skipping reduce_scatter when world size is 1#160136

Closed
anshul-si wants to merge 6 commits intogh/anshul-si/17/basefrom
gh/anshul-si/17/head
Closed

[FSDP][Collectives] skipping reduce_scatter when world size is 1#160136
anshul-si wants to merge 6 commits intogh/anshul-si/17/basefrom
gh/anshul-si/17/head

Conversation

@anshul-si
Copy link
Contributor

@anshul-si anshul-si commented Aug 7, 2025

Summary: In its current state, FSDP collectives uses cuda synchronizations and communication ops regardless of what the world size is. However, now that replicate will use FSDP, there will be instances where group size = 1 and these synchronizations and ops will be used needlessly. I have updated fsdp_collectives to skip reduce_scatter in the foreach_reduce API when world_size ‎ = 1. I have created edited a test that uses CommDebugMode to verify that the reduce_scatter has been removed. I also edited an affected test which used 1-way FSDP by verifying and changing its assert statements for CommDebugMode. I have also added a test command.

Test Cases

  1. pytest test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_train_parity_single_worldsize1
  2. pytest test/distributed/_composable/test_composability/test_2d_composability.py -k test_tp_with_fsdp_offloading

Stack from ghstack (oldest at bottom):

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 7, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/160136

Note: Links to docs will display an error until the docs builds have been completed.

❌ 2 New Failures, 8 Unrelated Failures

As of commit 421bed3 with merge base e6aa728 (image):

NEW FAILURES - The following jobs have failed:

FLAKY - The following job failed but was likely due to flakiness present on trunk:

BROKEN TRUNK - The following jobs failed but was present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category labels Aug 7, 2025
anshul-si added a commit that referenced this pull request Aug 7, 2025
@anshul-si anshul-si requested review from mori360 and weifengpy August 7, 2025 20:41
…e is 1"

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
…e is 1"

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
…e is 1"

cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
…e is 1"

**Summary:** In its current state, FSDP collectives uses cuda synchronizations and communication ops regardless of what the world size is. However, now that replicate will use FSDP, there will be instances where group size = 1 and these synchronizations and ops will be used needlessly. I have updated fsdp_collectives to skip reduce_scatter in the foreach_reduce API when world_size ‎ = 1. I have created edited a test that uses CommDebugMode to verify that the reduce_scatter has been removed. I also edited an affected test which used 1-way FSDP by verifying and changing its assert statements for CommDebugMode. I have also added a test command. 


**Test Cases**
1. pytest test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_train_parity_single_worldsize1
2. pytest test/distributed/_composable/test_composability/test_2d_composability.py -k test_tp_with_fsdp_offloading





cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
@anshul-si
Copy link
Contributor Author

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot pushed a commit that referenced this pull request Aug 20, 2025
…vice movements (#160147)

**Summary:** In order to ensure that replicate acts as intended (a specialized version of hsdp) we need to make sure that it can pass the same tests that fully_shard can for training. To this end, I have added three test cases, one to test input device movement and the other two to test parameter registration during the forward and backward pass of a model.

**Test Cases**
1. pytest test/distributed/_composable/test_replicate_training.py -k test_root_move_forward_input_to_device
2. pytest test/distributed/_composable/test_replicate_training.py -k TestReplicateRegisteredParams

Pull Request resolved: #160147
Approved by: https://github.com/weifengpy
ghstack dependencies: #160135, #160136
@jithunnair-amd
Copy link
Collaborator

@pytorchbot revert -m "Sorry, but looks like this broke ROCm distributed CI" -c nosignal

@pragupta can provide some more triage details

@jithunnair-amd jithunnair-amd added the ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR label Aug 21, 2025
@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

pytorchmergebot added a commit that referenced this pull request Aug 21, 2025
…input device movements (#160147)"

This reverts commit a3a82e3.

Reverted #160147 on behalf of https://github.com/jithunnair-amd due to Sorry, but looks like this broke ROCm distributed CI ([comment](#160136 (comment)))
pytorchmergebot added a commit that referenced this pull request Aug 21, 2025
…s 1 (#160136)"

This reverts commit 3d126e1.

Reverted #160136 on behalf of https://github.com/jithunnair-amd due to Sorry, but looks like this broke ROCm distributed CI ([comment](#160136 (comment)))
@pytorchmergebot
Copy link
Collaborator

@anshul-si your PR has been successfully reverted.

@pytorchmergebot pytorchmergebot added Reverted ci-no-td Do not run TD on this PR labels Aug 21, 2025
anshul-si added a commit to anshul-si/pytorch that referenced this pull request Aug 21, 2025
can-gaa-hou pushed a commit to can-gaa-hou/pytorch that referenced this pull request Aug 22, 2025
…orch#160136)

**Summary:** In its current state, FSDP collectives uses cuda synchronizations and communication ops regardless of what the world size is. However, now that replicate will use FSDP, there will be instances where group size = 1 and these synchronizations and ops will be used needlessly. I have updated fsdp_collectives to skip reduce_scatter in the foreach_reduce API when world_size ‎ = 1. I have created edited a test that uses CommDebugMode to verify that the reduce_scatter has been removed. I also edited an affected test which used 1-way FSDP by verifying and changing its assert statements for CommDebugMode. I have also added a test command.

**Test Cases**
1. pytest test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_train_parity_single_worldsize1
2. pytest test/distributed/_composable/test_composability/test_2d_composability.py -k test_tp_with_fsdp_offloading

Pull Request resolved: pytorch#160136
Approved by: https://github.com/weifengpy
ghstack dependencies: pytorch#160135
anshul-si added a commit to anshul-si/pytorch that referenced this pull request Aug 26, 2025
…orch#160136)

**Summary:** In its current state, FSDP collectives uses cuda synchronizations and communication ops regardless of what the world size is. However, now that replicate will use FSDP, there will be instances where group size = 1 and these synchronizations and ops will be used needlessly. I have updated fsdp_collectives to skip reduce_scatter in the foreach_reduce API when world_size ‎ = 1. I have created edited a test that uses CommDebugMode to verify that the reduce_scatter has been removed. I also edited an affected test which used 1-way FSDP by verifying and changing its assert statements for CommDebugMode. I have also added a test command.

**Test Cases**
1. pytest test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_train_parity_single_worldsize1
2. pytest test/distributed/_composable/test_composability/test_2d_composability.py -k test_tp_with_fsdp_offloading

Pull Request resolved: pytorch#160136
Approved by: https://github.com/weifengpy
ghstack dependencies: pytorch#160135
anshul-si added a commit to anshul-si/pytorch that referenced this pull request Aug 26, 2025
…vice movements (pytorch#160147)

**Summary:** In order to ensure that replicate acts as intended (a specialized version of hsdp) we need to make sure that it can pass the same tests that fully_shard can for training. To this end, I have added three test cases, one to test input device movement and the other two to test parameter registration during the forward and backward pass of a model.

**Test Cases**
1. pytest test/distributed/_composable/test_replicate_training.py -k test_root_move_forward_input_to_device
2. pytest test/distributed/_composable/test_replicate_training.py -k TestReplicateRegisteredParams

Pull Request resolved: pytorch#160147
Approved by: https://github.com/weifengpy
ghstack dependencies: pytorch#160135, pytorch#160136
…e is 1"

**Summary:** In its current state, FSDP collectives uses cuda synchronizations and communication ops regardless of what the world size is. However, now that replicate will use FSDP, there will be instances where group size = 1 and these synchronizations and ops will be used needlessly. I have updated fsdp_collectives to skip reduce_scatter in the foreach_reduce API when world_size ‎ = 1. I have created edited a test that uses CommDebugMode to verify that the reduce_scatter has been removed. I also edited an affected test which used 1-way FSDP by verifying and changing its assert statements for CommDebugMode. I have also added a test command. 


**Test Cases**
1. pytest test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_train_parity_single_worldsize1
2. pytest test/distributed/_composable/test_composability/test_2d_composability.py -k test_tp_with_fsdp_offloading





cc H-Huang awgu wanchaol fegin fduwjj wz337 wconstab d4l3k pragupta

[ghstack-poisoned]
anshul-si added a commit to anshul-si/pytorch that referenced this pull request Sep 2, 2025
…orch#160136)

**Summary:** In its current state, FSDP collectives uses cuda synchronizations and communication ops regardless of what the world size is. However, now that replicate will use FSDP, there will be instances where group size = 1 and these synchronizations and ops will be used needlessly. I have updated fsdp_collectives to skip reduce_scatter in the foreach_reduce API when world_size ‎ = 1. I have created edited a test that uses CommDebugMode to verify that the reduce_scatter has been removed. I also edited an affected test which used 1-way FSDP by verifying and changing its assert statements for CommDebugMode. I have also added a test command.

**Test Cases**
1. pytest test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_train_parity_single_worldsize1
2. pytest test/distributed/_composable/test_composability/test_2d_composability.py -k test_tp_with_fsdp_offloading

Pull Request resolved: pytorch#160136
Approved by: https://github.com/weifengpy
ghstack dependencies: pytorch#160135
anshul-si added a commit to anshul-si/pytorch that referenced this pull request Sep 2, 2025
…vice movements (pytorch#160147)

**Summary:** In order to ensure that replicate acts as intended (a specialized version of hsdp) we need to make sure that it can pass the same tests that fully_shard can for training. To this end, I have added three test cases, one to test input device movement and the other two to test parameter registration during the forward and backward pass of a model.

**Test Cases**
1. pytest test/distributed/_composable/test_replicate_training.py -k test_root_move_forward_input_to_device
2. pytest test/distributed/_composable/test_replicate_training.py -k TestReplicateRegisteredParams

Pull Request resolved: pytorch#160147
Approved by: https://github.com/weifengpy
ghstack dependencies: pytorch#160135, pytorch#160136
@anshul-si anshul-si closed this Sep 2, 2025
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…orch#160136)

**Summary:** In its current state, FSDP collectives uses cuda synchronizations and communication ops regardless of what the world size is. However, now that replicate will use FSDP, there will be instances where group size = 1 and these synchronizations and ops will be used needlessly. I have updated fsdp_collectives to skip reduce_scatter in the foreach_reduce API when world_size ‎ = 1. I have created edited a test that uses CommDebugMode to verify that the reduce_scatter has been removed. I also edited an affected test which used 1-way FSDP by verifying and changing its assert statements for CommDebugMode. I have also added a test command.

**Test Cases**
1. pytest test/distributed/_composable/fsdp/test_fully_shard_training.py -k test_train_parity_single_worldsize1
2. pytest test/distributed/_composable/test_composability/test_2d_composability.py -k test_tp_with_fsdp_offloading

Pull Request resolved: pytorch#160136
Approved by: https://github.com/weifengpy
ghstack dependencies: pytorch#160135
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…vice movements (pytorch#160147)

**Summary:** In order to ensure that replicate acts as intended (a specialized version of hsdp) we need to make sure that it can pass the same tests that fully_shard can for training. To this end, I have added three test cases, one to test input device movement and the other two to test parameter registration during the forward and backward pass of a model.

**Test Cases**
1. pytest test/distributed/_composable/test_replicate_training.py -k test_root_move_forward_input_to_device
2. pytest test/distributed/_composable/test_replicate_training.py -k TestReplicateRegisteredParams

Pull Request resolved: pytorch#160147
Approved by: https://github.com/weifengpy
ghstack dependencies: pytorch#160135, pytorch#160136
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…input device movements (pytorch#160147)"

This reverts commit a3a82e3.

Reverted pytorch#160147 on behalf of https://github.com/jithunnair-amd due to Sorry, but looks like this broke ROCm distributed CI ([comment](pytorch#160136 (comment)))
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…s 1 (pytorch#160136)"

This reverts commit 3d126e1.

Reverted pytorch#160136 on behalf of https://github.com/jithunnair-amd due to Sorry, but looks like this broke ROCm distributed CI ([comment](pytorch#160136 (comment)))
@github-actions github-actions bot deleted the gh/anshul-si/17/head branch October 3, 2025 02:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/inductor ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (fsdp) release notes category Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants