Skip to content

Commit 120f934

Browse files
committed
Enable GPU-to-GPU comm in TensorPipeAgent
Pull Request resolved: #44418 This commit uses TensorPipe's cuda_ipc channel to conduct cross-process same-machine GPU-to-GPU communication. On the sender side, `TensorPipeAgent` grabs a stream to each device used by the message, let these streams wait for current streams, and passes the streams to TensorPipe `CudaBuffer`. On the receiver side, it also grabs a stream for each device used in the message, and uses these streams to receive tensors and run user functions. After that, these streams are then used for sending the response back to the sender. When receiving the response, the sender will grab a new set of streams and use them for TensorPipe's `CudaBuffer`. If device maps are provided, `TensorPipeAgent::send` will return a derived class of `CUDAFuture`, which is specifically tailored for RPC Messages. TODOs: 1. Enable sending CUDA RPC to the same process. 2. Add a custom CUDA stream pool. 3. When TensorPipe addressed the error for `cudaPointerGetAttributes()`, remove `cuda:0` context initialization code in `backend_registry.py`. 4. When TensorPipe can detect availability of peer access, enable all tests on platforms without peer access. Differential Revision: [D23626207](https://our.internmc.facebook.com/intern/diff/D23626207/) **NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D23626207/)! ghstack-source-id: 119821241
1 parent 2a60314 commit 120f934

10 files changed

Lines changed: 699 additions & 93 deletions

File tree

aten/src/ATen/cuda/CUDAFuture.h

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
namespace at { namespace cuda {
2323

24-
struct TORCH_CUDA_API CUDAFuture final : at::ivalue::Future {
24+
struct TORCH_CUDA_API CUDAFuture : at::ivalue::Future {
2525
public:
2626
using at::ivalue::Future::Future;
2727

@@ -106,22 +106,7 @@ struct TORCH_CUDA_API CUDAFuture final : at::ivalue::Future {
106106
}
107107
}
108108

109-
private:
110-
// The device that was current when markCompleted was called, which we'll
111-
// restore when invoking callbacks.
112-
c10::DeviceIndex currentDevice_;
113-
114-
// The events that correspond to the completion of the async I/O kernels. They
115-
// are recorded on the appropriate streams when the future is marked completed
116-
// and can then be queried/waited/blocked on. There is one event for each
117-
// distinct device on which the value's tensors reside.
118-
std::vector<at::cuda::CUDAEvent> cudaEvents_;
119-
120-
// A cached version of the data ptrs extracted from the value when the future
121-
// is first marked completed.
122-
std::vector<std::reference_wrapper<const at::DataPtr>> dataPtrs_;
123-
124-
std::vector<std::reference_wrapper<const at::DataPtr>> extractDataPtrs(
109+
virtual std::vector<std::reference_wrapper<const at::DataPtr>> extractDataPtrs(
125110
const at::IValue& value) {
126111
at::IValue::HashAliasedIValues sub_values;
127112
// Prefer getSubValues() over visit() as the latter is a silent no-op for
@@ -136,6 +121,21 @@ struct TORCH_CUDA_API CUDAFuture final : at::ivalue::Future {
136121
}
137122
return data_ptrs;
138123
}
124+
125+
private:
126+
// The device that was current when markCompleted was called, which we'll
127+
// restore when invoking callbacks.
128+
c10::DeviceIndex currentDevice_;
129+
130+
// The events that correspond to the completion of the async I/O kernels. They
131+
// are recorded on the appropriate streams when the future is marked completed
132+
// and can then be queried/waited/blocked on. There is one event for each
133+
// distinct device on which the value's tensors reside.
134+
std::vector<at::cuda::CUDAEvent> cudaEvents_;
135+
136+
// A cached version of the data ptrs extracted from the value when the future
137+
// is first marked completed.
138+
std::vector<std::reference_wrapper<const at::DataPtr>> dataPtrs_;
139139
};
140140

141141
} // namespace cuda

caffe2/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
302302
add_dependencies(process_group_agent torch c10d)
303303

304304
add_library(tensorpipe_agent
305+
"${TORCH_SRC_DIR}/csrc/distributed/rpc/macros.h"
305306
"${TORCH_SRC_DIR}/csrc/distributed/rpc/tensorpipe_agent.cpp"
306307
"${TORCH_SRC_DIR}/csrc/distributed/rpc/tensorpipe_agent.h"
307308
"${TORCH_SRC_DIR}/csrc/distributed/rpc/tensorpipe_utils.cpp"
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
#pragma once
2+
3+
#if defined(USE_CUDA) && !defined(__HIP_PLATFORM_HCC__)
4+
#define USE_CUDA_NOT_ROCM
5+
#endif

torch/csrc/distributed/rpc/tensorpipe_agent.cpp

Lines changed: 86 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
#include <torch/csrc/distributed/rpc/tensorpipe_utils.h>
1111
#include <torch/csrc/distributed/rpc/utils.h>
1212

13+
#ifdef USE_CUDA_NOT_ROCM
14+
#include <ATen/cuda/CUDAMultiStreamGuard.h>
15+
#endif
16+
1317
namespace torch {
1418
namespace distributed {
1519
namespace rpc {
@@ -201,6 +205,30 @@ C10_REGISTER_CREATOR(
201205

202206
} // namespace
203207

208+
namespace {
209+
210+
// This is a wrapper of CUDAMultiStreamGuard to run in both CUDA-enabled and
211+
// CPU-only environments. When CUDA is not available, all methods are no-ops.
212+
struct MultiStreamGuard {
213+
MultiStreamGuard(const MultiStreamGuard& other) = delete;
214+
MultiStreamGuard(MultiStreamGuard&& other) = delete;
215+
MultiStreamGuard& operator=(const MultiStreamGuard& rhs) = delete;
216+
MultiStreamGuard& operator=(MultiStreamGuard&& rhs) = delete;
217+
218+
#ifndef USE_CUDA_NOT_ROCM
219+
explicit MultiStreamGuard(
220+
const std::shared_ptr<LazyStreamContext>& /* unused */) {}
221+
#else
222+
explicit MultiStreamGuard(const std::shared_ptr<LazyStreamContext>& ctx)
223+
: guard(ctx->getReservedStreams()) {}
224+
225+
private:
226+
at::cuda::CUDAMultiStreamGuard guard;
227+
#endif
228+
};
229+
230+
} // namespace
231+
204232
////////////////////////// MetricsTracker /////////////////////////////////
205233

206234
TensorPipeAgent::TimeSeriesMetricsTracker::TimeSeriesMetricsTracker(
@@ -412,26 +440,31 @@ void TensorPipeAgent::onListenerAccepted(
412440

413441
void TensorPipeAgent::pipeRead(
414442
const std::shared_ptr<tensorpipe::Pipe>& pipe,
415-
std::function<void(const tensorpipe::Error&, Message&&)> fn) noexcept {
443+
std::function<void(
444+
const tensorpipe::Error&,
445+
Message&&,
446+
std::shared_ptr<LazyStreamContext>)> fn) noexcept {
416447
pipe->readDescriptor([fn{std::move(fn)}, pipe](
417448
const tensorpipe::Error& error,
418449
tensorpipe::Message tpMessage) mutable {
419450
if (error) {
420-
fn(error, Message());
451+
fn(error, Message(), nullptr);
421452
return;
422453
}
423454

424-
TensorpipeReadBuffers tpBuffers = tensorpipeAllocate(tpMessage);
455+
auto ctx = createLazyStreamContext();
456+
TensorpipeReadBuffers tpBuffers = tensorpipeAllocate(tpMessage, ctx);
425457

426458
pipe->read(
427459
std::move(tpMessage),
428460
[tpBuffers{
429461
std::make_shared<TensorpipeReadBuffers>(std::move(tpBuffers))},
430-
fn{std::move(fn)}](
462+
fn{std::move(fn)},
463+
ctx{std::move(ctx)}](
431464
const tensorpipe::Error& error,
432465
tensorpipe::Message tpMessage) mutable {
433466
if (error) {
434-
fn(error, Message());
467+
fn(error, Message(), nullptr);
435468
return;
436469
}
437470

@@ -440,7 +473,7 @@ void TensorPipeAgent::pipeRead(
440473
Message rpcMessage = tensorpipeDeserialize(
441474
std::move(tpMessage), std::move(*tpBuffers));
442475

443-
fn(error, std::move(rpcMessage));
476+
fn(error, std::move(rpcMessage), std::move(ctx));
444477
});
445478
});
446479
}
@@ -449,18 +482,20 @@ void TensorPipeAgent::pipeWrite(
449482
const std::shared_ptr<tensorpipe::Pipe>& pipe,
450483
Message&& rpcMessage,
451484
std::vector<c10::DeviceIndex>&& devices,
485+
std::shared_ptr<LazyStreamContext> ctx,
452486
std::function<void(const tensorpipe::Error&)> fn) noexcept {
453487
tensorpipe::Message tpMessage;
454488
TensorpipeWriteBuffers tpBuffers;
455489

456490
std::tie(tpMessage, tpBuffers) =
457-
tensorpipeSerialize(std::move(rpcMessage), std::move(devices));
491+
tensorpipeSerialize(std::move(rpcMessage), std::move(devices), ctx);
458492

459493
pipe->write(
460494
std::move(tpMessage),
461495
[tpBuffers{
462496
std::make_shared<TensorpipeWriteBuffers>(std::move(tpBuffers))},
463-
fn{std::move(fn)}](
497+
fn{std::move(fn)},
498+
ctx{std::move(ctx)}](
464499
const tensorpipe::Error& error, tensorpipe::Message /* unused */) {
465500
fn(error);
466501
});
@@ -469,7 +504,8 @@ void TensorPipeAgent::pipeWrite(
469504
void TensorPipeAgent::sendCompletedResponseMessage(
470505
std::shared_ptr<tensorpipe::Pipe>& pipe,
471506
std::shared_ptr<JitFuture>& futureResponseMessage,
472-
uint64_t messageId) {
507+
uint64_t messageId,
508+
std::shared_ptr<LazyStreamContext> ctx) {
473509
if (!rpcAgentRunning_.load()) {
474510
LOG(WARNING) << "RPC agent for " << workerInfo_.name_
475511
<< " won't send response to request #" << messageId << " to "
@@ -496,6 +532,7 @@ void TensorPipeAgent::sendCompletedResponseMessage(
496532
pipe,
497533
std::move(responseMessage),
498534
std::move(devices),
535+
std::move(ctx),
499536
[this, pipe, messageId](const tensorpipe::Error& error) {
500537
if (error) {
501538
LOG(WARNING)
@@ -515,7 +552,8 @@ void TensorPipeAgent::sendCompletedResponseMessage(
515552
pipe,
516553
createExceptionResponse(
517554
futureResponseMessage->tryRetrieveErrorMessage(), messageId),
518-
{},
555+
/* devices */ {},
556+
std::move(ctx),
519557
[this, pipe, messageId](const tensorpipe::Error& error) {
520558
if (error) {
521559
LOG(WARNING)
@@ -537,7 +575,9 @@ void TensorPipeAgent::respond(std::shared_ptr<tensorpipe::Pipe>& pipe) {
537575
pipeRead(
538576
pipe,
539577
[this, pipe](
540-
const tensorpipe::Error& error, Message&& requestMessage) mutable {
578+
const tensorpipe::Error& error,
579+
Message&& requestMessage,
580+
std::shared_ptr<LazyStreamContext> ctx) mutable {
541581
if (error) {
542582
// FIXME This is not a correct way to check whether this error was
543583
// "intentionally" caused by the remote end shutting down. We should
@@ -570,7 +610,10 @@ void TensorPipeAgent::respond(std::shared_ptr<tensorpipe::Pipe>& pipe) {
570610
threadPool_.run([this,
571611
pipe,
572612
messageId,
573-
requestMessage{std::move(requestMessage)}]() mutable {
613+
requestMessage{std::move(requestMessage)},
614+
ctx{std::move(ctx)}]() mutable {
615+
// create guards again as this function runs on a different thread
616+
MultiStreamGuard guard(ctx);
574617
VLOG(1) << "RPC agent for " << workerInfo_.name_
575618
<< " is running request #" << messageId << " from "
576619
<< pipe->getRemoteName() << " in thread pool";
@@ -588,17 +631,20 @@ void TensorPipeAgent::respond(std::shared_ptr<tensorpipe::Pipe>& pipe) {
588631
if (futureResponseMessage->completed()) {
589632
decreaseCallCount(serverActiveCalls_);
590633
sendCompletedResponseMessage(
591-
pipe, futureResponseMessage, messageId);
634+
pipe, futureResponseMessage, messageId, std::move(ctx));
592635
} else {
593636
// Not complete yet
594637
increaseCallCount(serverActiveAsyncCalls_);
595-
futureResponseMessage->addCallback(
596-
[this, pipe, futureResponseMessage, messageId]() mutable {
597-
decreaseCallCount(serverActiveCalls_);
598-
decreaseCallCount(serverActiveAsyncCalls_);
599-
sendCompletedResponseMessage(
600-
pipe, futureResponseMessage, messageId);
601-
});
638+
futureResponseMessage->addCallback([this,
639+
pipe,
640+
futureResponseMessage,
641+
messageId,
642+
ctx{std::move(ctx)}]() mutable {
643+
decreaseCallCount(serverActiveCalls_);
644+
decreaseCallCount(serverActiveAsyncCalls_);
645+
sendCompletedResponseMessage(
646+
pipe, futureResponseMessage, messageId, std::move(ctx));
647+
});
602648
}
603649

604650
VLOG(1) << "RPC agent for " << workerInfo_.name_
@@ -641,7 +687,8 @@ std::shared_ptr<JitFuture> TensorPipeAgent::send(
641687
ClientPipe& clientPipe = it->second;
642688
auto& pendingResponseMessage = clientPipe.pendingResponseMessage_;
643689

644-
auto futureResponseMessage = std::make_shared<AtomicJitFuture>();
690+
auto futureResponseMessage = std::make_shared<AtomicJitFuture>(
691+
reverseDeviceMaps_.empty() && opts_.deviceMaps.empty());
645692
uint64_t messageId = nextMessageID_++;
646693
requestMessage.setId(messageId);
647694
pendingResponseMessage[messageId] = futureResponseMessage;
@@ -686,10 +733,13 @@ std::shared_ptr<JitFuture> TensorPipeAgent::send(
686733
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is sending request #"
687734
<< messageId << " to " << clientPipe.pipe_->getRemoteName();
688735

736+
auto ctx = createLazyStreamContext();
737+
ctx->waitForCurrentStreams(requestMessage.tensors());
689738
pipeWrite(
690739
clientPipe.pipe_,
691740
std::move(requestMessage),
692741
std::move(devices),
742+
std::move(ctx),
693743
[this, &clientPipe, messageId](const tensorpipe::Error& error) mutable {
694744
if (error) {
695745
if (error.isOfType<tensorpipe::PipeClosedError>() &&
@@ -716,7 +766,9 @@ std::shared_ptr<JitFuture> TensorPipeAgent::send(
716766
pipeRead(
717767
clientPipe.pipe_,
718768
[this, &clientPipe](
719-
const tensorpipe::Error& error, Message&& responseMessage) {
769+
const tensorpipe::Error& error,
770+
Message&& responseMessage,
771+
std::shared_ptr<LazyStreamContext> ctx) {
720772
if (error) {
721773
if (error.isOfType<tensorpipe::PipeClosedError>() &&
722774
!rpcAgentRunning_.load()) {
@@ -777,7 +829,8 @@ std::shared_ptr<JitFuture> TensorPipeAgent::send(
777829
} else {
778830
markFutureAsComplete(
779831
std::move(futureResponseMessage),
780-
std::move(responseMessage));
832+
std::move(responseMessage),
833+
std::move(ctx));
781834
}
782835
});
783836
});
@@ -1029,14 +1082,17 @@ void TensorPipeAgent::decreaseCallCount(int32_t& count) {
10291082

10301083
void TensorPipeAgent::markFutureAsComplete(
10311084
std::shared_ptr<AtomicJitFuture> atomicFuture,
1032-
Message message) {
1085+
Message message,
1086+
std::shared_ptr<LazyStreamContext> ctx) {
10331087
if (!atomicFuture->isComplete.test_and_set()) {
10341088
// Completing the future will run its callbacks, which could execute
10351089
// arbitrary user code. To prevent blocking or stalling the TensorPipe event
10361090
// loops, we defer this to a worker thread.
10371091
threadPool_.run([this,
10381092
atomicFuture{std::move(atomicFuture)},
1039-
message{std::move(message)}]() mutable {
1093+
message{std::move(message)},
1094+
ctx{std::move(ctx)}]() mutable {
1095+
MultiStreamGuard guard(ctx);
10401096
atomicFuture->jitFuture->markCompleted(
10411097
IValue(c10::make_intrusive<Message>(std::move(message))));
10421098
// The future's callbacks may schedule further RPCs, increasing the count.
@@ -1096,6 +1152,7 @@ std::vector<c10::DeviceIndex> TensorPipeAgent::getDevicesForTensors(
10961152
std::vector<c10::DeviceIndex> deviceIndices;
10971153
deviceIndices.reserve(message.tensors().size());
10981154
const auto& deviceMap = iter->second;
1155+
bool hasCudaTensor = false;
10991156
for (const auto& t : message.tensors()) {
11001157
if (t.device().is_cpu()) {
11011158
deviceIndices.push_back(-1);
@@ -1108,8 +1165,12 @@ std::vector<c10::DeviceIndex> TensorPipeAgent::getDevicesForTensors(
11081165
t.device(),
11091166
" but received a tensor on that device.");
11101167
deviceIndices.push_back(deviceIter->second);
1168+
hasCudaTensor = true;
11111169
}
11121170
}
1171+
if (!hasCudaTensor) {
1172+
deviceIndices.clear();
1173+
}
11131174
return deviceIndices;
11141175
}
11151176
}

0 commit comments

Comments
 (0)