Skip to content

[SymmMem] Deprecate enable_symm_mem_for_group#172163

Closed
kwen2501 wants to merge 3 commits intogh/kwen2501/305/basefrom
gh/kwen2501/305/head
Closed

[SymmMem] Deprecate enable_symm_mem_for_group#172163
kwen2501 wants to merge 3 commits intogh/kwen2501/305/basefrom
gh/kwen2501/305/head

Conversation

@kwen2501
Copy link
Copy Markdown
Collaborator

@kwen2501 kwen2501 commented Jan 10, 2026

Stack from ghstack (oldest at bottom):

Resolves #171827

enable_symm_mem_for_group is for getting access to the store of a group.
But the store can be also retrieved by ProcessGroup.getStore() in C++.
Thus makes little sense to require users to call enable_symm_mem_for_group.

cc @Skylion007 . Thanks for pointing out the inconvenience.

[ghstack-poisoned]
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Jan 10, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/172163

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 945de0e with merge base 8cfe6f1 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

kwen2501 added a commit that referenced this pull request Jan 10, 2026
[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Jan 10, 2026
Comment on lines -111 to -131
// For logging only
static int exchanged_n_times = 0;
auto global_rank = get_group_info("0").rank;
auto store = group_info.store;
// Exchange rank to global rank mapping for this group.
// If it is already available, skip the exchange.
if (group_info.rank_to_global_rank.empty()) {
group_info.rank_to_global_rank =
storeExchange.all_gather(store, rank_, world_size_, global_rank);
exchanged_n_times++;
if (rank_ == 0) {
LOG(INFO) << "[rank " << rank_ << ']'
<< " rank_to_global_rank: " << group_info.rank_to_global_rank
<< ", group_name: " << group_name_
<< ", exchanged_n_times: " << exchanged_n_times;
}
}

TORCH_INTERNAL_ASSERT(!group_info.rank_to_global_rank.empty());
rank_to_global_rank_ = group_info.rank_to_global_rank;

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NCCL has no need for this mapping.

Comment on lines -286 to -287
auto group_info = get_group_info("0");
auto store = group_info.store;
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing dead code

Comment on lines -122 to -129

rank_to_global_rank_dev_ = reinterpret_cast<int*>(
c10::cuda::CUDACachingAllocator::raw_alloc(sizeof(int) * world_size_));
AT_CUDA_CHECK(cudaMemcpy(
rank_to_global_rank_dev_,
rank_to_global_rank_.data(),
sizeof(int) * world_size_,
cudaMemcpyHostToDevice));
Copy link
Copy Markdown
Collaborator Author

@kwen2501 kwen2501 Jan 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This device-side mapping doesn't need to be repeatedly allocated per-handle. I made it per-group above.

Comment on lines -148 to -151
// Note this field is not automatically populated by set_group_info(). If a
// SymmetricMemory implementation needs to use it, it must be populated by a
// call to exchange_global_ranks() first.
std::vector<int> rank_to_global_rank;
Copy link
Copy Markdown
Collaborator Author

@kwen2501 kwen2501 Jan 10, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only needed by NVSHMEM. I made it an internal impl.

@kwen2501 kwen2501 added the release notes: distributed (symm_mem) release note label for symmetric memory label Jan 10, 2026
Copy link
Copy Markdown
Collaborator

@Skylion007 Skylion007 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely better than before

};

// A map from group name to rank-to-global rank mapping
static std::unordered_map<std::string, std::vector<int>> rank_to_global_rank_map{};
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason this map needs to be reference outside the allocator instance?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is referenced by two classes now:
NVSHMEMPeerAllocInfo and NVSHMEMSymmetricMemory

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bad code smell right there for a global static map



