[SymmMem] Add team pool to hold duplicated teams for the same rank group#162320
[SymmMem] Add team pool to hold duplicated teams for the same rank group#162320kwen2501 wants to merge 8 commits intogh/kwen2501/232/basefrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/162320
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 1 PendingAs of commit 2f3b55a with merge base 7a83cf4 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| if (it == team_pool_devptrs_.end()) { | ||
| // If not, allocate a new pool in device memory | ||
| C10_CUDA_CHECK(cudaMalloc((void**)&team_pool_dev, pool_bytes)); | ||
| team_pool_devptrs_[group_name] = team_pool_dev; |
There was a problem hiding this comment.
you should think about structuring the code in such a way that when group is destroyed, it's team manager entries are also freed, right now you have cudaMalloc (and possibly other resources) that's leaking.
There was a problem hiding this comment.
Added in the destructor now.
There was a problem hiding this comment.
Wouldn't it be cleaner to use std::unique_ptrs to accomplish this (with a custom destructor passed to the unique_ptr constructor)?
There was a problem hiding this comment.
Normally yes (thanks for the suggestion)! But here I have specific comments + warning message, so I kind of prefer writing it out explicitly.
There was a problem hiding this comment.
@kwen2501 True, but comments has way more footguns than RAII
There was a problem hiding this comment.
Agree, using standard RAII techniques is preferable
| } | ||
| } catch (...) { | ||
| // Ignore the error | ||
| std::cerr << "Failed to free the team pool in device memory, skipping\n"; |
There was a problem hiding this comment.
Why not use our logging utility we already have in TORCH?
There was a problem hiding this comment.
I was worried that those logging utility might have been destructed at this point. Can you shed some light on this?
| if (it == team_pool_devptrs_.end()) { | ||
| // If not, allocate a new pool in device memory | ||
| C10_CUDA_CHECK(cudaMalloc((void**)&team_pool_dev, pool_bytes)); | ||
| team_pool_devptrs_[group_name] = team_pool_dev; |
There was a problem hiding this comment.
Wouldn't it be cleaner to use std::unique_ptrs to accomplish this (with a custom destructor passed to the unique_ptr constructor)?
| ~TeamManager() { | ||
| // Free the team pools in device memory | ||
| // Note that we do it in a best effort manner because the team pool is | ||
| // managed by a static TeamManager and the destruction order of static | ||
| // objects is undetermined. If the destructor is called after the CUDA | ||
| // context is destroyed, cudaFree would fail. | ||
| try { | ||
| // cudaFree generally implies a device synchronization, meaning it will | ||
| // block until all preceding CUDA operations on the device have completed | ||
| // before freeing the memory. Thus we don't need to worry about freeing | ||
| // the memory before CUDA kernels complete. | ||
| for (auto& [_, team_pool_dev] : team_pool_devptrs_) { | ||
| c10::cuda::CUDACachingAllocator::raw_delete(team_pool_dev); | ||
| } | ||
| } catch (...) { | ||
| // Ignore the error | ||
| std::cerr << "Failed to free the team pool in device memory, skipping\n"; | ||
| } | ||
| } |
There was a problem hiding this comment.
@ngimel I am doing the free in best-effort manner due to undetermined destruction order of static objects and CUDA context. In my test runs, I never see the cerr message being printed tho, so I guess we are lucky enough.
Groups are a bit static today too. So I guess it would make little difference if we were to implement a callback from group destructor.
ngimel
left a comment
There was a problem hiding this comment.
I take it, there will be tests for it some time down the stack?
|
All the existing tests will exercise the host-side |
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
|
Starting merge as part of PR stack under #162394 |
NVSHMEM put/get APIs take global PE instead of local counterpart. So we'd need to do a translation within the kernel. Also added a sub-group test for dispatch and combine mimic'ing the Expert Parallel cases. Pull Request resolved: #162394 Approved by: https://github.com/ngimel, https://github.com/fegin ghstack dependencies: #162320
…oup (pytorch#162320) When multiple threadblocks call device-side collectives concurrently, NVSHMEM requires each call being made on a separate team struct, see [Collective operations scopes and active sets](https://docs.nvidia.com/nvshmem/api/gen/api/collectives.html?highlight=nvshmem_barrier_all#collective-operations-scopes-and-active-sets). This PR adds a util `get_n_teams` for creating duplicated nvshmem teams for the same rank group, i.e. team pool. So that we can use them on device side. Pull Request resolved: pytorch#162320 Approved by: https://github.com/ngimel
NVSHMEM put/get APIs take global PE instead of local counterpart. So we'd need to do a translation within the kernel. Also added a sub-group test for dispatch and combine mimic'ing the Expert Parallel cases. Pull Request resolved: pytorch#162394 Approved by: https://github.com/ngimel, https://github.com/fegin ghstack dependencies: pytorch#162320
…oup (pytorch#162320) When multiple threadblocks call device-side collectives concurrently, NVSHMEM requires each call being made on a separate team struct, see [Collective operations scopes and active sets](https://docs.nvidia.com/nvshmem/api/gen/api/collectives.html?highlight=nvshmem_barrier_all#collective-operations-scopes-and-active-sets). This PR adds a util `get_n_teams` for creating duplicated nvshmem teams for the same rank group, i.e. team pool. So that we can use them on device side. Pull Request resolved: pytorch#162320 Approved by: https://github.com/ngimel
NVSHMEM put/get APIs take global PE instead of local counterpart. So we'd need to do a translation within the kernel. Also added a sub-group test for dispatch and combine mimic'ing the Expert Parallel cases. Pull Request resolved: pytorch#162394 Approved by: https://github.com/ngimel, https://github.com/fegin ghstack dependencies: pytorch#162320
…oup (pytorch#162320) When multiple threadblocks call device-side collectives concurrently, NVSHMEM requires each call being made on a separate team struct, see [Collective operations scopes and active sets](https://docs.nvidia.com/nvshmem/api/gen/api/collectives.html?highlight=nvshmem_barrier_all#collective-operations-scopes-and-active-sets). This PR adds a util `get_n_teams` for creating duplicated nvshmem teams for the same rank group, i.e. team pool. So that we can use them on device side. Pull Request resolved: pytorch#162320 Approved by: https://github.com/ngimel
NVSHMEM put/get APIs take global PE instead of local counterpart. So we'd need to do a translation within the kernel. Also added a sub-group test for dispatch and combine mimic'ing the Expert Parallel cases. Pull Request resolved: pytorch#162394 Approved by: https://github.com/ngimel, https://github.com/fegin ghstack dependencies: pytorch#162320
…oup (pytorch#162320) When multiple threadblocks call device-side collectives concurrently, NVSHMEM requires each call being made on a separate team struct, see [Collective operations scopes and active sets](https://docs.nvidia.com/nvshmem/api/gen/api/collectives.html?highlight=nvshmem_barrier_all#collective-operations-scopes-and-active-sets). This PR adds a util `get_n_teams` for creating duplicated nvshmem teams for the same rank group, i.e. team pool. So that we can use them on device side. Pull Request resolved: pytorch#162320 Approved by: https://github.com/ngimel
NVSHMEM put/get APIs take global PE instead of local counterpart. So we'd need to do a translation within the kernel. Also added a sub-group test for dispatch and combine mimic'ing the Expert Parallel cases. Pull Request resolved: pytorch#162394 Approved by: https://github.com/ngimel, https://github.com/fegin ghstack dependencies: pytorch#162320
Stack from ghstack (oldest at bottom):
When multiple threadblocks call device-side collectives concurrently, NVSHMEM requires each call being made on a separate team struct, see Collective operations scopes and active sets.
This PR adds a util
get_n_teamsfor creating duplicated nvshmem teams for the same rank group, i.e. team pool. So that we can use them on device side.cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @ezyang @msaroufim