Skip to content

Future callbacks in RPC should capture and restore autograd context id #38439

@mrshenli

Description

@mrshenli

In early versions of RPC, all threads on server side are block-waiting until the request is fully processed. Hence we use a guard to clear the thread_local autograd context.

ClearAutogradContextGuard guard;
processRpc(*rpc, messageType, id, retFuture);

However, since we gradually making the RPC non-blocking on server side, the above guard is no longer correct. For example, in FORWARD_AUTOGRAD_REQ, the bottom half of the processing is done in a callback, which could run on a different thread. If it indeed runs on a different thread, the autograd context id is no longer valid.

autogradContainer.setCurrentContextId(autogradContext->contextId());
// Process the original RPC.
auto wrappedMessageType = rpcWithAutograd.wrappedMessageType();
// Make an overall future for the wrapped response.
auto wrappedRpcResponseFuture = std::make_shared<FutureMessage>();
// Kick off processing for the nested future and get a Future<T> to the
// result.
processRpc(
rpcWithAutograd.wrappedRpc(),
wrappedMessageType,
messageId,
wrappedRpcResponseFuture);
auto fromWorkerId = rpcWithAutograd.fromWorkerId();
// The original future needs to be marked as completed when the wrapped
// one completes, with the autograd context information wrapped.
// Uses weak_ptr so we can std::move the value.
wrappedRpcResponseFuture->addCallback(
[responseFuture,
messageId,
fromWorkerId,
weak = std::weak_ptr<FutureMessage>(wrappedRpcResponseFuture)]() {
auto wrappedRpcResponseFuture = weak.lock();
TORCH_INTERNAL_ASSERT(wrappedRpcResponseFuture);
if (wrappedRpcResponseFuture->hasError()) {
// Propagate error to responseFuture if we had one.
responseFuture->setError(
wrappedRpcResponseFuture->error()->what());
} else {
auto msg = getMessageWithAutograd(
fromWorkerId,
std::move(*wrappedRpcResponseFuture).moveValue(),
MessageType::FORWARD_AUTOGRAD_RESP);
msg.setId(messageId);
responseFuture->markCompleted(std::move(msg));
}
});
return;

This could be the cause of the recent flakiness in distributed autograd test.

cc @ezyang @gchanan @zou3519 @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @rohan-varma @xush6528 @jjlilley @osalpekar

Metadata

Metadata

Assignees

Labels

high prioritymodule: rpcRelated to RPC, distributed autograd, RRef, and distributed optimizertriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions