Skip to content

Commit 160db3d

Browse files
mrzzdfacebook-github-bot
authored andcommitted
Adding profiling capability to c++ ddp collective functions (#46471)
Summary: Pull Request resolved: #46471 ghstack-source-id: 116018837 Test Plan: Added unit tests: buck test mode/dev-nosan caffe2/test/distributed:distributed_gloo_fork buck test mode/dev-nosan caffe2/test/distributed:distributed_nccl_fork Reviewed By: rohan-varma Differential Revision: D23948397 fbshipit-source-id: 6d93a370aff26bf96c39e5d78a2492c5142a9156
1 parent 1aeefcd commit 160db3d

10 files changed

Lines changed: 252 additions & 66 deletions

File tree

torch/csrc/distributed/c10d/init.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1086,8 +1086,7 @@ that adds a prefix to each key inserted to the store.
10861086
&::c10d::ProcessGroup::Work::wait,
10871087
py::arg("timeout") = kNoTimeout,
10881088
py::call_guard<py::gil_scoped_release>())
1089-
.def(
1090-
"get_future",
1089+
.def("get_future",
10911090
[](::c10d::ProcessGroup::Work& work)
10921091
-> std::shared_ptr<jit::PythonFutureWrapper> {
10931092
return std::make_shared<jit::PythonFutureWrapper>(work.getFuture());

torch/lib/c10d/ProcessGroup.cpp

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
#include <c10d/ProcessGroup.hpp>
2+
#include <ATen/ThreadLocalState.h>
3+
24

35
#include <c10/util/Logging.h>
46

@@ -51,10 +53,20 @@ bool isP2POp(OpType opType) {
5153
opType == OpType::RECVANYSOURCE;
5254
}
5355

54-
ProcessGroup::Work::Work() : rank_(-1), opType_(OpType::UNKNOWN) {}
5556

56-
ProcessGroup::Work::Work(int rank, OpType opType)
57-
: rank_(rank), opType_(opType) {}
57+
ProcessGroup::Work::Work(int rank, OpType opType, const char* profilingTitle)
58+
: rank_(rank), opType_(opType) {
59+
if (profilingTitle != nullptr) {
60+
auto recordingFunction = std::make_shared<at::RecordFunction>(at::RecordScope::USER_SCOPE);
61+
if (recordingFunction->active) {
62+
recordingFunction->before(profilingTitle, {});
63+
std::function<void()> end_handler = [this, recordingFunction]() {
64+
recordingFunction->end();
65+
};
66+
recordFunctionEndCallback_ = at::wrapPropagateTLSState(end_handler);
67+
}
68+
}
69+
}
5870

5971
OpType ProcessGroup::Work::retrieveOpType() {
6072
return opType_;
@@ -123,6 +135,10 @@ void ProcessGroup::Work::finish(std::exception_ptr exception) {
123135
std::unique_lock<std::mutex> lock(mutex_);
124136
completed_ = true;
125137
exception_ = exception;
138+
if (recordFunctionEndCallback_) {
139+
recordFunctionEndCallback_();
140+
recordFunctionEndCallback_ = nullptr;
141+
}
126142
lock.unlock();
127143
cv_.notify_all();
128144
}
@@ -131,6 +147,10 @@ void ProcessGroup::Work::finishAndThrow(std::exception_ptr exception) {
131147
std::unique_lock<std::mutex> lock(mutex_);
132148
completed_ = true;
133149
exception_ = exception;
150+
if (recordFunctionEndCallback_) {
151+
recordFunctionEndCallback_();
152+
recordFunctionEndCallback_ = nullptr;
153+
}
134154
if (exception_) {
135155
std::rethrow_exception(exception_);
136156
}

torch/lib/c10d/ProcessGroup.hpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,7 @@ class ProcessGroup {
7777
// this will be bound using pybind.
7878
class Work {
7979
public:
80-
Work();
81-
82-
Work(int rank, OpType opType);
80+
Work(int rank = -1, OpType opType = OpType::UNKNOWN, const char* profilingTitle = nullptr);
8381

8482
virtual ~Work();
8583

@@ -156,6 +154,10 @@ class ProcessGroup {
156154

157155
// Operation type that this work object refers to.
158156
OpType opType_;
157+
158+
// When profiling, the callback to record end of operation event. This
159+
// callback needs to be called when collective operation is complete.
160+
std::function<void()> recordFunctionEndCallback_;
159161
};
160162

161163
explicit ProcessGroup(int rank, int size);

torch/lib/c10d/ProcessGroupGloo.cpp

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,8 @@ class AsyncBroadcastWork : public ProcessGroupGloo::AsyncWork {
677677
int rootRank,
678678
int rootTensor,
679679
uint32_t tag)
680-
: context(context),
680+
: ProcessGroupGloo::AsyncWork("gloo:broadcast"),
681+
context(context),
681682
inputs(inputs),
682683
rootRank(rootRank),
683684
rootTensor(rootTensor),
@@ -823,7 +824,8 @@ class AsyncAllreduceWork : public ProcessGroupGloo::AsyncWork {
823824
std::vector<at::Tensor>& inputs,
824825
ReduceOp reduceOp,
825826
uint32_t tag)
826-
: context(context), inputs(inputs), reduceOp(reduceOp), tag(tag) {}
827+
: ProcessGroupGloo::AsyncWork("gloo:all_reduce"),
828+
context(context), inputs(inputs), reduceOp(reduceOp), tag(tag) {}
827829

828830
std::shared_ptr<gloo::Context> context;
829831
std::vector<at::Tensor> inputs;
@@ -1431,7 +1433,8 @@ class AsyncReduceWork : public ProcessGroupGloo::AsyncWork {
14311433
int rootTensor,
14321434
ReduceOp reduceOp,
14331435
uint32_t tag)
1434-
: context(context),
1436+
: ProcessGroupGloo::AsyncWork("gloo:reduce"),
1437+
context(context),
14351438
inputs(inputs),
14361439
rootRank(rootRank),
14371440
rootTensor(rootTensor),
@@ -1595,7 +1598,8 @@ class AsyncAllgatherWork : public ProcessGroupGloo::AsyncWork {
15951598
std::vector<std::vector<at::Tensor>>& outputs,
15961599
std::vector<at::Tensor>& inputs,
15971600
uint32_t tag)
1598-
: context(context), outputs(outputs), inputs(inputs), tag(tag) {}
1601+
: ProcessGroupGloo::AsyncWork("gloo:all_gather"),
1602+
context(context), outputs(outputs), inputs(inputs), tag(tag) {}
15991603

16001604
std::shared_ptr<gloo::Context> context;
16011605
std::vector<std::vector<at::Tensor>> outputs;
@@ -1792,7 +1796,8 @@ class AsyncAllgatherCoalescedWork : public ProcessGroupGloo::AsyncWork {
17921796
std::vector<std::vector<at::Tensor>>& output_lists,
17931797
std::vector<at::Tensor>& input_list,
17941798
uint32_t tag)
1795-
: context(context),
1799+
: ProcessGroupGloo::AsyncWork("gloo:all_gather"),
1800+
context(context),
17961801
output_lists(output_lists),
17971802
input_list(input_list),
17981803
tag(tag) {}
@@ -1921,7 +1926,8 @@ class AsyncGatherWork : public ProcessGroupGloo::AsyncWork {
19211926
std::vector<at::Tensor>& inputs,
19221927
int root,
19231928
uint32_t tag)
1924-
: context(context),
1929+
: ProcessGroupGloo::AsyncWork("gloo:gather"),
1930+
context(context),
19251931
outputs(outputs),
19261932
inputs(inputs),
19271933
root(root),
@@ -2125,7 +2131,8 @@ class AsyncScatterWork : public ProcessGroupGloo::AsyncWork {
21252131
std::vector<std::vector<at::Tensor>>& inputs,
21262132
int root,
21272133
uint32_t tag)
2128-
: context(context),
2134+
: ProcessGroupGloo::AsyncWork("gloo:scatter"),
2135+
context(context),
21292136
outputs(outputs),
21302137
inputs(inputs),
21312138
root(root),
@@ -2319,7 +2326,8 @@ class AsyncAlltoallWork : public ProcessGroupGloo::AsyncWork {
23192326
std::vector<int64_t>& outputCounts,
23202327
std::vector<int64_t>& inputCounts,
23212328
uint32_t tag)
2322-
: context(context),
2329+
: ProcessGroupGloo::AsyncWork("gloo:all_to_all"),
2330+
context(context),
23232331
outputTensor(outputTensor),
23242332
inputTensor(inputTensor),
23252333
outputCounts(std::move(outputCounts)),
@@ -2576,7 +2584,8 @@ class AsyncBarrierWork : public ProcessGroupGloo::AsyncWork {
25762584
const std::shared_ptr<gloo::Context>& context,
25772585
std::vector<std::weak_ptr<AsyncWork>> priorWork,
25782586
uint32_t tag)
2579-
: context(context), priorWork(std::move(priorWork)), tag(tag) {}
2587+
: ProcessGroupGloo::AsyncWork("gloo:barrier"),
2588+
context(context), priorWork(std::move(priorWork)), tag(tag) {}
25802589

25812590
std::shared_ptr<gloo::Context> context;
25822591
std::vector<std::weak_ptr<AsyncWork>> priorWork;

torch/lib/c10d/ProcessGroupGloo.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ class ProcessGroupGloo : public ProcessGroup {
6868
//
6969
class AsyncWork : public ProcessGroup::Work {
7070
public:
71+
AsyncWork(const char* profilingTitle = nullptr): ProcessGroup::Work(-1, OpType::UNKNOWN, profilingTitle) {}
72+
7173
static void execute(std::shared_ptr<AsyncWork> work) {
7274
std::exception_ptr eptr;
7375
try {

torch/lib/c10d/ProcessGroupNCCL.cpp

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,9 @@ std::ostream& operator<<(
240240
ProcessGroupNCCL::WorkNCCL::WorkNCCL(
241241
const std::vector<at::Device>& devices,
242242
int rank,
243-
OpType opType)
244-
: Work(rank, opType),
243+
OpType opType,
244+
const char* profilingTitle)
245+
: Work(rank, opType, profilingTitle),
245246
devices_(devices),
246247
workStartTime_(std::chrono::steady_clock::now()) {
247248
// Creates the CUDA event wrappers
@@ -986,8 +987,9 @@ std::vector<at::Tensor> flatten_for_scatter_gather(
986987
std::shared_ptr<ProcessGroupNCCL::WorkNCCL> ProcessGroupNCCL::initWork(
987988
std::vector<at::Device> devices,
988989
int rank,
989-
OpType opType) {
990-
return std::make_shared<ProcessGroupNCCL::WorkNCCL>(devices, rank, opType);
990+
OpType opType,
991+
const char* profilingTitle) {
992+
return std::make_shared<ProcessGroupNCCL::WorkNCCL>(devices, rank, opType, profilingTitle);
991993
}
992994

993995
std::vector<at::Tensor> ProcessGroupNCCL::WorkNCCL::result() {
@@ -1031,7 +1033,8 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
10311033
Fn fn,
10321034
PreProcess pre,
10331035
PostProcess post,
1034-
OpType opType) {
1036+
OpType opType,
1037+
const char* profilingTitle) {
10351038
const auto devices = getDeviceList(inputs);
10361039
const auto key = getKeyFromDevices(devices);
10371040
auto& ncclComms = getNCCLComm(key, devices, opType);
@@ -1040,13 +1043,25 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
10401043
syncStreams(devices, ncclEvents_[key], ncclStreams_[key]);
10411044

10421045
// Work itself will create the CUDA events on all GPUs of tensors
1043-
auto work = initWork(devices, rank_, opType);
1046+
bool can_profile = outputs.size() == 1;
1047+
auto work = initWork(devices, rank_, opType, can_profile ? profilingTitle : nullptr);
10441048

10451049
// Store references to outputs and futureNCCLCallbackStream to be used by
10461050
// WorkNCCL::getFuture.
10471051
work->outputs_ = std::make_shared<std::vector<at::Tensor>>(outputs);
10481052
work->futureNCCLCallbackStreams_ = futureNCCLCallbackStreams_;
10491053

1054+
if (work->recordFunctionEndCallback_) {
1055+
// recordFunctionEndCallback_ is normally called in fininsh() function by
1056+
// base class, but since finish is not called by WorkNCCL, we schedule this
1057+
// function to be run when work is done.
1058+
// Note when can_profile is false, profilingTitle is not provided and so,
1059+
// recordFunctionEndCallback_ is not set.
1060+
work->getFuture()->addCallback(std::move(work->recordFunctionEndCallback_));
1061+
}
1062+
1063+
1064+
10501065
at::cuda::OptionalCUDAGuard gpuGuard;
10511066

10521067
pre(ncclStreams_[key]);
@@ -1175,14 +1190,16 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::collective(
11751190
std::vector<at::Tensor>& inputs,
11761191
std::vector<at::Tensor>& outputs,
11771192
Fn fn,
1178-
OpType opType) {
1193+
OpType opType,
1194+
const char* profilingTitle) {
11791195
return collective(
11801196
inputs,
11811197
outputs,
11821198
fn,
11831199
[](std::vector<at::cuda::CUDAStream>&) {},
11841200
[](std::vector<at::cuda::CUDAStream>&) {},
1185-
opType);
1201+
opType,
1202+
profilingTitle);
11861203
}
11871204

11881205
template <typename Fn>
@@ -1221,7 +1238,8 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::allreduce(
12211238
comm,
12221239
stream.stream());
12231240
},
1224-
OpType::ALLREDUCE);
1241+
OpType::ALLREDUCE,
1242+
"nccl:all_reduce");
12251243
}
12261244

12271245
std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::allreduce_coalesced(
@@ -1252,7 +1270,8 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::broadcast(
12521270
comm,
12531271
stream.stream());
12541272
},
1255-
OpType::BROADCAST);
1273+
OpType::BROADCAST,
1274+
"nccl:broadcast");
12561275
}
12571276

12581277
std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::reduce(
@@ -1278,7 +1297,8 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::reduce(
12781297
comm,
12791298
stream.stream());
12801299
},
1281-
OpType::REDUCE);
1300+
OpType::REDUCE,
1301+
"nccl:reduce");
12821302
}
12831303

12841304
std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::allgather(
@@ -1322,7 +1342,8 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::allgather(
13221342
}
13231343
}
13241344
},
1325-
OpType::ALLGATHER);
1345+
OpType::ALLGATHER,
1346+
"nccl:all_gather");
13261347
}
13271348

