Skip to content

Commit 9b53d31

Browse files
Wanchao Liangpytorchmergebot
authored andcommitted
Implement gather primitive for ProcessGroupNCCL (#66745)
Summary: Pull Request resolved: #66745 This PR implement NCCL gather and add gather to ProcessGroupNCCL using nccl send/recv api. NCCL doesn’t directly provide primitives for gather, so we need to be implemented on top of NCCL’s send/recv API. 1. In ProcessGroupNCCL.cpp, the outputTensors are first flattened, then inputTensors and outputFlattened are passed by the collective class to gather() function in nccl.cpp. 1. In nccl.cpp, gather is implemented using ncclSend/ncclRecv: all the ranks send inputTensor to the root rank, and the root rank uses a for loop to receive these inputTensors. ghstack-source-id: 147754838 Test Plan: test_gather_ops test_gather_checks test_gather_stress Reviewed By: pritamdamania87 Differential Revision: D29616361 fbshipit-source-id: b500d9b8e67113194c5cc6575fb0e5d806dc7782 (cherry picked from commit d560ee7)
1 parent 0a8b391 commit 9b53d31

7 files changed

Lines changed: 276 additions & 38 deletions

File tree

docs/source/distributed.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ MPI supports CUDA only if the implementation used to build PyTorch supports it.
3737
+----------------+-----+-----+-----+-----+-----+-----+
3838
| all_gather |||| ? |||
3939
+----------------+-----+-----+-----+-----+-----+-----+
40-
| gather |||| ? || |
40+
| gather |||| ? || |
4141
+----------------+-----+-----+-----+-----+-----+-----+
4242
| scatter |||| ? |||
4343
+----------------+-----+-----+-----+-----+-----+-----+

test/distributed/test_c10d_nccl.py

Lines changed: 136 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ def test_allgather_ops(self):
422422

423423
def allgather(output_ts, input_ts):
424424
work = pg.allgather(output_ts, input_ts)
425-
work.wait()
425+
return work.wait()
426426

427427
tensors = [torch.empty(2, 2).fill_(2).cuda(device=i) for i in local_device_ids]
428428
output_tensors = []
@@ -435,7 +435,7 @@ def allgather(output_ts, input_ts):
435435
output_tensors.append([t.cuda(device=gpu) for t in output_per_gpu])
436436
expected_output.append([t.cuda(device=gpu) for t in expected_per_gpu])
437437

438-
allgather(output_tensors, tensors)
438+
result = allgather(output_tensors, tensors)
439439

440440
# Verification
441441
self.assertEqual(output_tensors, expected_output)
@@ -495,6 +495,140 @@ def allgather_base(output_t, input_t):
495495
# fails the check because the dtype is different
496496
allgather_base(output_t, tensor)
497497

498+
@requires_nccl()
499+
@sandcastle_skip_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
500+
def test_gather_ops(self):
501+
store = c10d.FileStore(self.file_name, self.world_size)
502+
pg = self._create_process_group_nccl(store, self.opts())
503+
local_device_ids = self.rank_to_GPU[self.rank]
504+
num_gpus = len(local_device_ids)
505+
506+
def gather(output_t, input_t, rootRank):
507+
opts = c10d.GatherOptions()
508+
opts.rootRank = rootRank
509+
if rootRank == self.rank:
510+
work = pg.gather(output_t, input_t, opts)
511+
else:
512+
work = pg.gather([], input_t, opts)
513+
work.wait()
514+
515+
# init input
516+
tensors = []
517+
for device_id in local_device_ids:
518+
tensors.append(torch.tensor([self.rank]).cuda(device_id))
519+
520+
# init output
521+
output_ts = []
522+
for idx in range(num_gpus):
523+
gpu_idx = local_device_ids[idx]
524+
output_ts.append([])
525+
for rank in range(self.world_size):
526+
output_ts[idx].append(torch.tensor([-1]).cuda(gpu_idx))
527+
528+
expected = [[torch.tensor([rank]) for rank in range(self.world_size)]]
529+
for rank in range(self.world_size):
530+
gather(output_ts, tensors, rank)
531+
if rank == self.rank:
532+
self.assertEqual(expected, output_ts)
533+
534+
@requires_nccl()
535+
@sandcastle_skip_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
536+
def test_gather_stress(self):
537+
store = c10d.FileStore(self.file_name, self.world_size)
538+
pg = self._create_process_group_nccl(store, self.opts())
539+
local_device_ids = self.rank_to_GPU[self.rank]
540+
num_gpus = len(local_device_ids)
541+
542+
def gather(output_t, input_t, rootRank):
543+
opts = c10d.GatherOptions()
544+
opts.rootRank = rootRank
545+
if rootRank == self.rank:
546+
work = pg.gather(output_t, input_t, opts)
547+
else:
548+
work = pg.gather([], input_t, opts)
549+
work.wait()
550+
551+
stress_length = 1000
552+
553+
# init input
554+
tensors = []
555+
for i in range(stress_length):
556+
tensors.append([])
557+
for device_id in local_device_ids:
558+
tensors[i].append(torch.tensor([self.rank]).cuda(device_id))
559+
560+
# init output
561+
output_ts = []
562+
for i in range(stress_length):
563+
output_ts.append([[] for _ in range(num_gpus)])
564+
for idx, ls in enumerate(output_ts[i]):
565+
gpu_idx = local_device_ids[idx]
566+
for _ in range(self.world_size):
567+
ls.append(torch.tensor([-1]).cuda(gpu_idx))
568+
569+
expected = [[torch.tensor([rank]) for rank in range(self.world_size)]]
570+
for i in range(stress_length):
571+
for rank in range(self.world_size):
572+
gather(output_ts[i], tensors[i], rank)
573+
# Verification
574+
if rank == self.rank:
575+
self.assertEqual(output_ts[i], expected)
576+
577+
@requires_nccl()
578+
@sandcastle_skip_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
579+
def test_gather_checks(self):
580+
store = c10d.FileStore(self.file_name, self.world_size)
581+
pg = self._create_process_group_nccl(store, self.opts())
582+
local_device_ids = self.rank_to_GPU[self.rank]
583+
num_gpus = len(local_device_ids)
584+
585+
# init input
586+
tensors = []
587+
for device_id in local_device_ids:
588+
tensors.append(torch.tensor([self.rank]).cuda(device_id))
589+
590+
# init output
591+
output_ts = []
592+
for idx in range(num_gpus):
593+
gpu_idx = local_device_ids[idx]
594+
output_ts.append([])
595+
for rank in range(self.world_size):
596+
output_ts[idx].append(torch.tensor([-1]).cuda(gpu_idx))
597+
598+
with self.assertRaisesRegex(RuntimeError, "invalid root rank"):
599+
opts = c10d.GatherOptions()
600+
opts.rootRank = -1
601+
pg.gather(output_ts, tensors, opts)
602+
603+
with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
604+
pg.gather(output_ts, tensors, 0)
605+
606+
with self.assertRaisesRegex(RuntimeError, "invalid root rank"):
607+
opts = c10d.GatherOptions()
608+
opts.rootRank = self.world_size
609+
pg.gather(output_ts, tensors, opts)
610+
611+
with self.assertRaisesRegex(
612+
RuntimeError, "Tensor list must be nonempty"
613+
):
614+
opts = c10d.GatherOptions()
615+
opts.rootRank = 0
616+
pg.gather(output_ts, [], opts)
617+
618+
with self.assertRaisesRegex(
619+
RuntimeError, "Tensors must be on distinct GPU devices"
620+
):
621+
# init input
622+
tensors2 = []
623+
for device_id in local_device_ids:
624+
tensors2.append(torch.tensor([self.rank]).cuda(device_id))
625+
tensors2.append(torch.tensor([self.rank]).cuda(device_id))
626+
627+
opts = c10d.GatherOptions()
628+
opts.rootRank = 0
629+
pg.gather(output_ts, tensors2, opts)
630+
631+
498632
@requires_nccl()
499633
@sandcastle_skip_if(torch.cuda.device_count() < 2, "NCCL test requires 2+ GPUs")
500634
def test_reduce_scatter_base_basics(self):

torch/csrc/cuda/nccl.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -809,6 +809,54 @@ void recv(
809809
#endif
810810
}
811811

812+
813+
void gather(
814+
const at::Tensor& inputs,
815+
std::vector<at::Tensor>& outputs,
816+
ncclComm_t _comm,
817+
at::cuda::CUDAStream& stream,
818+
int32_t root) {
819+
#ifdef USE_NCCL
820+
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && (NCCL_MAJOR * 10 + NCCL_MINOR) >= 27
821+
using namespace torch::cuda::nccl::detail;
822+
823+
auto comm = to_nccl_comm(_comm);
824+
int numranks, cur_rank;
825+
NCCL_CHECK(ncclCommCount(comm, &numranks));
826+
NCCL_CHECK(ncclCommUserRank(comm, &cur_rank));
827+
828+
size_t count = inputs.numel();
829+
auto type = to_nccl_data_type(inputs);
830+
const auto* sendbuff = reinterpret_cast<char*>(inputs.data_ptr());
831+
832+
NCCL_CHECK(ncclGroupStart());
833+
834+
if (cur_rank == root)
835+
{
836+
for (int r = 0; r < numranks; r++)
837+
{
838+
if (r != root) {
839+
auto* recvbuff = reinterpret_cast<char*>(outputs[r].data_ptr());
840+
NCCL_CHECK(ncclRecv(recvbuff, count, type, r, comm, stream));
841+
} else {
842+
// on its own rank, simply copy from the input
843+
outputs[r].copy_(inputs);
844+
}
845+
}
846+
} else {
847+
NCCL_CHECK(ncclSend(sendbuff, count, type, root, comm, stream));
848+
}
849+
NCCL_CHECK(ncclGroupEnd());
850+
851+
#else
852+
AT_ERROR("gather is only supported for NCCL lib version >= 2.7.0");
853+
#endif
854+
#else
855+
AT_ERROR("PyTorch built without NCCL support");
856+
#endif
857+
}
858+
859+
812860
} // namespace nccl
813861
} // namespace cuda
814862
} // namespace torch

