Skip to content

Use SymmMem for reduce-scatter in FSDP#177111

Closed
kwen2501 wants to merge 5 commits intogh/kwen2501/325/basefrom
gh/kwen2501/325/head
Closed

Use SymmMem for reduce-scatter in FSDP#177111
kwen2501 wants to merge 5 commits intogh/kwen2501/325/basefrom
gh/kwen2501/325/head

Conversation

@kwen2501
Copy link
Copy Markdown
Collaborator

@kwen2501 kwen2501 commented Mar 11, 2026

Stack from ghstack (oldest at bottom):

Summary

This change enables symmetric memory optimizations for reduce-scatter collectives, matching the behavior already available for all-gather.

Changes

  1. Added SymmMemReduceScatter class: Similar to SymmMemAllGather, this class:

    • Allocates tensors from symmetric memory pool (via SymmMemAllocMixin)
    • Rendezvouses both input and output tensors before calling reduce_scatter_tensor
    • This allows NCCL to detect symmetric memory tensors and use the optimized symmetric kernel
  2. Updated set_symm_mem(): Now sets both all-gather and reduce-scatter to use symmetric memory implementations when set_symm_mem_for_comm() is called.

  3. Enhanced testing: Added parametrized test to verify ReduceOp.SUM reduction modes work with symmetric memory.

Testing

NCCL INFO ReduceScatter [Symmetric]: 100681728 Bytes -> Kernel ReduceScatter_LDMC nchannels 2 nthreads 

Notes

  • Today symmetric kernel is enabled for ReduceOp.SUM only (when set_force_sum_reduction_for_comms(True) is used). For other ReduceOp, NCCL falls back to regular kernel.

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Mar 11, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/177111

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 4 Pending, 4 Unrelated Failures

As of commit e833582 with merge base 1ef51a6 (image):

NEW FAILURE - The following job has failed:

