Skip to content

[SymmMem] Add team pool to hold duplicated teams for the same rank group#162320

Closed
kwen2501 wants to merge 8 commits intogh/kwen2501/232/basefrom
gh/kwen2501/232/head
Closed

[SymmMem] Add team pool to hold duplicated teams for the same rank group#162320
kwen2501 wants to merge 8 commits intogh/kwen2501/232/basefrom
gh/kwen2501/232/head

Conversation

@kwen2501
Copy link
Copy Markdown
Collaborator

@kwen2501 kwen2501 commented Sep 6, 2025

Stack from ghstack (oldest at bottom):

When multiple threadblocks call device-side collectives concurrently, NVSHMEM requires each call being made on a separate team struct, see Collective operations scopes and active sets.

This PR adds a util get_n_teams for creating duplicated nvshmem teams for the same rank group, i.e. team pool. So that we can use them on device side.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @ezyang @msaroufim

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Sep 6, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/162320

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 1 Pending

As of commit 2f3b55a with merge base 7a83cf4 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot Bot added ciflow/h100-symm-mem oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category labels Sep 6, 2025
kwen2501 added a commit that referenced this pull request Sep 6, 2025
[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Sep 6, 2025
Comment thread torch/csrc/distributed/c10d/symm_mem/nvshmem_team_manager.hpp Outdated
Comment thread torch/csrc/distributed/c10d/symm_mem/nvshmem_team_manager.hpp Outdated
Comment thread torch/csrc/distributed/c10d/symm_mem/nvshmem_team_manager.hpp Outdated
[ghstack-poisoned]
if (it == team_pool_devptrs_.end()) {
// If not, allocate a new pool in device memory
C10_CUDA_CHECK(cudaMalloc((void**)&team_pool_dev, pool_bytes));
team_pool_devptrs_[group_name] = team_pool_dev;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you should think about structuring the code in such a way that when group is destroyed, it's team manager entries are also freed, right now you have cudaMalloc (and possibly other resources) that's leaking.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in the destructor now.

Copy link
Copy Markdown
Collaborator

@Skylion007 Skylion007 Sep 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be cleaner to use std::unique_ptrs to accomplish this (with a custom destructor passed to the unique_ptr constructor)?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Normally yes (thanks for the suggestion)! But here I have specific comments + warning message, so I kind of prefer writing it out explicitly.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kwen2501 True, but comments has way more footguns than RAII

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, using standard RAII techniques is preferable

[ghstack-poisoned]
}
} catch (...) {
// Ignore the error
std::cerr << "Failed to free the team pool in device memory, skipping\n";
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use our logging utility we already have in TORCH?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was worried that those logging utility might have been destructed at this point. Can you shed some light on this?

if (it == team_pool_devptrs_.end()) {
// If not, allocate a new pool in device memory
C10_CUDA_CHECK(cudaMalloc((void**)&team_pool_dev, pool_bytes));
team_pool_devptrs_[group_name] = team_pool_dev;
Copy link
Copy Markdown
Collaborator

@Skylion007 Skylion007 Sep 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be cleaner to use std::unique_ptrs to accomplish this (with a custom destructor passed to the unique_ptr constructor)?

Comment on lines +91 to +109
~TeamManager() {
// Free the team pools in device memory
// Note that we do it in a best effort manner because the team pool is
// managed by a static TeamManager and the destruction order of static
// objects is undetermined. If the destructor is called after the CUDA
// context is destroyed, cudaFree would fail.
try {
// cudaFree generally implies a device synchronization, meaning it will
// block until all preceding CUDA operations on the device have completed
// before freeing the memory. Thus we don't need to worry about freeing
// the memory before CUDA kernels complete.
for (auto& [_, team_pool_dev] : team_pool_devptrs_) {
c10::cuda::CUDACachingAllocator::raw_delete(team_pool_dev);
}
} catch (...) {
// Ignore the error
std::cerr << "Failed to free the team pool in device memory, skipping\n";
}
}
Copy link
Copy Markdown
Collaborator Author

@kwen2501 kwen2501 Sep 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel I am doing the free in best-effort manner due to undetermined destruction order of static objects and CUDA context. In my test runs, I never see the cerr message being printed tho, so I guess we are lucky enough.

Groups are a bit static today too. So I guess it would make little difference if we were to implement a callback from group destructor.

Comment thread torch/csrc/distributed/c10d/symm_mem/nvshmem_team_manager.hpp Outdated
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
Copy link
Copy Markdown
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I take it, there will be tests for it some time down the stack?

@kwen2501
Copy link
Copy Markdown
Collaborator Author

kwen2501 commented Sep 9, 2025

All the existing tests will exercise the host-side get_team API.
The tile_reduce feature on top will exercise the device-side get_n_teams API.

@kwen2501
Copy link
Copy Markdown
Collaborator Author

kwen2501 commented Sep 9, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot Bot added the ciflow/trunk Trigger trunk jobs on your pull request label Sep 9, 2025
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Starting merge as part of PR stack under #162394

pytorchmergebot pushed a commit that referenced this pull request Sep 9, 2025
NVSHMEM put/get APIs take global PE instead of local counterpart. So we'd need to do a translation within the kernel.

Also added a sub-group test for dispatch and combine mimic'ing the Expert Parallel cases.

Pull Request resolved: #162394
Approved by: https://github.com/ngimel, https://github.com/fegin
ghstack dependencies: #162320
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
…oup (pytorch#162320)

When multiple threadblocks call device-side collectives concurrently, NVSHMEM requires each call being made on a separate team struct, see [Collective operations scopes and active sets](https://docs.nvidia.com/nvshmem/api/gen/api/collectives.html?highlight=nvshmem_barrier_all#collective-operations-scopes-and-active-sets).

This PR adds a util `get_n_teams` for creating duplicated nvshmem teams for the same rank group, i.e. team pool. So that we can use them on device side.

Pull Request resolved: pytorch#162320
Approved by: https://github.com/ngimel
markc-614 pushed a commit to markc-614/pytorch that referenced this pull request Sep 17, 2025
NVSHMEM put/get APIs take global PE instead of local counterpart. So we'd need to do a translation within the kernel.

Also added a sub-group test for dispatch and combine mimic'ing the Expert Parallel cases.

Pull Request resolved: pytorch#162394
Approved by: https://github.com/ngimel, https://github.com/fegin
ghstack dependencies: pytorch#162320
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
…oup (pytorch#162320)

When multiple threadblocks call device-side collectives concurrently, NVSHMEM requires each call being made on a separate team struct, see [Collective operations scopes and active sets](https://docs.nvidia.com/nvshmem/api/gen/api/collectives.html?highlight=nvshmem_barrier_all#collective-operations-scopes-and-active-sets).

This PR adds a util `get_n_teams` for creating duplicated nvshmem teams for the same rank group, i.e. team pool. So that we can use them on device side.

Pull Request resolved: pytorch#162320
Approved by: https://github.com/ngimel
mansiag05 pushed a commit to mansiag05/pytorch that referenced this pull request Sep 22, 2025
NVSHMEM put/get APIs take global PE instead of local counterpart. So we'd need to do a translation within the kernel.

Also added a sub-group test for dispatch and combine mimic'ing the Expert Parallel cases.

Pull Request resolved: pytorch#162394
Approved by: https://github.com/ngimel, https://github.com/fegin
ghstack dependencies: pytorch#162320
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
…oup (pytorch#162320)

When multiple threadblocks call device-side collectives concurrently, NVSHMEM requires each call being made on a separate team struct, see [Collective operations scopes and active sets](https://docs.nvidia.com/nvshmem/api/gen/api/collectives.html?highlight=nvshmem_barrier_all#collective-operations-scopes-and-active-sets).

This PR adds a util `get_n_teams` for creating duplicated nvshmem teams for the same rank group, i.e. team pool. So that we can use them on device side.

Pull Request resolved: pytorch#162320
Approved by: https://github.com/ngimel
cleonard530 pushed a commit to cleonard530/pytorch that referenced this pull request Sep 22, 2025
NVSHMEM put/get APIs take global PE instead of local counterpart. So we'd need to do a translation within the kernel.

Also added a sub-group test for dispatch and combine mimic'ing the Expert Parallel cases.

Pull Request resolved: pytorch#162394
Approved by: https://github.com/ngimel, https://github.com/fegin
ghstack dependencies: pytorch#162320
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
…oup (pytorch#162320)

When multiple threadblocks call device-side collectives concurrently, NVSHMEM requires each call being made on a separate team struct, see [Collective operations scopes and active sets](https://docs.nvidia.com/nvshmem/api/gen/api/collectives.html?highlight=nvshmem_barrier_all#collective-operations-scopes-and-active-sets).

This PR adds a util `get_n_teams` for creating duplicated nvshmem teams for the same rank group, i.e. team pool. So that we can use them on device side.

Pull Request resolved: pytorch#162320
Approved by: https://github.com/ngimel
dsashidh pushed a commit to dsashidh/pytorch that referenced this pull request Sep 26, 2025
NVSHMEM put/get APIs take global PE instead of local counterpart. So we'd need to do a translation within the kernel.

Also added a sub-group test for dispatch and combine mimic'ing the Expert Parallel cases.

Pull Request resolved: pytorch#162394
Approved by: https://github.com/ngimel, https://github.com/fegin
ghstack dependencies: pytorch#162320
@github-actions github-actions Bot deleted the gh/kwen2501/232/head branch October 10, 2025 02:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/h100-symm-mem ciflow/trunk Trigger trunk jobs on your pull request Merged oncall: distributed Add this issue/PR to distributed oncall triage queue release notes: distributed (c10d) release notes category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants