Skip to content

Commit 1c83214

Browse files
kwen2501pytorchmergebot
authored andcommitted
[SymmMem] Add multimem support for NCCL and NVSHMEM (pytorch#172185)
Pull Request resolved: pytorch#172185 Approved by: https://github.com/Skylion007, https://github.com/dzmitry-huba ghstack dependencies: pytorch#172163
1 parent 1cd056e commit 1c83214

5 files changed

Lines changed: 87 additions & 18 deletions

File tree

test/distributed/test_nccl.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,31 @@ def test_mempool_compute_ops(self):
411411
expected = torch.mm(x, w) * self.world_size
412412
self.assertEqual(y, expected)
413413

414+
@skip_but_pass_in_sandcastle_if(TEST_WITH_ROCM, "Skip NCCL tests for ROCm")
415+
@skip_but_pass_in_sandcastle_if(IS_WINDOWS, "NCCL doesn't support Windows")
416+
@skip_if_lt_x_gpu(2)
417+
@requires_nccl_version(
418+
(2, 29), "NCCL Symmetric Memory multicast support from nccl 2.29"
419+
)
420+
def test_multicast_ptr(self) -> None:
421+
"""
422+
Get the multicast pointer
423+
"""
424+
from torch._C._autograd import DeviceType
425+
from torch._C._distributed_c10d import _SymmetricMemory
426+
427+
symm_mem.set_backend("NCCL")
428+
torch.cuda.set_device(self.rank)
429+
c10d.all_reduce(torch.ones(1, device=self.device))
430+
group_name = c10d.group.WORLD.group_name
431+
432+
tensor = symm_mem.empty(1, device=self.device)
433+
handle = symm_mem.rendezvous(tensor, group_name)
434+
if _SymmetricMemory.has_multicast_support(DeviceType.CUDA, self.device.index):
435+
self.assertNotEqual(handle.multicast_ptr, 0)
436+
else:
437+
self.assertEqual(handle.multicast_ptr, 0)
438+
414439

415440
instantiate_device_type_tests(TestNCCL, globals(), only_for="cuda")
416441

test/distributed/test_nvshmem.py

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# To run:
44
# python test/distributed/test_nvshmem.py
55

6+
67
import os
78

89
import torch
@@ -31,12 +32,26 @@ def requires_nvshmem():
3132
)
3233

3334

35+
def has_nvls_support():
36+
if not symm_mem.is_nvshmem_available():
37+
return False
38+
39+
if os.environ.get("NVSHMEM_DISABLE_NVLS", "0") == "1":
40+
return False
41+
42+
# Set NVSHMEM as SymmMem backend before running the check
43+
symm_mem.set_backend("NVSHMEM")
44+
from torch._C._autograd import DeviceType
45+
from torch._C._distributed_c10d import _SymmetricMemory
46+
47+
return _SymmetricMemory.has_multicast_support(DeviceType.CUDA, 0)
48+
49+
3450
def requires_nvls():
35-
"""Skip test if NVLS (NVLink Switch) is not available."""
36-
nvls_disabled = os.environ.get("NVSHMEM_DISABLE_NVLS", "0") == "1"
51+
"""Skip test if NVLS (NVLink SHARP) is not available."""
3752
return skip_but_pass_in_sandcastle_if(
38-
nvls_disabled,
39-
"Test requires NVLS which is disabled via NVSHMEM_DISABLE_NVLS=1",
53+
not has_nvls_support(),
54+
"Test requires NVLink SHARP support",
4055
)
4156

4257

@@ -225,6 +240,20 @@ def test_get_remote_tensors(self) -> None:
225240
for peer, tensor in enumerate(remote_tensors):
226241
self.assertEqual(tensor, peer)
227242

243+
def test_multicast_ptr(self) -> None:
244+
"""
245+
Get the multicast pointer
246+
"""
247+
self._init_device()
248+
group_name = dist.group.WORLD.group_name
249+
250+
tensor = symm_mem.empty(1, device=self.device)
251+
handle = symm_mem.rendezvous(tensor, group_name)
252+
if has_nvls_support():
253+
self.assertNotEqual(handle.multicast_ptr, 0)
254+
else:
255+
self.assertEqual(handle.multicast_ptr, 0)
256+
228257
@skipIfRocm
229258
def test_nvshmem_put(self) -> None:
230259
self._init_device()

torch/csrc/distributed/c10d/symm_mem/CUDASymmetricMemory.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ bool CUDASymmetricMemory::has_multicast_support() {
142142
}
143143

144144
void* CUDASymmetricMemory::get_multicast_ptr() {
145-
return pai_->mc_addr_;
145+
return static_cast<char*>(pai_->mc_addr_) + offset_;
146146
}
147147

148148
size_t CUDASymmetricMemory::get_offset() {

torch/csrc/distributed/c10d/symm_mem/NCCLSymmetricMemory.cu

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,12 @@ class NCCLPeerAllocInfo : public c10::intrusive_ptr_target {
138138
arr_size,
139139
cudaMemcpyDeviceToHost));
140140
#endif
141+
142+
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 29, 0)
143+
C10D_NCCL_CHECK(
144+
ncclGetLsaMultimemDevicePointer(buffer_win_, offset_, &mc_addr_),
145+
"Failed to get multicast pointer");
146+
#endif
141147
}
142148

143149
// Exact copy is not needed / supported
@@ -159,6 +165,8 @@ class NCCLPeerAllocInfo : public c10::intrusive_ptr_target {
159165
std::string group_name_;
160166
ncclWindow_t buffer_win_;
161167
ncclWindow_t signal_handle_;
168+
// Multicast address
169+
void* mc_addr_ = nullptr;
162170

163171
friend class NCCLSymmetricMemory;
164172
};
@@ -195,13 +203,14 @@ size_t NCCLSymmetricMemory::get_buffer_size() {
195203
}
196204

197205
bool NCCLSymmetricMemory::has_multicast_support() {
198-
// TODO
199-
return false;
206+
return pai_->mc_addr_ != nullptr;
200207
}
201208

202209
void* NCCLSymmetricMemory::get_multicast_ptr() {
203-
// TODO
204-
return nullptr;
210+
if (!has_multicast_support()) {
211+
return nullptr;
212+
}
213+
return static_cast<char*>(pai_->mc_addr_) + offset_;
205214
}
206215

207216
void NCCLSymmetricMemory::barrier(int channel, size_t timeout_ms) {
@@ -312,8 +321,7 @@ class NCCLSymmetricMemoryAllocator : public SymmetricMemoryAllocator {
312321
}
313322

314323
bool has_multicast_support(int device_idx) override {
315-
// TODO
316-
return false;
324+
return device_has_multicast_support(device_idx);
317325
}
318326

319327
c10::DeviceType supported_device_type() override {

torch/csrc/distributed/c10d/symm_mem/NVSHMEMSymmetricMemory.cu

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ class NVSHMEMPeerAllocInfo : public c10::intrusive_ptr_target {
138138
signal_pads_.data(),
139139
arr_size,
140140
cudaMemcpyHostToDevice));
141+
142+
// Initialize multicast address
143+
// On unsupported platforms, this API returns a nullptr.
144+
mc_addr_ = nvshmemx_mc_ptr(NVSHMEM_TEAM_WORLD, base_ptr_);
141145
}
142146

143147
private:
@@ -151,6 +155,8 @@ class NVSHMEMPeerAllocInfo : public c10::intrusive_ptr_target {
151155
void** signal_pads_dev_;
152156
// Whether the world is within CUDA P2P only, not network
153157
bool world_within_cuda_p2p_;
158+
// Multicast address
159+
void* mc_addr_;
154160

155161
friend class NVSHMEMSymmetricMemory;
156162
};
@@ -206,13 +212,15 @@ class NVSHMEMSymmetricMemory : public SymmetricMemory {
206212
}
207213

208214
bool has_multicast_support() override {
209-
// TODO
210-
return false;
215+
// On unsupported platforms, this API returns a nullptr.
216+
return pai_->mc_addr_ != nullptr;
211217
}
212218

213219
void* get_multicast_ptr() override {
214-
// TODO
215-
return nullptr;
220+
if (!has_multicast_support()) {
221+
return nullptr;
222+
}
223+
return static_cast<char*>(pai_->mc_addr_) + offset_;
216224
}
217225

218226
size_t get_offset() override {
@@ -432,9 +440,8 @@ class NVSHMEMSymmetricMemoryAllocator : public SymmetricMemoryAllocator {
432440
};
433441

434442
bool has_multicast_support(int device_idx) override {
435-
// TODO
436-
return false;
437-
};
443+
return device_has_multicast_support(device_idx);
444+
}
438445

439446
c10::DeviceType supported_device_type() override {
440447
return c10::DeviceType::CUDA;

0 commit comments

Comments
 (0)