13281349
std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::allgather_coalesced(
@@ -1375,7 +1396,8 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::reduce_scatter(
13751396
}
13761397
},
13771398
[&](std::vector<at::cuda::CUDAStream>& ncclStreams) {},
1378-
OpType::REDUCE_SCATTER);
1399+
OpType::REDUCE_SCATTER,
1400+
"nccl:reduce_scatter");
13791401
}
13801402

13811403
std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::barrier(
@@ -1448,7 +1470,8 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::alltoall_base(
14481470
stream);
14491471
return ncclSuccess;
14501472
},
1451-
OpType::ALLTOALL_BASE);
1473+
OpType::ALLTOALL_BASE,
1474+
"nccl:all_to_all");
14521475
} else {
14531476
c10d::checkSplitSizes(inputSplitSizes, inputTensor, size_);
14541477
c10d::checkSplitSizes(outputSplitSizes, outputTensor, size_);
@@ -1484,7 +1507,8 @@ std::shared_ptr<ProcessGroup::Work> ProcessGroupNCCL::alltoall_base(
14841507
comm,
14851508
stream.stream());
14861509
},
1487-
OpType::ALLTOALL_BASE);
1510+
OpType::ALLTOALL_BASE,
1511+
"nccl:all_to_all");
14881512
}
14891513
}
14901514