@deprecated(
"`enable_symm_mem_for_group` is deprecated. There is no need to call this function anymore."
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specify warning type in deprecated decorator as FutureWarning since it will likely be removed soon

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thanks

[ghstack-poisoned]
kwen2501 added a commit that referenced this pull request Jan 10, 2026
@kwen2501
Copy link
Copy Markdown
Collaborator Author

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Jan 10, 2026
@pytorchmergebot
Copy link
Copy Markdown
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

hinriksnaer pushed a commit to hinriksnaer/pytorch that referenced this pull request Jan 12, 2026
Resolves pytorch#171827

`enable_symm_mem_for_group` is for getting access to the store of a group.
But the store can be also retrieved by `ProcessGroup.getStore()` in C++.
Thus makes little sense to require users to call `enable_symm_mem_for_group`.

Pull Request resolved: pytorch#172163
Approved by: https://github.com/Skylion007
skpark-rh pushed a commit to skpark-rh/pytorch that referenced this pull request Jan 12, 2026
Resolves pytorch#171827

`enable_symm_mem_for_group` is for getting access to the store of a group.
But the store can be also retrieved by `ProcessGroup.getStore()` in C++.
Thus makes little sense to require users to call `enable_symm_mem_for_group`.

Pull Request resolved: pytorch#172163
Approved by: https://github.com/Skylion007
pytorchmergebot pushed a commit that referenced this pull request Jan 12, 2026
pytorchmergebot pushed a commit that referenced this pull request Jan 14, 2026
Resolves #172050

Two motivations:
- Give better UX and perf to users who explicitly use `symm_mem.empty()`.
- Simplify the code generated by Inductor, i.e. `symm_mem.empty()` would automatically reuse memory, rather than requiring Inductor to bookkeep it.

The MemPool infra for all CUDA backends (`CUDA`, `NVSHMEM`, `NCCL`) has been built previously.
Pull Request resolved: #172292
Approved by: https://github.com/ngimel, https://github.com/dzmitry-huba
ghstack dependencies: #172163
pytorchmergebot pushed a commit that referenced this pull request Jan 14, 2026
pytorchmergebot pushed a commit that referenced this pull request Jan 14, 2026
Fixes #172398

`NCCL_DEV_COMM_REQUIREMENTS_INITIALIZER` available in NCCL 2.29.
Pull Request resolved: #172400
Approved by: https://github.com/dzmitry-huba, https://github.com/fduwjj
ghstack dependencies: #172163
mattteochen pushed a commit to mattteochen/pytorch that referenced this pull request Jan 15, 2026
Resolves pytorch#172050

Two motivations:
- Give better UX and perf to users who explicitly use `symm_mem.empty()`.
- Simplify the code generated by Inductor, i.e. `symm_mem.empty()` would automatically reuse memory, rather than requiring Inductor to bookkeep it.

The MemPool infra for all CUDA backends (`CUDA`, `NVSHMEM`, `NCCL`) has been built previously.
Pull Request resolved: pytorch#172292
Approved by: https://github.com/ngimel, https://github.com/dzmitry-huba
ghstack dependencies: pytorch#172163
mattteochen pushed a commit to mattteochen/pytorch that referenced this pull request Jan 15, 2026
mattteochen pushed a commit to mattteochen/pytorch that referenced this pull request Jan 15, 2026
Fixes pytorch#172398

`NCCL_DEV_COMM_REQUIREMENTS_INITIALIZER` available in NCCL 2.29.
Pull Request resolved: pytorch#172400
Approved by: https://github.com/dzmitry-huba, https://github.com/fduwjj
ghstack dependencies: pytorch#172163
pytorchmergebot pushed a commit that referenced this pull request Jan 16, 2026
Resolves #172050

Two motivations:
- Give better UX and perf to users who explicitly use `symm_mem.empty()`.
- Simplify the code generated by Inductor, i.e. `symm_mem.empty()` would automatically reuse memory, rather than requiring Inductor to bookkeep it.

The MemPool infra for all CUDA backends (`CUDA`, `NVSHMEM`, `NCCL`) has been built previously.
Pull Request resolved: #172292
Approved by: https://github.com/ngimel, https://github.com/dzmitry-huba
ghstack dependencies: #172163
pytorchmergebot pushed a commit that referenced this pull request Jan 16, 2026
@github-actions github-actions bot deleted the gh/kwen2501/305/head branch February 10, 2026 02:24
gderossi pushed a commit to gderossi/pytorch that referenced this pull request Feb 10, 2026
Fixes pytorch#172398

`NCCL_DEV_COMM_REQUIREMENTS_INITIALIZER` available in NCCL 2.29.
Pull Request resolved: pytorch#172400
Approved by: https://github.com/dzmitry-huba, https://github.com/fduwjj
ghstack dependencies: pytorch#172163
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/h100-symm-mem ciflow/trunk Trigger trunk jobs on your pull request Merged open source release notes: distributed (c10d) release notes category release notes: distributed (symm_mem) release note label for symmetric memory

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants