Skip to content

Commit 8e7898a

Browse files
kwen2501Skylion007
authored andcommitted
Add NCCL comm suspend, resume and memory stats (#176300)
Added three new APIs: `backend.suspend()`: free the memory held by the backend/communicator `backend.resume()`: restore the memory needed by the backend/communicator `backend.memory_stats()`: return memory usage info of the backend. ``` pytest -vs test/distributed/test_c10d_nccl.py -k test_get_memory_stats pytest -vs test/distributed/test_c10d_nccl.py -k test_suspend pytest -vs test/distributed/test_c10d_nccl.py -k test_resume ``` Pull Request resolved: #176300 Approved by: https://github.com/Skylion007 Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
1 parent 2979373 commit 8e7898a

8 files changed

Lines changed: 184 additions & 26 deletions

File tree

cmake/Modules/FindNCCL.cmake

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -57,31 +57,6 @@ if(NCCL_FOUND) # obtaining NCCL version and some sanity checks
5757
include(CheckCXXSymbolExists)
5858
check_cxx_symbol_exists(NCCL_VERSION_CODE nccl.h NCCL_VERSION_DEFINED)
5959

60-
# this condition check only works for non static NCCL linking
61-
if (NCCL_VERSION_DEFINED AND NOT USE_STATIC_NCCL)
62-
set(file "${PROJECT_BINARY_DIR}/detect_nccl_version.cc")
63-
file(WRITE ${file} "
64-
#include <iostream>
65-
#include <nccl.h>
66-
int main()
67-
{
68-
std::cout << NCCL_MAJOR << '.' << NCCL_MINOR << '.' << NCCL_PATCH << std::endl;
69-
int x;
70-
ncclGetVersion(&x);
71-
return x == NCCL_VERSION_CODE;
72-
}
73-
")
74-
try_run(NCCL_VERSION_MATCHED compile_result ${PROJECT_BINARY_DIR} ${file}
75-
RUN_OUTPUT_VARIABLE NCCL_VERSION_FROM_HEADER
76-
CMAKE_FLAGS "-DINCLUDE_DIRECTORIES=${NCCL_INCLUDE_DIRS}"
77-
LINK_LIBRARIES ${NCCL_LIBRARIES})
78-
if (NOT NCCL_VERSION_MATCHED)
79-
message(FATAL_ERROR "Found NCCL header version and library version do not match! \
80-
(include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES}) Please set NCCL_INCLUDE_DIR and NCCL_LIB_DIR manually.")
81-
endif()
82-
message(STATUS "NCCL version: ${NCCL_VERSION_FROM_HEADER}")
83-
endif ()
84-
8560
set (CMAKE_REQUIRED_INCLUDES ${OLD_CMAKE_REQUIRED_INCLUDES})
8661
message(STATUS "Found NCCL (include: ${NCCL_INCLUDE_DIRS}, library: ${NCCL_LIBRARIES})")
8762
mark_as_advanced(NCCL_ROOT_DIR NCCL_INCLUDE_DIRS NCCL_LIBRARIES)

test/distributed/test_c10d_nccl.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1975,6 +1975,69 @@ def test_block_current_stream(self):
19751975
work.wait()
19761976
torch.cuda.synchronize()
19771977

1978+
@requires_nccl()
1979+
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
1980+
def test_suspend(self):
1981+
"""Test that suspend can be called on the NCCL backend."""
1982+
store = c10d.FileStore(self.file_name, self.world_size)
1983+
device = torch.device(f"cuda:{self.rank}")
1984+
pg = self._create_process_group_nccl(store, self.opts(), device_id=device)
1985+
1986+
# Run a large collective to cause NCCL to allocate internal memory
1987+
dist.all_reduce(torch.zeros(1024 * 1024 * 512, device=device))
1988+
1989+
backend = pg._get_backend(device)
1990+
backend.suspend()
1991+
# Confirm that the memory is suspended
1992+
stats = backend.memory_stats()
1993+
self.assertEqual(stats["suspended"], 1)
1994+
1995+
@requires_nccl()
1996+
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
1997+
def test_get_memory_stats(self):
1998+
"""Test that get_memory_stats returns a dict of memory stats."""
1999+
store = c10d.FileStore(self.file_name, self.world_size)
2000+
device = torch.device(f"cuda:{self.rank}")
2001+
pg = self._create_process_group_nccl(store, self.opts(), device_id=device)
2002+
2003+
# Run a large collective to cause NCCL to allocate internal memory
2004+
dist.all_reduce(torch.zeros(1024 * 1024 * 512, device=device))
2005+
2006+
backend = pg._get_backend(device)
2007+
stats = backend.memory_stats()
2008+
self.assertIsInstance(stats, dict)
2009+
for key in ("suspend", "suspended", "persist", "total"):
2010+
self.assertIn(key, stats)
2011+
print(stats)
2012+
2013+
@requires_nccl()
2014+
@skip_but_pass_in_sandcastle_if(not TEST_MULTIGPU, "NCCL test requires 2+ GPUs")
2015+
def test_resume(self):
2016+
"""Test the full suspend/resume cycle with collectives."""
2017+
store = c10d.FileStore(self.file_name, self.world_size)
2018+
device = torch.device(f"cuda:{self.rank}")
2019+
pg = self._create_process_group_nccl(store, self.opts(), device_id=device)
2020+
backend = pg._get_backend(device)
2021+
2022+
# Run a large collective to cause NCCL to allocate internal memory
2023+
dist.all_reduce(torch.zeros(1024 * 1024 * 512, device=device))
2024+
2025+
# Suspend (release memory)
2026+
backend.suspend()
2027+
# Resume
2028+
backend.resume()
2029+
# Confirm that the memory is resumed
2030+
stats = backend.memory_stats()
2031+
self.assertEqual(stats["suspended"], 0)
2032+
2033+
# Run a collective to verify the communicator still works
2034+
tensor = torch.ones(1024, device=device, dtype=torch.float32)
2035+
dist.all_reduce(tensor)
2036+
expected = torch.full(
2037+
(1024,), self.world_size, device=device, dtype=torch.float32
2038+
)
2039+
self.assertEqual(tensor, expected)
2040+
19782041

19792042
class DistributedDataParallelTest(
19802043
test_c10d_common.CommonDistributedDataParallelTest, MultiProcessTestCase

torch/csrc/distributed/c10d/Backend.hpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#pragma once
22

33
#include <memory>
4+
#include <string>
5+
#include <unordered_map>
46
#include <utility>
57
#include <vector>
68

@@ -509,6 +511,26 @@ class TORCH_API Backend : public torch::CustomClassHolder {
509511
// normal shutdown.
510512
virtual void shutdown() {}
511513

514+
// APIs related to memory offload
515+
virtual void suspend() {
516+
TORCH_CHECK(
517+
false,
518+
c10::str("Backend ", getBackendName(), " does not support suspend"));
519+
}
520+
521+
virtual void resume() {
522+
TORCH_CHECK(
523+
false,
524+
c10::str("Backend ", getBackendName(), " does not support resume"));
525+
}
526+
527+
virtual std::unordered_map<std::string, uint64_t> getMemoryStats() {
528+
TORCH_CHECK(
529+
false,
530+
c10::str(
531+
"Backend ", getBackendName(), " does not support getMemoryStats"));
532+
}
533+
512534
protected:
513535
// Implementations of this interface need to call this to setup
514536
// appropriate logging etc.

torch/csrc/distributed/c10d/NCCLUtils.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,54 @@ std::string NCCLComm::repr() const {
595595
return c10::str((void*)ncclComm_);
596596
}
597597

598+
void NCCLComm::suspend() {
599+
#ifdef NCCL_HAS_COMM_OFFLOAD
600+
LockType lock(mutex_);
601+
at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex_);
602+
auto comm = getNcclComm();
603+
C10D_NCCL_CHECK(ncclCommSuspend(comm, NCCL_SUSPEND_MEM), std::nullopt);
604+
#else
605+
TORCH_CHECK(false, "suspend() requires NCCL 2.29.7 or later");
606+
#endif
607+
}
608+
609+
void NCCLComm::resume() {
610+
#ifdef NCCL_HAS_COMM_OFFLOAD
611+
LockType lock(mutex_);
612+
at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex_);
613+
auto comm = getNcclComm();
614+
C10D_NCCL_CHECK(ncclCommResume(comm), std::nullopt);
615+
#else
616+
TORCH_CHECK(false, "resume() requires NCCL 2.29.7 or later");
617+
#endif
618+
}
619+
620+
std::unordered_map<std::string, uint64_t> NCCLComm::getMemoryStats() {
621+
#ifdef NCCL_HAS_COMM_OFFLOAD
622+
LockType lock(mutex_);
623+
at::cuda::OptionalCUDAGuard gpuGuard(deviceIndex_);
624+
auto comm = getNcclComm();
625+
uint64_t suspend, suspended, persist, total;
626+
C10D_NCCL_CHECK(
627+
ncclCommMemStats(comm, ncclStatGpuMemSuspend, &suspend), std::nullopt);
628+
C10D_NCCL_CHECK(
629+
ncclCommMemStats(comm, ncclStatGpuMemSuspended, &suspended),
630+
std::nullopt);
631+
C10D_NCCL_CHECK(
632+
ncclCommMemStats(comm, ncclStatGpuMemPersist, &persist), std::nullopt);
633+
C10D_NCCL_CHECK(
634+
ncclCommMemStats(comm, ncclStatGpuMemTotal, &total), std::nullopt);
635+
return {
636+
{"suspend", suspend},
637+
{"suspended", suspended},
638+
{"persist", persist},
639+
{"total", total},
640+
};
641+
#else
642+
TORCH_CHECK(false, "getMemoryStats() requires NCCL 2.29.7 or later");
643+
#endif
644+
}
645+
598646
#if (defined(IS_NCCLX) || defined(USE_ROCM)) && defined(NCCL_COMM_DUMP)
599647
std::unordered_map<std::string, std::string> NCCLComm::ncclCommDump() {
600648
std::unordered_map<std::string, std::string> dump;

torch/csrc/distributed/c10d/NCCLUtils.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ static_assert(
9494
#define NCCL_HAS_COMM_SHRINK
9595
#endif
9696

97+
#if NCCL_VERSION_CODE >= NCCL_VERSION(2, 29, 7)
98+
#define NCCL_HAS_COMM_OFFLOAD
99+
#endif
100+
97101
// Macro to throw on a non-successful NCCL return value.
98102
#define C10D_NCCL_CHECK(cmd, failureReason) \
99103
do { \
@@ -376,6 +380,13 @@ class NCCLComm {
376380

377381
std::string repr() const;
378382

383+
// APIs related to memory offload (require NCCL 2.29.7+ at runtime)
384+
void suspend();
385+
386+
void resume();
387+
388+
std::unordered_map<std::string, uint64_t> getMemoryStats();
389+
379390
friend class ProcessGroupNCCL;
380391

381392
protected:

torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3367,6 +3367,30 @@ uint64_t ProcessGroupNCCL::getCommSplitCounter() const {
33673367
return ret;
33683368
}
33693369

3370+
void ProcessGroupNCCL::suspend() {
3371+
auto device = at::Device(at::kCUDA, guessDeviceId());
3372+
std::string deviceKey = getKeyFromDevice(device);
3373+
auto ncclComm = getNCCLComm(deviceKey);
3374+
TORCH_CHECK(ncclComm != nullptr, "NCCL communicator not initialized.");
3375+
ncclComm->suspend();
3376+
}
3377+
3378+
void ProcessGroupNCCL::resume() {
3379+
auto device = at::Device(at::kCUDA, guessDeviceId());
3380+
std::string deviceKey = getKeyFromDevice(device);
3381+
auto ncclComm = getNCCLComm(deviceKey);
3382+
TORCH_CHECK(ncclComm != nullptr, "NCCL communicator not initialized.");
3383+
ncclComm->resume();
3384+
}
3385+
3386+
std::unordered_map<std::string, uint64_t> ProcessGroupNCCL::getMemoryStats() {
3387+
auto device = at::Device(at::kCUDA, guessDeviceId());
3388+
std::string deviceKey = getKeyFromDevice(device);
3389+
auto ncclComm = getNCCLComm(deviceKey);
3390+
TORCH_CHECK(ncclComm != nullptr, "NCCL communicator not initialized.");
3391+
return ncclComm->getMemoryStats();
3392+
}
3393+
33703394
namespace {
33713395

33723396
// Check validity of tensor

torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,6 +1051,13 @@ class TORCH_API ProcessGroupNCCL : public Backend {
10511051

10521052
void setEnableNanCheck(bool enableNanCheck);
10531053

1054+
// APIs related to memory offload (require NCCL 2.29.7+ at runtime)
1055+
void suspend() override;
1056+
1057+
void resume() override;
1058+
1059+
std::unordered_map<std::string, uint64_t> getMemoryStats() override;
1060+
10541061
protected:
10551062
uint64_t getWatchdogHeartbt() const;
10561063

torch/csrc/distributed/c10d/init.cpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3156,7 +3156,15 @@ The hook must have the following signature:
31563156
py::arg("device"),
31573157
py::call_guard<py::gil_scoped_release>())
31583158
.def_property_readonly(
3159-
"mem_allocator", &::c10d::Backend::getMemAllocator);
3159+
"mem_allocator", &::c10d::Backend::getMemAllocator)
3160+
.def("suspend", &::c10d::Backend::suspend)
3161+
.def("resume", &::c10d::Backend::resume)
3162+
.def("memory_stats", &::c10d::Backend::getMemoryStats, R"(
3163+
Get the memory statistics of the backend.
3164+
3165+
Returns:
3166+
A dictionary containing the memory statistics.
3167+
)");
31603168

31613169
// base Backend::Options binding
31623170
// TODO: Maybe we can consider how to merge this with

0 commit comments

Comments
 (0)