1010#include < torch/csrc/distributed/rpc/tensorpipe_utils.h>
1111#include < torch/csrc/distributed/rpc/utils.h>
1212
13+ #ifdef USE_CUDA_NOT_ROCM
14+ #include < ATen/cuda/CUDAMultiStreamGuard.h>
15+ #endif
16+
1317namespace torch {
1418namespace distributed {
1519namespace rpc {
@@ -201,6 +205,30 @@ C10_REGISTER_CREATOR(
201205
202206} // namespace
203207
208+ namespace {
209+
210+ // This is a wrapper of CUDAMultiStreamGuard to run in both CUDA-enabled and
211+ // CPU-only environments. When CUDA is not available, all methods are no-ops.
212+ struct MultiStreamGuard {
213+ MultiStreamGuard (const MultiStreamGuard& other) = delete ;
214+ MultiStreamGuard (MultiStreamGuard&& other) = delete ;
215+ MultiStreamGuard& operator =(const MultiStreamGuard& rhs) = delete ;
216+ MultiStreamGuard& operator =(MultiStreamGuard&& rhs) = delete ;
217+
218+ #ifndef USE_CUDA_NOT_ROCM
219+ explicit MultiStreamGuard (
220+ const std::shared_ptr<LazyStreamContext>& /* unused */ ) {}
221+ #else
222+ explicit MultiStreamGuard (const std::shared_ptr<LazyStreamContext>& ctx)
223+ : guard(ctx->getReservedStreams ()) {}
224+
225+ private:
226+ at::cuda::CUDAMultiStreamGuard guard;
227+ #endif
228+ };
229+
230+ } // namespace
231+
204232// //////////////////////// MetricsTracker /////////////////////////////////
205233
206234TensorPipeAgent::TimeSeriesMetricsTracker::TimeSeriesMetricsTracker (
@@ -412,26 +440,31 @@ void TensorPipeAgent::onListenerAccepted(
412440
413441void TensorPipeAgent::pipeRead (
414442 const std::shared_ptr<tensorpipe::Pipe>& pipe,
415- std::function<void (const tensorpipe::Error&, Message&&)> fn) noexcept {
443+ std::function<void (
444+ const tensorpipe::Error&,
445+ Message&&,
446+ std::shared_ptr<LazyStreamContext>)> fn) noexcept {
416447 pipe->readDescriptor ([fn{std::move (fn)}, pipe](
417448 const tensorpipe::Error& error,
418449 tensorpipe::Message tpMessage) mutable {
419450 if (error) {
420- fn (error, Message ());
451+ fn (error, Message (), nullptr );
421452 return ;
422453 }
423454
424- TensorpipeReadBuffers tpBuffers = tensorpipeAllocate (tpMessage);
455+ auto ctx = createLazyStreamContext ();
456+ TensorpipeReadBuffers tpBuffers = tensorpipeAllocate (tpMessage, ctx);
425457
426458 pipe->read (
427459 std::move (tpMessage),
428460 [tpBuffers{
429461 std::make_shared<TensorpipeReadBuffers>(std::move (tpBuffers))},
430- fn{std::move (fn)}](
462+ fn{std::move (fn)},
463+ ctx{std::move (ctx)}](
431464 const tensorpipe::Error& error,
432465 tensorpipe::Message tpMessage) mutable {
433466 if (error) {
434- fn (error, Message ());
467+ fn (error, Message (), nullptr );
435468 return ;
436469 }
437470
@@ -440,7 +473,7 @@ void TensorPipeAgent::pipeRead(
440473 Message rpcMessage = tensorpipeDeserialize (
441474 std::move (tpMessage), std::move (*tpBuffers));
442475
443- fn (error, std::move (rpcMessage));
476+ fn (error, std::move (rpcMessage), std::move (ctx) );
444477 });
445478 });
446479}
@@ -449,18 +482,20 @@ void TensorPipeAgent::pipeWrite(
449482 const std::shared_ptr<tensorpipe::Pipe>& pipe,
450483 Message&& rpcMessage,
451484 std::vector<c10::DeviceIndex>&& devices,
485+ std::shared_ptr<LazyStreamContext> ctx,
452486 std::function<void (const tensorpipe::Error&)> fn) noexcept {
453487 tensorpipe::Message tpMessage;
454488 TensorpipeWriteBuffers tpBuffers;
455489
456490 std::tie (tpMessage, tpBuffers) =
457- tensorpipeSerialize (std::move (rpcMessage), std::move (devices));
491+ tensorpipeSerialize (std::move (rpcMessage), std::move (devices), ctx );
458492
459493 pipe->write (
460494 std::move (tpMessage),
461495 [tpBuffers{
462496 std::make_shared<TensorpipeWriteBuffers>(std::move (tpBuffers))},
463- fn{std::move (fn)}](
497+ fn{std::move (fn)},
498+ ctx{std::move (ctx)}](
464499 const tensorpipe::Error& error, tensorpipe::Message /* unused */ ) {
465500 fn (error);
466501 });
@@ -469,7 +504,8 @@ void TensorPipeAgent::pipeWrite(
469504void TensorPipeAgent::sendCompletedResponseMessage (
470505 std::shared_ptr<tensorpipe::Pipe>& pipe,
471506 std::shared_ptr<JitFuture>& futureResponseMessage,
472- uint64_t messageId) {
507+ uint64_t messageId,
508+ std::shared_ptr<LazyStreamContext> ctx) {
473509 if (!rpcAgentRunning_.load ()) {
474510 LOG (WARNING) << " RPC agent for " << workerInfo_.name_
475511 << " won't send response to request #" << messageId << " to "
@@ -496,6 +532,7 @@ void TensorPipeAgent::sendCompletedResponseMessage(
496532 pipe,
497533 std::move (responseMessage),
498534 std::move (devices),
535+ std::move (ctx),
499536 [this , pipe, messageId](const tensorpipe::Error& error) {
500537 if (error) {
501538 LOG (WARNING)
@@ -515,7 +552,8 @@ void TensorPipeAgent::sendCompletedResponseMessage(
515552 pipe,
516553 createExceptionResponse (
517554 futureResponseMessage->tryRetrieveErrorMessage (), messageId),
518- {},
555+ /* devices */ {},
556+ std::move (ctx),
519557 [this , pipe, messageId](const tensorpipe::Error& error) {
520558 if (error) {
521559 LOG (WARNING)
@@ -537,7 +575,9 @@ void TensorPipeAgent::respond(std::shared_ptr<tensorpipe::Pipe>& pipe) {
537575 pipeRead (
538576 pipe,
539577 [this , pipe](
540- const tensorpipe::Error& error, Message&& requestMessage) mutable {
578+ const tensorpipe::Error& error,
579+ Message&& requestMessage,
580+ std::shared_ptr<LazyStreamContext> ctx) mutable {
541581 if (error) {
542582 // FIXME This is not a correct way to check whether this error was
543583 // "intentionally" caused by the remote end shutting down. We should
@@ -570,7 +610,10 @@ void TensorPipeAgent::respond(std::shared_ptr<tensorpipe::Pipe>& pipe) {
570610 threadPool_.run ([this ,
571611 pipe,
572612 messageId,
573- requestMessage{std::move (requestMessage)}]() mutable {
613+ requestMessage{std::move (requestMessage)},
614+ ctx{std::move (ctx)}]() mutable {
615+ // create guards again as this function runs on a different thread
616+ MultiStreamGuard guard (ctx);
574617 VLOG (1 ) << " RPC agent for " << workerInfo_.name_
575618 << " is running request #" << messageId << " from "
576619 << pipe->getRemoteName () << " in thread pool" ;
@@ -588,17 +631,20 @@ void TensorPipeAgent::respond(std::shared_ptr<tensorpipe::Pipe>& pipe) {
588631 if (futureResponseMessage->completed ()) {
589632 decreaseCallCount (serverActiveCalls_);
590633 sendCompletedResponseMessage (
591- pipe, futureResponseMessage, messageId);
634+ pipe, futureResponseMessage, messageId, std::move (ctx) );
592635 } else {
593636 // Not complete yet
594637 increaseCallCount (serverActiveAsyncCalls_);
595- futureResponseMessage->addCallback (
596- [this , pipe, futureResponseMessage, messageId]() mutable {
597- decreaseCallCount (serverActiveCalls_);
598- decreaseCallCount (serverActiveAsyncCalls_);
599- sendCompletedResponseMessage (
600- pipe, futureResponseMessage, messageId);
601- });
638+ futureResponseMessage->addCallback ([this ,
639+ pipe,
640+ futureResponseMessage,
641+ messageId,
642+ ctx{std::move (ctx)}]() mutable {
643+ decreaseCallCount (serverActiveCalls_);
644+ decreaseCallCount (serverActiveAsyncCalls_);
645+ sendCompletedResponseMessage (
646+ pipe, futureResponseMessage, messageId, std::move (ctx));
647+ });
602648 }
603649
604650 VLOG (1 ) << " RPC agent for " << workerInfo_.name_
@@ -641,7 +687,8 @@ std::shared_ptr<JitFuture> TensorPipeAgent::send(
641687 ClientPipe& clientPipe = it->second ;
642688 auto & pendingResponseMessage = clientPipe.pendingResponseMessage_ ;
643689
644- auto futureResponseMessage = std::make_shared<AtomicJitFuture>();
690+ auto futureResponseMessage = std::make_shared<AtomicJitFuture>(
691+ reverseDeviceMaps_.empty () && opts_.deviceMaps .empty ());
645692 uint64_t messageId = nextMessageID_++;
646693 requestMessage.setId (messageId);
647694 pendingResponseMessage[messageId] = futureResponseMessage;
@@ -686,10 +733,13 @@ std::shared_ptr<JitFuture> TensorPipeAgent::send(
686733 VLOG (1 ) << " RPC agent for " << workerInfo_.name_ << " is sending request #"
687734 << messageId << " to " << clientPipe.pipe_ ->getRemoteName ();
688735
736+ auto ctx = createLazyStreamContext ();
737+ ctx->waitForCurrentStreams (requestMessage.tensors ());
689738 pipeWrite (
690739 clientPipe.pipe_ ,
691740 std::move (requestMessage),
692741 std::move (devices),
742+ std::move (ctx),
693743 [this , &clientPipe, messageId](const tensorpipe::Error& error) mutable {
694744 if (error) {
695745 if (error.isOfType <tensorpipe::PipeClosedError>() &&
@@ -716,7 +766,9 @@ std::shared_ptr<JitFuture> TensorPipeAgent::send(
716766 pipeRead (
717767 clientPipe.pipe_ ,
718768 [this , &clientPipe](
719- const tensorpipe::Error& error, Message&& responseMessage) {
769+ const tensorpipe::Error& error,
770+ Message&& responseMessage,
771+ std::shared_ptr<LazyStreamContext> ctx) {
720772 if (error) {
721773 if (error.isOfType <tensorpipe::PipeClosedError>() &&
722774 !rpcAgentRunning_.load ()) {
@@ -777,7 +829,8 @@ std::shared_ptr<JitFuture> TensorPipeAgent::send(
777829 } else {
778830 markFutureAsComplete (
779831 std::move (futureResponseMessage),
780- std::move (responseMessage));
832+ std::move (responseMessage),
833+ std::move (ctx));
781834 }
782835 });
783836 });
@@ -1029,14 +1082,17 @@ void TensorPipeAgent::decreaseCallCount(int32_t& count) {
10291082
10301083void TensorPipeAgent::markFutureAsComplete (
10311084 std::shared_ptr<AtomicJitFuture> atomicFuture,
1032- Message message) {
1085+ Message message,
1086+ std::shared_ptr<LazyStreamContext> ctx) {
10331087 if (!atomicFuture->isComplete .test_and_set ()) {
10341088 // Completing the future will run its callbacks, which could execute
10351089 // arbitrary user code. To prevent blocking or stalling the TensorPipe event
10361090 // loops, we defer this to a worker thread.
10371091 threadPool_.run ([this ,
10381092 atomicFuture{std::move (atomicFuture)},
1039- message{std::move (message)}]() mutable {
1093+ message{std::move (message)},
1094+ ctx{std::move (ctx)}]() mutable {
1095+ MultiStreamGuard guard (ctx);
10401096 atomicFuture->jitFuture ->markCompleted (
10411097 IValue (c10::make_intrusive<Message>(std::move (message))));
10421098 // The future's callbacks may schedule further RPCs, increasing the count.
@@ -1096,6 +1152,7 @@ std::vector<c10::DeviceIndex> TensorPipeAgent::getDevicesForTensors(
10961152 std::vector<c10::DeviceIndex> deviceIndices;
10971153 deviceIndices.reserve (message.tensors ().size ());
10981154 const auto & deviceMap = iter->second ;
1155+ bool hasCudaTensor = false ;
10991156 for (const auto & t : message.tensors ()) {
11001157 if (t.device ().is_cpu ()) {
11011158 deviceIndices.push_back (-1 );
@@ -1108,8 +1165,12 @@ std::vector<c10::DeviceIndex> TensorPipeAgent::getDevicesForTensors(
11081165 t.device (),
11091166 " but received a tensor on that device." );
11101167 deviceIndices.push_back (deviceIter->second );
1168+ hasCudaTensor = true ;
11111169 }
11121170 }
1171+ if (!hasCudaTensor) {
1172+ deviceIndices.clear ();
1173+ }
11131174 return deviceIndices;
11141175 }
11151176}
0 commit comments