[SymmMem] Deprecate enable_symm_mem_for_group#172163
[SymmMem] Deprecate enable_symm_mem_for_group#172163kwen2501 wants to merge 3 commits intogh/kwen2501/305/basefrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/172163
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 945de0e with merge base 8cfe6f1 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| // For logging only | ||
| static int exchanged_n_times = 0; | ||
| auto global_rank = get_group_info("0").rank; | ||
| auto store = group_info.store; | ||
| // Exchange rank to global rank mapping for this group. | ||
| // If it is already available, skip the exchange. | ||
| if (group_info.rank_to_global_rank.empty()) { | ||
| group_info.rank_to_global_rank = | ||
| storeExchange.all_gather(store, rank_, world_size_, global_rank); | ||
| exchanged_n_times++; | ||
| if (rank_ == 0) { | ||
| LOG(INFO) << "[rank " << rank_ << ']' | ||
| << " rank_to_global_rank: " << group_info.rank_to_global_rank | ||
| << ", group_name: " << group_name_ | ||
| << ", exchanged_n_times: " << exchanged_n_times; | ||
| } | ||
| } | ||
|
|
||
| TORCH_INTERNAL_ASSERT(!group_info.rank_to_global_rank.empty()); | ||
| rank_to_global_rank_ = group_info.rank_to_global_rank; | ||
|
|
There was a problem hiding this comment.
NCCL has no need for this mapping.
| auto group_info = get_group_info("0"); | ||
| auto store = group_info.store; |
There was a problem hiding this comment.
Removing dead code
|
|
||
| rank_to_global_rank_dev_ = reinterpret_cast<int*>( | ||
| c10::cuda::CUDACachingAllocator::raw_alloc(sizeof(int) * world_size_)); | ||
| AT_CUDA_CHECK(cudaMemcpy( | ||
| rank_to_global_rank_dev_, | ||
| rank_to_global_rank_.data(), | ||
| sizeof(int) * world_size_, | ||
| cudaMemcpyHostToDevice)); |
There was a problem hiding this comment.
This device-side mapping doesn't need to be repeatedly allocated per-handle. I made it per-group above.
| // Note this field is not automatically populated by set_group_info(). If a | ||
| // SymmetricMemory implementation needs to use it, it must be populated by a | ||
| // call to exchange_global_ranks() first. | ||
| std::vector<int> rank_to_global_rank; |
There was a problem hiding this comment.
This is only needed by NVSHMEM. I made it an internal impl.
Skylion007
left a comment
There was a problem hiding this comment.
Definitely better than before
| }; | ||
|
|
||
| // A map from group name to rank-to-global rank mapping | ||
| static std::unordered_map<std::string, std::vector<int>> rank_to_global_rank_map{}; |
There was a problem hiding this comment.
Any reason this map needs to be reference outside the allocator instance?
There was a problem hiding this comment.
It is referenced by two classes now:
NVSHMEMPeerAllocInfo and NVSHMEMSymmetricMemory
There was a problem hiding this comment.
Bad code smell right there for a global static map
|
|
||
|
|
||
| @deprecated( | ||
| "`enable_symm_mem_for_group` is deprecated. There is no need to call this function anymore." |
There was a problem hiding this comment.
Specify warning type in deprecated decorator as FutureWarning since it will likely be removed soon
|
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
Resolves pytorch#171827 `enable_symm_mem_for_group` is for getting access to the store of a group. But the store can be also retrieved by `ProcessGroup.getStore()` in C++. Thus makes little sense to require users to call `enable_symm_mem_for_group`. Pull Request resolved: pytorch#172163 Approved by: https://github.com/Skylion007
Resolves pytorch#171827 `enable_symm_mem_for_group` is for getting access to the store of a group. But the store can be also retrieved by `ProcessGroup.getStore()` in C++. Thus makes little sense to require users to call `enable_symm_mem_for_group`. Pull Request resolved: pytorch#172163 Approved by: https://github.com/Skylion007
Pull Request resolved: #172185 Approved by: https://github.com/Skylion007, https://github.com/dzmitry-huba ghstack dependencies: #172163
Resolves #172050 Two motivations: - Give better UX and perf to users who explicitly use `symm_mem.empty()`. - Simplify the code generated by Inductor, i.e. `symm_mem.empty()` would automatically reuse memory, rather than requiring Inductor to bookkeep it. The MemPool infra for all CUDA backends (`CUDA`, `NVSHMEM`, `NCCL`) has been built previously. Pull Request resolved: #172292 Approved by: https://github.com/ngimel, https://github.com/dzmitry-huba ghstack dependencies: #172163
Pull Request resolved: #172185 Approved by: https://github.com/Skylion007, https://github.com/dzmitry-huba ghstack dependencies: #172163
Fixes #172398 `NCCL_DEV_COMM_REQUIREMENTS_INITIALIZER` available in NCCL 2.29. Pull Request resolved: #172400 Approved by: https://github.com/dzmitry-huba, https://github.com/fduwjj ghstack dependencies: #172163
Resolves pytorch#172050 Two motivations: - Give better UX and perf to users who explicitly use `symm_mem.empty()`. - Simplify the code generated by Inductor, i.e. `symm_mem.empty()` would automatically reuse memory, rather than requiring Inductor to bookkeep it. The MemPool infra for all CUDA backends (`CUDA`, `NVSHMEM`, `NCCL`) has been built previously. Pull Request resolved: pytorch#172292 Approved by: https://github.com/ngimel, https://github.com/dzmitry-huba ghstack dependencies: pytorch#172163
Pull Request resolved: pytorch#172185 Approved by: https://github.com/Skylion007, https://github.com/dzmitry-huba ghstack dependencies: pytorch#172163
Fixes pytorch#172398 `NCCL_DEV_COMM_REQUIREMENTS_INITIALIZER` available in NCCL 2.29. Pull Request resolved: pytorch#172400 Approved by: https://github.com/dzmitry-huba, https://github.com/fduwjj ghstack dependencies: pytorch#172163
Resolves #172050 Two motivations: - Give better UX and perf to users who explicitly use `symm_mem.empty()`. - Simplify the code generated by Inductor, i.e. `symm_mem.empty()` would automatically reuse memory, rather than requiring Inductor to bookkeep it. The MemPool infra for all CUDA backends (`CUDA`, `NVSHMEM`, `NCCL`) has been built previously. Pull Request resolved: #172292 Approved by: https://github.com/ngimel, https://github.com/dzmitry-huba ghstack dependencies: #172163
Pull Request resolved: #172185 Approved by: https://github.com/Skylion007, https://github.com/dzmitry-huba ghstack dependencies: #172163
Fixes pytorch#172398 `NCCL_DEV_COMM_REQUIREMENTS_INITIALIZER` available in NCCL 2.29. Pull Request resolved: pytorch#172400 Approved by: https://github.com/dzmitry-huba, https://github.com/fduwjj ghstack dependencies: pytorch#172163
Stack from ghstack (oldest at bottom):
Resolves #171827
enable_symm_mem_for_groupis for getting access to the store of a group.But the store can be also retrieved by
ProcessGroup.getStore()in C++.Thus makes little sense to require users to call
enable_symm_mem_for_group.cc @Skylion007 . Thanks for pointing out the inconvenience.