@@ -240,8 +240,9 @@ std::ostream& operator<<(
240240ProcessGroupNCCL::WorkNCCL::WorkNCCL (
241241 const std::vector<at::Device>& devices,
242242 int rank,
243- OpType opType)
244- : Work(rank, opType),
243+ OpType opType,
244+ const char * profilingTitle)
245+ : Work(rank, opType, profilingTitle),
245246 devices_ (devices),
246247 workStartTime_(std::chrono::steady_clock::now()) {
247248 // Creates the CUDA event wrappers
@@ -986,8 +987,9 @@ std::vector<at::Tensor> flatten_for_scatter_gather(
986987std::shared_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork (
987988 std::vector<at::Device> devices,
988989 int rank,
989- OpType opType) {
990- return std::make_shared<ProcessGroupNCCL::WorkNCCL>(devices, rank, opType);
990+ OpType opType,
991+ const char * profilingTitle) {
992+ return std::make_shared<ProcessGroupNCCL::WorkNCCL>(devices, rank, opType, profilingTitle);
991993}
992994
993995std::vector<at::Tensor> ProcessGroupNCCL::WorkNCCL::result () {
@@ -1031,7 +1033,8 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
10311033 Fn fn,
10321034 PreProcess pre ,
10331035 PostProcess post ,
1034- OpType opType) {
1036+ OpType opType,
1037+ const char * profilingTitle) {
10351038 const auto devices = getDeviceList (inputs);
10361039 const auto key = getKeyFromDevices (devices);
10371040 auto & ncclComms = getNCCLComm (key, devices, opType);
@@ -1040,13 +1043,25 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
10401043 syncStreams (devices, ncclEvents_[key], ncclStreams_[key]);
10411044
10421045 // Work itself will create the CUDA events on all GPUs of tensors
1043- auto work = initWork (devices, rank_, opType);
1046+ bool can_profile = outputs.size () == 1 ;
1047+ auto work = initWork (devices, rank_, opType, can_profile ? profilingTitle : nullptr );
10441048
10451049 // Store references to outputs and futureNCCLCallbackStream to be used by
10461050 // WorkNCCL::getFuture.
10471051 work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs);
10481052 work->futureNCCLCallbackStreams_ = futureNCCLCallbackStreams_;
10491053
1054+ if (work->recordFunctionEndCallback_ ) {
1055+ // recordFunctionEndCallback_ is normally called in fininsh() function by
1056+ // base class, but since finish is not called by WorkNCCL, we schedule this
1057+ // function to be run when work is done.
1058+ // Note when can_profile is false, profilingTitle is not provided and so,
1059+ // recordFunctionEndCallback_ is not set.
1060+ work->getFuture ()->addCallback (std::move (work->recordFunctionEndCallback_ ));
1061+ }
1062+
1063+
1064+
10501065 at::cuda::OptionalCUDAGuard gpuGuard;
10511066
10521067 pre (ncclStreams_[key]);
@@ -1175,14 +1190,16 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
11751190 std::vector<at::Tensor>& inputs,
11761191 std::vector<at::Tensor>& outputs,
11771192 Fn fn,
1178- OpType opType) {
1193+ OpType opType,
1194+ const char * profilingTitle) {
11791195 return collective (
11801196 inputs,
11811197 outputs,
11821198 fn,
11831199 [](std::vector<at::cuda::CUDAStream>&) {},
11841200 [](std::vector<at::cuda::CUDAStream>&) {},
1185- opType);
1201+ opType,
1202+ profilingTitle);
11861203}
11871204
11881205template <typename Fn>
@@ -1221,7 +1238,8 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::allreduce(
12211238 comm,
12221239 stream.stream ());
12231240 },
1224- OpType::ALLREDUCE);
1241+ OpType::ALLREDUCE,
1242+ " nccl:all_reduce" );
12251243}
12261244
12271245std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::allreduce_coalesced (
@@ -1252,7 +1270,8 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::broadcast(
12521270 comm,
12531271 stream.stream ());
12541272 },
1255- OpType::BROADCAST);
1273+ OpType::BROADCAST,
1274+ " nccl:broadcast" );
12561275}
12571276
12581277std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::reduce (
@@ -1278,7 +1297,8 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::reduce(
12781297 comm,
12791298 stream.stream ());
12801299 },
1281- OpType::REDUCE);
1300+ OpType::REDUCE,
1301+ " nccl:reduce" );
12821302}
12831303
12841304std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::allgather (
@@ -1322,7 +1342,8 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::allgather(
13221342 }
13231343 }
13241344 },
1325- OpType::ALLGATHER);
1345+ OpType::ALLGATHER,
1346+ " nccl:all_gather" );
13261347}
13271348
13281349std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::allgather_coalesced (
@@ -1375,7 +1396,8 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::reduce_scatter(
13751396 }
13761397 },
13771398 [&](std::vector<at::cuda::CUDAStream>& ncclStreams) {},
1378- OpType::REDUCE_SCATTER);
1399+ OpType::REDUCE_SCATTER,
1400+ " nccl:reduce_scatter" );
13791401}
13801402
13811403std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::barrier (
@@ -1448,7 +1470,8 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::alltoall_base(
14481470 stream);
14491471 return ncclSuccess;
14501472 },
1451- OpType::ALLTOALL_BASE);
1473+ OpType::ALLTOALL_BASE,
1474+ " nccl:all_to_all" );
14521475 } else {
14531476 c10d::checkSplitSizes (inputSplitSizes, inputTensor, size_);
14541477 c10d::checkSplitSizes (outputSplitSizes, outputTensor, size_);
@@ -1484,7 +1507,8 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::alltoall_base(
14841507 comm,
14851508 stream.stream ());
14861509 },
1487- OpType::ALLTOALL_BASE);
1510+ OpType::ALLTOALL_BASE,
1511+ " nccl:all_to_all" );
14881512 }
14891513}
14901514
0 commit comments