torch/lib/c10d/ProcessGroupNCCL.hpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class ProcessGroupNCCL : public ProcessGroup {
6868
public std::enable_shared_from_this<WorkNCCL> {
6969
public:
7070
// Constructor takes a list of CUDA devices
71-
WorkNCCL(const std::vector<at::Device>& devices, int rank, OpType opType);
71+
WorkNCCL(const std::vector<at::Device>& devices, int rank, OpType opType, const char* profilingTitle = nullptr);
7272
// Copy constructor doing partial copy without outputs_. Cleanup thread
7373
// monitors and removes finished works. However it will deadlock when
7474
// destructs outputs_ tensors who are view tensors in autograd graph.
@@ -518,7 +518,8 @@ class ProcessGroupNCCL : public ProcessGroup {
518518
virtual std::shared_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
519519
std::vector<at::Device> devices,
520520
int rank,
521-
OpType opType);
521+
OpType opType,
522+
const char* profilingTitle=nullptr);
522523

523524
private:
524525
// Helper that encapsulates work shared across all collective communication
@@ -532,15 +533,17 @@ class ProcessGroupNCCL : public ProcessGroup {
532533
std::vector<at::Tensor>& input,
533534
std::vector<at::Tensor>& output,
534535
Fn fn,
535-
OpType opType);
536+
OpType opType,
537+
const char* profilingTitle = nullptr);
536538
template <typename Fn, typename PreProcess, typename PostProcess>
537539
std::shared_ptr<ProcessGroup::Work> collective(
538540
std::vector<at::Tensor>& input,
539541
std::vector<at::Tensor>& output,
540542
Fn fn,
541543
PreProcess pre,
542544
PostProcess post,
543-
OpType opType);
545+
OpType opType,
546+
const char* profilingTitle = nullptr);
544547

545548
// Helper that encapsulates work shared across point-to-point communication
546549
// primitives. It is the same structure as the helper used for collective

torch/lib/c10d/test/ProcessGroupNCCLErrorsTest.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ class ProcessGroupNCCLSimulateErrors : public c10d::ProcessGroupNCCL {
5959
std::shared_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
6060
std::vector<at::Device> devices,
6161
int rank,
62-
c10d::OpType opType) override {
62+
c10d::OpType opType,
63+
const char* profilingTitle) override {
6364
return std::make_shared<WorkNCCLSimulateErrors>(
6465
devices, simulate_error_, rank, opType);
6566
}
@@ -115,7 +116,8 @@ class ProcessGroupNCCLTimedOutErrors : public ProcessGroupNCCLSimulateErrors {
115116
std::shared_ptr<ProcessGroupNCCL::WorkNCCL> initWork(
116117
std::vector<at::Device> devices,
117118
int rank,
118-
c10d::OpType opType) override {
119+
c10d::OpType opType,
120+
const char* profilingTitle) override {
119121
return std::make_shared<WorkNCCLTimedoutErrors>(
120122
devices, set_timedout_error_, rank, opType);
121123
}

0 commit comments

Comments
 (0)