In early versions of RPC, all threads on server side are block-waiting until the request is fully processed. Hence we use a guard to clear the thread_local autograd context.
|
ClearAutogradContextGuard guard; |
|
processRpc(*rpc, messageType, id, retFuture); |
However, since we gradually making the RPC non-blocking on server side, the above guard is no longer correct. For example, in FORWARD_AUTOGRAD_REQ, the bottom half of the processing is done in a callback, which could run on a different thread. If it indeed runs on a different thread, the autograd context id is no longer valid.
|
autogradContainer.setCurrentContextId(autogradContext->contextId()); |
|
|
|
// Process the original RPC. |
|
auto wrappedMessageType = rpcWithAutograd.wrappedMessageType(); |
|
// Make an overall future for the wrapped response. |
|
auto wrappedRpcResponseFuture = std::make_shared<FutureMessage>(); |
|
// Kick off processing for the nested future and get a Future<T> to the |
|
// result. |
|
processRpc( |
|
rpcWithAutograd.wrappedRpc(), |
|
wrappedMessageType, |
|
messageId, |
|
wrappedRpcResponseFuture); |
|
|
|
auto fromWorkerId = rpcWithAutograd.fromWorkerId(); |
|
// The original future needs to be marked as completed when the wrapped |
|
// one completes, with the autograd context information wrapped. |
|
// Uses weak_ptr so we can std::move the value. |
|
wrappedRpcResponseFuture->addCallback( |
|
[responseFuture, |
|
messageId, |
|
fromWorkerId, |
|
weak = std::weak_ptr<FutureMessage>(wrappedRpcResponseFuture)]() { |
|
auto wrappedRpcResponseFuture = weak.lock(); |
|
TORCH_INTERNAL_ASSERT(wrappedRpcResponseFuture); |
|
if (wrappedRpcResponseFuture->hasError()) { |
|
// Propagate error to responseFuture if we had one. |
|
responseFuture->setError( |
|
wrappedRpcResponseFuture->error()->what()); |
|
} else { |
|
auto msg = getMessageWithAutograd( |
|
fromWorkerId, |
|
std::move(*wrappedRpcResponseFuture).moveValue(), |
|
MessageType::FORWARD_AUTOGRAD_RESP); |
|
msg.setId(messageId); |
|
responseFuture->markCompleted(std::move(msg)); |
|
} |
|
}); |
|
return; |
This could be the cause of the recent flakiness in distributed autograd test.
cc @ezyang @gchanan @zou3519 @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @rohan-varma @xush6528 @jjlilley @osalpekar
In early versions of RPC, all threads on server side are block-waiting until the request is fully processed. Hence we use a guard to clear the
thread_localautograd context.pytorch/torch/csrc/distributed/rpc/request_callback_impl.cpp
Lines 563 to 564 in 3d02798
However, since we gradually making the RPC non-blocking on server side, the above guard is no longer correct. For example, in
FORWARD_AUTOGRAD_REQ, the bottom half of the processing is done in a callback, which could run on a different thread. If it indeed runs on a different thread, the autograd context id is no longer valid.pytorch/torch/csrc/distributed/rpc/request_callback_impl.cpp
Lines 445 to 483 in 3d02798
This could be the cause of the recent flakiness in distributed autograd test.
cc @ezyang @gchanan @zou3519 @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @gqchen @aazzolini @rohan-varma @xush6528 @jjlilley @osalpekar