torch/csrc/cuda/nccl.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,13 @@ TORCH_CUDA_CPP_API void all_gather(
150150
const stream_list& streams = {},
151151
const comm_list& user_comms = {});
152152

153+
TORCH_CUDA_CPP_API void gather(
154+
const at::Tensor& inputs,
155+
std::vector<at::Tensor>& outputs,
156+
ncclComm_t comm,
157+
at::cuda::CUDAStream& stream,
158+
int32_t root = 0);
159+
153160
TORCH_CUDA_CPP_API void all2all_single_equal_split(
154161
at::Tensor& input,
155162
at::Tensor& output,

torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2213,10 +2213,76 @@ void ProcessGroupNCCL::groupEnd() {
22132213
}
22142214

22152215
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::gather(
2216-
std::vector<std::vector<at::Tensor>>& /* unused */,
2217-
std::vector<at::Tensor>& /* unused */,
2218-
const GatherOptions& /* unused */) {
2219-
TORCH_CHECK(false, "ProcessGroupNCCL does not support gather");
2216+
std::vector<std::vector<at::Tensor>>& outputTensors,
2217+
std::vector<at::Tensor>& inputTensors,
2218+
const GatherOptions& opts) {
2219+
static auto invalidArgument = [](const std::string& msg) {
2220+
TORCH_CHECK(false, "ProcessGroupNCCL::gather: " + msg);
2221+
};
2222+
2223+
assertRootRank(invalidArgument, opts.rootRank, size_);
2224+
check_gpu_tensors_different_devices(inputTensors);
2225+
assertSingleElementInput(invalidArgument, inputTensors);
2226+
2227+
// @lint-ignore CLANGTIDY
2228+
auto tensor = inputTensors.back();
2229+
RECORD_PARAM_COMMS(
2230+
rank_, // rank
2231+
"gather", // colName
2232+
tensor.numel(), // inSize
2233+
tensor.numel() *
2234+
this->getSize(), // outSize
2235+
tensor.scalar_type(), // dType
2236+
std::vector<int64_t>(), // inSplitSizes
2237+
std::vector<int64_t>()); // outSplitSize
2238+
2239+
std::vector<at::Tensor> outputs;
2240+
2241+
if (getRank() == opts.rootRank) {
2242+
if (outputTensors.size() != 1) {
2243+
std::stringstream ss;
2244+
ss << "requires a single-element output list containing a list with "
2245+
<< getSize() << " tensors.";
2246+
invalidArgument(ss.str());
2247+
} else if (outputTensors[0].size() != static_cast<size_t>(getSize())) {
2248+
std::stringstream ss;
2249+
ss << "Incorrect output list size " << outputTensors[0].size()
2250+
<< ". Output list size should be " << getSize()
2251+
<< ", same as size of the process group.";
2252+
invalidArgument(ss.str());
2253+
}
2254+
2255+
const auto& options = inputTensors[0].options();
2256+
const auto& sizes = inputTensors[0].sizes();
2257+
assertTypeAndSizesMatch(invalidArgument, outputTensors[0], options, sizes);
2258+
outputs = outputTensors[0];
2259+
} else {
2260+
// if not in the root rank, initialize outputs as empty list
2261+
if (outputTensors.size() != 0) {
2262+
invalidArgument("requires empty output on non-root");
2263+
}
2264+
outputs = {};
2265+
}
2266+
2267+
return collective(
2268+
inputTensors,
2269+
outputs,
2270+
[&](at::Tensor& /* unused */,
2271+
at::Tensor& /* unused */,
2272+
ncclComm_t comm,
2273+
at::cuda::CUDAStream& stream) {
2274+
const auto root = opts.rootRank;
2275+
if (getRank() == root) {
2276+
for(auto output: outputs) {
2277+
c10::cuda::CUDACachingAllocator::recordStream(
2278+
output.storage().data_ptr(), stream);
2279+
}
2280+
}
2281+
torch::cuda::nccl::gather(inputTensors[0], outputs, comm, stream, root);
2282+
return ncclSuccess;
2283+
},
2284+
OpType::GATHER,
2285+
"nccl:gather");
22202286
}
22212287

22222288
c10::intrusive_ptr<ProcessGroup::Work> ProcessGroupNCCL::scatter(

torch/testing/_internal/common_distributed.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ class DistTestCases:
6969
# Backends that do not support a specific collective
7070
skip_collective = {}
7171
skip_collective["allgather_coalesced"] = {"nccl", "mpi"}
72-
skip_collective["gather"] = {"nccl"}
7372
skip_collective["scatter"] = {"nccl"}
7473
skip_collective["reduce"] = set()
7574
skip_collective["sendrecv anysource"] = {"nccl"}

0 commit comments

Comments
 (0)