FLAKY - The following jobs failed but were likely due to flakiness present on trunk:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added ciflow/inductor ciflow/torchtitan Run TorchTitan integration tests release notes: distributed (fsdp) release notes category labels Mar 11, 2026
kwen2501 added a commit that referenced this pull request Mar 11, 2026
ghstack-source-id: 84c806c
Pull-Request: #177111
@kwen2501 kwen2501 requested a review from weifengpy March 11, 2026 03:51
@kwen2501 kwen2501 added the module: symm_mem Issues and PRs of Symmetric Memory label Mar 11, 2026
[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Mar 11, 2026
ghstack-source-id: cf9be9a
Pull-Request: #177111
[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Mar 12, 2026
ghstack-source-id: e96a995
Pull-Request: #177111
@kwen2501 kwen2501 added the ciflow/trunk Trigger trunk jobs on your pull request label Mar 12, 2026
@kwen2501
Copy link
Copy Markdown
Collaborator Author

@pytorchbot merge -f "Failures are from Inductor tests and unrelated"

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

pytorchmergebot added a commit that referenced this pull request Mar 13, 2026
This reverts commit 332e4c7.

Reverted #177111 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](#176613 (comment)))
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@kwen2501 your PR has been reverted as part of the stack under #176613.

[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Mar 16, 2026
ghstack-source-id: 29270ca
Pull-Request: #177111
@kwen2501
Copy link
Copy Markdown
Collaborator Author

@pytorchbot merge -i

pytorchmergebot added a commit that referenced this pull request Mar 17, 2026
This reverts commit 0ae127b.

Reverted #177111 on behalf of https://github.com/yangw-dev due to internal test failed due to AssertionError: Unexpected methods found in class: {'set_symm_mem_for_comm'}, Missing methods: set(), please ask intenral folks to help add set_symm_mem_for_comm in expected method D96767236 ([comment](#176613 (comment)))
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

@kwen2501 your PR has been reverted as part of the stack under #176613.

[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Mar 19, 2026
ghstack-source-id: c2f7600
Pull-Request: #177111
@kwen2501
Copy link
Copy Markdown
Collaborator Author

@pytorchbot merge -f "Failure is unrelated"

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes). Please use -f as last resort and instead consider -i/--ignore-current to continue the merge ignoring current failures. This will allow currently pending tests to finish and report signal before the merge.

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
## Summary

This change enables symmetric memory optimizations for reduce-scatter collectives, matching the behavior already available for all-gather.

## Changes

1. **Added `SymmMemReduceScatter` class**: Similar to `SymmMemAllGather`, this class:
   - Allocates tensors from symmetric memory pool (via `SymmMemAllocMixin`)
   - Rendezvouses both input and output tensors before calling `reduce_scatter_tensor`
   - This allows NCCL to detect symmetric memory tensors and use the optimized symmetric kernel

2. **Updated `set_symm_mem()`**: Now sets both all-gather and reduce-scatter to use symmetric memory implementations when `set_symm_mem_for_comm()` is called.

3. **Enhanced testing**: Added parametrized test to verify `ReduceOp.SUM` reduction modes work with symmetric memory.

## Testing

```
NCCL INFO ReduceScatter [Symmetric]: 100681728 Bytes -> Kernel ReduceScatter_LDMC nchannels 2 nthreads
```

## Notes

- Today symmetric kernel is enabled for `ReduceOp.SUM` only (when `set_force_sum_reduction_for_comms(True)` is used). For other ReduceOp, NCCL falls back to regular kernel.
Pull Request resolved: pytorch#177111
Approved by: https://github.com/Skylion007, https://github.com/weifengpy
ghstack dependencies: pytorch#176613
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
This reverts commit 332e4c7.

Reverted pytorch#177111 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](pytorch#176613 (comment)))
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
## Summary

This change enables symmetric memory optimizations for reduce-scatter collectives, matching the behavior already available for all-gather.

## Changes

1. **Added `SymmMemReduceScatter` class**: Similar to `SymmMemAllGather`, this class:
   - Allocates tensors from symmetric memory pool (via `SymmMemAllocMixin`)
   - Rendezvouses both input and output tensors before calling `reduce_scatter_tensor`
   - This allows NCCL to detect symmetric memory tensors and use the optimized symmetric kernel

2. **Updated `set_symm_mem()`**: Now sets both all-gather and reduce-scatter to use symmetric memory implementations when `set_symm_mem_for_comm()` is called.

3. **Enhanced testing**: Added parametrized test to verify `ReduceOp.SUM` reduction modes work with symmetric memory.

## Testing

```
NCCL INFO ReduceScatter [Symmetric]: 100681728 Bytes -> Kernel ReduceScatter_LDMC nchannels 2 nthreads
```

## Notes

- Today symmetric kernel is enabled for `ReduceOp.SUM` only (when `set_force_sum_reduction_for_comms(True)` is used). For other ReduceOp, NCCL falls back to regular kernel.
Pull Request resolved: pytorch#177111
Approved by: https://github.com/Skylion007, https://github.com/weifengpy
ghstack dependencies: pytorch#176613
EmanueleCoradin pushed a commit to EmanueleCoradin/pytorch that referenced this pull request Mar 30, 2026
This reverts commit 0ae127b.

Reverted pytorch#177111 on behalf of https://github.com/yangw-dev due to internal test failed due to AssertionError: Unexpected methods found in class: {'set_symm_mem_for_comm'}, Missing methods: set(), please ask intenral folks to help add set_symm_mem_for_comm in expected method D96767236 ([comment](pytorch#176613 (comment)))
AaronWang04 pushed a commit to AaronWang04/pytorch that referenced this pull request Mar 31, 2026
## Summary

This change enables symmetric memory optimizations for reduce-scatter collectives, matching the behavior already available for all-gather.

## Changes

1. **Added `SymmMemReduceScatter` class**: Similar to `SymmMemAllGather`, this class:
   - Allocates tensors from symmetric memory pool (via `SymmMemAllocMixin`)
   - Rendezvouses both input and output tensors before calling `reduce_scatter_tensor`
   - This allows NCCL to detect symmetric memory tensors and use the optimized symmetric kernel

2. **Updated `set_symm_mem()`**: Now sets both all-gather and reduce-scatter to use symmetric memory implementations when `set_symm_mem_for_comm()` is called.

3. **Enhanced testing**: Added parametrized test to verify `ReduceOp.SUM` reduction modes work with symmetric memory.

## Testing

```
NCCL INFO ReduceScatter [Symmetric]: 100681728 Bytes -> Kernel ReduceScatter_LDMC nchannels 2 nthreads
```

## Notes

- Today symmetric kernel is enabled for `ReduceOp.SUM` only (when `set_force_sum_reduction_for_comms(True)` is used). For other ReduceOp, NCCL falls back to regular kernel.
Pull Request resolved: pytorch#177111
Approved by: https://github.com/Skylion007, https://github.com/weifengpy
ghstack dependencies: pytorch#176613
AaronWang04 pushed a commit to AaronWang04/pytorch that referenced this pull request Mar 31, 2026
This reverts commit 0ae127b.

Reverted pytorch#177111 on behalf of https://github.com/yangw-dev due to internal test failed due to AssertionError: Unexpected methods found in class: {'set_symm_mem_for_comm'}, Missing methods: set(), please ask intenral folks to help add set_symm_mem_for_comm in expected method D96767236 ([comment](pytorch#176613 (comment)))
AaronWang04 pushed a commit to AaronWang04/pytorch that referenced this pull request Mar 31, 2026
## Summary

This change enables symmetric memory optimizations for reduce-scatter collectives, matching the behavior already available for all-gather.

## Changes

1. **Added `SymmMemReduceScatter` class**: Similar to `SymmMemAllGather`, this class:
   - Allocates tensors from symmetric memory pool (via `SymmMemAllocMixin`)
   - Rendezvouses both input and output tensors before calling `reduce_scatter_tensor`
   - This allows NCCL to detect symmetric memory tensors and use the optimized symmetric kernel

2. **Updated `set_symm_mem()`**: Now sets both all-gather and reduce-scatter to use symmetric memory implementations when `set_symm_mem_for_comm()` is called.

3. **Enhanced testing**: Added parametrized test to verify `ReduceOp.SUM` reduction modes work with symmetric memory.

## Testing

```
NCCL INFO ReduceScatter [Symmetric]: 100681728 Bytes -> Kernel ReduceScatter_LDMC nchannels 2 nthreads
```

## Notes

- Today symmetric kernel is enabled for `ReduceOp.SUM` only (when `set_force_sum_reduction_for_comms(True)` is used). For other ReduceOp, NCCL falls back to regular kernel.
Pull Request resolved: pytorch#177111
Approved by: https://github.com/Skylion007, https://github.com/weifengpy
ghstack dependencies: pytorch#176613
nklshy-aws pushed a commit to nklshy-aws/pytorch that referenced this pull request Apr 7, 2026
## Summary

This change enables symmetric memory optimizations for reduce-scatter collectives, matching the behavior already available for all-gather.

## Changes

1. **Added `SymmMemReduceScatter` class**: Similar to `SymmMemAllGather`, this class:
   - Allocates tensors from symmetric memory pool (via `SymmMemAllocMixin`)
   - Rendezvouses both input and output tensors before calling `reduce_scatter_tensor`
   - This allows NCCL to detect symmetric memory tensors and use the optimized symmetric kernel

2. **Updated `set_symm_mem()`**: Now sets both all-gather and reduce-scatter to use symmetric memory implementations when `set_symm_mem_for_comm()` is called.

3. **Enhanced testing**: Added parametrized test to verify `ReduceOp.SUM` reduction modes work with symmetric memory.

## Testing

```
NCCL INFO ReduceScatter [Symmetric]: 100681728 Bytes -> Kernel ReduceScatter_LDMC nchannels 2 nthreads
```

## Notes

- Today symmetric kernel is enabled for `ReduceOp.SUM` only (when `set_force_sum_reduction_for_comms(True)` is used). For other ReduceOp, NCCL falls back to regular kernel.
Pull Request resolved: pytorch#177111
Approved by: https://github.com/Skylion007, https://github.com/weifengpy
ghstack dependencies: pytorch#176613
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-no-td Do not run TD on this PR ciflow/inductor ciflow/torchtitan Run TorchTitan integration tests ciflow/trunk Trigger trunk jobs on your pull request Merged module: symm_mem Issues and PRs of Symmetric Memory open source release notes: distributed (fsdp) release notes category Reverted

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants