Skip to content

[SymmMem] Tiled reduce#162243

Closed
kwen2501 wants to merge 13 commits intogh/kwen2501/231/basefrom
gh/kwen2501/231/head
Closed

[SymmMem] Tiled reduce#162243
kwen2501 wants to merge 13 commits intogh/kwen2501/231/basefrom
gh/kwen2501/231/head

Conversation

@kwen2501
Copy link
Copy Markdown
Collaborator

@kwen2501 kwen2501 commented Sep 5, 2025

Stack from ghstack (oldest at bottom):

Added op: tile_reduce(Tensor input, Tensor(a!) out, int root, str group_name)

For now supports only:

  • NVSHMEM backed symmetric tensor;
  • 2D tensor and tile;
  • torch.float.

Testing on right-bottom quandrant:

rank 0:
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1.]], device='cuda:0')
PASSED                                                                                             

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @ezyang

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Sep 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/162243

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

✅ No Failures

As of commit a622842 with merge base a707042 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

kwen2501 added a commit that referenced this pull request Sep 5, 2025
ghstack-source-id: 541b03e
Pull-Request-resolved: #162243
@pytorch-bot pytorch-bot Bot added ciflow/h100-symm-mem oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels Sep 5, 2025
[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Sep 5, 2025
ghstack-source-id: 016c9e8
Pull-Request-resolved: #162243
@kwen2501 kwen2501 requested review from fegin and ngimel September 5, 2025 03:36
Comment thread torch/csrc/distributed/c10d/symm_mem/nvshmem_extension.cu Outdated
@ngimel
Copy link
Copy Markdown
Collaborator

ngimel commented Sep 5, 2025

Can we have some benchmarks for inter and intra node? E.g. compared to copy + nccl?

@kwen2501
Copy link
Copy Markdown
Collaborator Author

kwen2501 commented Sep 5, 2025

@ngimel I need to add some util code to create multiple teams to boost the bandwidth. Stack PR coming :)

Copy link
Copy Markdown
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you give a tl;dr how you expect multiple teams to improve perf?

// src_tensor and dst_tensor are already the tiles to operate on, thus we set
// the start_coord to 0
auto start_coord = nvshmemx::make_shape(0, 0);
nvshmemx::tile_sum_reduce<decltype(src_tensor), decltype(dst_tensor), Shape2D, nvshmemx::tile_coll_algo_t::NVLS_ONE_SHOT_PULL_NBI>(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One-shot algorithms are ok only for small sizes, for larger sizes they result in 4x more network traffic for 8 world size

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docs say The users are expected to use tile_collective_wait routine to ensure completion of the non-blocking collectives., I don't see it here

Copy link
Copy Markdown
Collaborator Author

@kwen2501 kwen2501 Sep 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the limitation of NVSHMEM today, only three algorithms are available:

tile_coll_algo_t::NVLS_ONE_SHOT_PUSH_NBI
tile_coll_algo_t::NVLS_ONE_SHOT_PULL_NBI
tile_coll_algo_t::NVLS_TWO_SHOT_PUSH_NBI

And I don't think TWO_SHOT would work for reduce.

One-shot reduce (not all-reduce) will not create extra traffic, but it would indeed create a hot-spot at the root GPU.

So this relates to whether the collective has access to intermediate buffers or not:

  • if not, reduce can do only one-shot, thus hot-spot and never bandwidth optimal;
  • if yes, then algorithms like ring is possible, thus bandwidth optimal.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (for some not very small sizes) TWO_SHOT allreduce is faster than one-shot reduce we should be using it? At a higher level, what are you trying to achieve? Is it inter-node or intra-node?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TWO_SHOT allreduce will modify non-root ranks' buffer, so not so much a 1:1 in terms of semantics.

@kwen2501
Copy link
Copy Markdown
Collaborator Author

kwen2501 commented Sep 5, 2025

Can you give a tl;dr how you expect multiple teams to improve perf?

@ngimel today the PR launches only 1 CUDA block to work on the tile. If we want to scale to multiple blocks, e.g. 1 block per k rows, we'd need 1 team per block, because that's the semantics of the tile_sum_reduce API.

[ghstack-poisoned]
[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Sep 6, 2025
ghstack-source-id: acf6e0f
Pull-Request-resolved: #162243

use tile in arg name

teams_dev

wait

boundry

nblocks
[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Sep 7, 2025
ghstack-source-id: d4fadf4
Pull-Request-resolved: #162243

use tile in arg name

teams_dev

wait

boundry

nblocks
[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Sep 7, 2025
ghstack-source-id: c40d27c
Pull-Request-resolved: #162243

use tile in arg name

teams_dev

wait

boundry

nblocks

reduce op
[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Sep 7, 2025
ghstack-source-id: 8e1b03c
Pull-Request-resolved: #162243

use tile in arg name

teams_dev

wait

boundry

nblocks

reduce op
[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Sep 8, 2025
ghstack-source-id: cff954c
Pull-Request-resolved: #162243

use tile in arg name

teams_dev

wait

boundry

nblocks

reduce op
[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Sep 9, 2025
ghstack-source-id: 1d96422
Pull-Request-resolved: #162243

use tile in arg name

teams_dev

wait

boundry

nblocks

reduce op
[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Oct 2, 2025
ghstack-source-id: 5b415a9
Pull-Request-resolved: #162243

use tile in arg name

teams_dev

wait

boundry

nblocks

reduce op

boundary

Each block has a sub-tile
[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Oct 2, 2025
ghstack-source-id: 18cf071
Pull-Request-resolved: #162243

use tile in arg name

teams_dev

wait

boundry

nblocks

reduce op

Each block has a sub-tile

Empty start_coord and boundary
[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Oct 3, 2025
ghstack-source-id: 1a81ab2
Pull-Request-resolved: #162243

use tile in arg name

teams_dev

wait

boundry

nblocks

reduce op

Each block has a sub-tile

Empty start_coord and boundary

Add benchmark
[ghstack-poisoned]
* receiving the reduced tensor. */
TORCH_CHECK(reduce_op == "sum", "tile_reduce: only sum is supported for now");
TORCH_CHECK(in_tile.dim() == 2 && out_tile.dim() == 2, "Only 2D tensors are supported");
TORCH_CHECK_EQ(in_tile.dtype(), out_tile.dtype());
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add user-friendly error messages here

Copy link
Copy Markdown
Collaborator Author

@kwen2501 kwen2501 Oct 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had a look at the macro expansion of TORCH_CHECK_EQ, looks okay friendly?

#define TORCH_CHECK_OP(val1, val2, op) \
FATAL_IF(((val1)op(val2))) << "Check failed: " #val1 " " #op " " #val2 " (" \
<< (val1) << " vs. " << (val2) << ") "

@kwen2501
Copy link
Copy Markdown
Collaborator Author

kwen2501 commented Oct 7, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot Bot added the ciflow/trunk Trigger trunk jobs on your pull request label Oct 7, 2025
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Starting merge as part of PR stack under #164757

pytorchmergebot pushed a commit that referenced this pull request Oct 8, 2025
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom):

Perform multiple tile reductions concurrently, with each tile reduced to a separate root.

- The number of concurrent reductions can be smaller than world size, i.e. roots can be a subset of all ranks. But all ranks are still required to call into this API.

- Currently supports NVLink SHARP scope only.

Pull Request resolved: #164757
Approved by: https://github.com/weifengpy, https://github.com/fegin
ghstack dependencies: #162243
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
Added op: `tile_reduce(Tensor input, Tensor(a!) out, int root, str group_name)`

For now supports only:
- NVSHMEM backed symmetric tensor;
- 2D tensor and tile;
- torch.float.

Testing on right-bottom quandrant:
```
rank 0:
tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1.]], device='cuda:0')
PASSED
```

Pull Request resolved: pytorch#162243
Approved by: https://github.com/ngimel
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom):

Perform multiple tile reductions concurrently, with each tile reduced to a separate root.

- The number of concurrent reductions can be smaller than world size, i.e. roots can be a subset of all ranks. But all ranks are still required to call into this API.

- Currently supports NVLink SHARP scope only.

Pull Request resolved: pytorch#164757
Approved by: https://github.com/weifengpy, https://github.com/fegin
ghstack dependencies: pytorch#162243
@github-actions github-actions Bot deleted the gh/kwen2501/231/head branch November 7, 2025 02:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/h100-symm-mem ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants