Restore thread_local states in continuation thread on RPC servers#38512
Restore thread_local states in continuation thread on RPC servers#38512mrshenli wants to merge 9 commits intogh/mrshenli/179/basefrom
Conversation
As we gradually making the RPC non-blocking on server side, the processing of the same request can yield-run on different threads. Hence, we need to populate thread_local states (e.g., ctx id) in the continuation thread. Fixes #38439 [ghstack-poisoned]
…servers" As we gradually making the RPC non-blocking on server side, the processing of the same request can yield-run on different threads. Hence, we need to populate thread_local states (e.g., ctx id) in the continuation thread. Fixes #38439 [ghstack-poisoned]
| return t3 | ||
|
|
||
| @dist_init | ||
| def test_thread_local_context_id(self): |
There was a problem hiding this comment.
I confirm that this test fails without this fix.
…servers" As we gradually making the RPC non-blocking on server side, the processing of the same request can yield-run on different threads. Hence, we need to populate thread_local states (e.g., ctx id) in the continuation thread. Fixes #38439 [ghstack-poisoned]
| weak = std::weak_ptr<FutureMessage>(wrappedRpcResponseFuture)]() { | ||
| weak = std::weak_ptr<FutureMessage>(wrappedRpcResponseFuture), | ||
| threadLocalState = ThreadLocalState(), | ||
| ctxId = autogradContext->contextId()]() { |
There was a problem hiding this comment.
This is currently the only place on server-side where we have such a continuation by continuing processing with addCallback, correct?
There was a problem hiding this comment.
There are more in other types of requests. But IIUC, this is the only place we need to propagate context id.
- In BACKWARD_AUTOGRAD_REQ, it does not need the context id when creating the
PropagateGradientsRespobj - In SCRIPT_CALL and SCRIPT_REMOTE_CALL, the context id is fixed by [DistAutograd x JIT] Capture global state, dist autograd current context id, before thread switching triggered by JIT future.wait() #36395.
- PYTHON_CALL and PYTHON_REMOTE_CALL currently do not yield. But when we add async user function, we also need to propagate the TLS here.
- SCRIPT_RREF_FETCH_CALL and PYTHON_RREF_FETCH_CALL do not need the context id either.
| fromWorkerId, | ||
| weak = std::weak_ptr<FutureMessage>(wrappedRpcResponseFuture)]() { | ||
| weak = std::weak_ptr<FutureMessage>(wrappedRpcResponseFuture), | ||
| threadLocalState = ThreadLocalState(), |
There was a problem hiding this comment.
Do we need TLS state or just the dist autograd context id for now? (I am planning to eventually use TLS state for distributed profiler work, but curious if this is already needed now)
There was a problem hiding this comment.
Also, could we potentially reuse the approach taken in record_function_ops.cpp?
This is basically the same here, but there we declare the tls_state outside of the lambda capture and std::move it into the capture. Not sure if there's a difference perf wise.
also cc @ilia-cher, if you have any comments on the usage here.
There was a problem hiding this comment.
Do we need TLS state or just the dist autograd context id for now? (I am planning to eventually use TLS state for distributed profiler work, but curious if this is already needed now)
For now we only need the context id I think. @xush6528 also pointed out that we will need this for profiler later, so added ThreadLocalState as well. I think we can also remove it in this PR and leave to the profiler-related PR?
This is basically the same here, but there we declare the tls_state outside of the lambda capture and std::move it into the capture. Not sure if there's a difference perf wise.
Should be the same I guess, as both are rvalue references?
There was a problem hiding this comment.
You can remove. I think profiler state is not needed because no more tensor operations happens in this callback.
Autograd context id is needed because getMessageWithAutograd(..) needs it.
We will need to restore both profiler state and autograd context id for python call user async function continuation.
…servers" As we gradually making the RPC non-blocking on server side, the processing of the same request can yield-run on different threads. Hence, we need to populate thread_local states (e.g., ctx id) in the continuation thread. Fixes #38439 Differential Revision: [D21583642](https://our.internmc.facebook.com/intern/diff/D21583642) [ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit 95d287f (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker. This comment has been revised 11 times. |
…servers" As we gradually making the RPC non-blocking on server side, the processing of the same request can yield-run on different threads. Hence, we need to populate thread_local states (e.g., ctx id) in the continuation thread. Fixes #38439 Differential Revision: [D21583642](https://our.internmc.facebook.com/intern/diff/D21583642) [ghstack-poisoned]
| rref = rpc.remote(dst, DistAutogradTest._slow_add, args=(t1, t2)) | ||
|
|
||
| with dist_autograd.context() as context_id: | ||
| loss = rref.to_here().sum() |
There was a problem hiding this comment.
Let me write down a note for whoever like me is not super clear with FORWARD_AUTOGRAD_REQ.
The rref.to_here() here is a PYTHON_RREF_FETCH_CALL, which is always wrapped in FORWARD_AUTORAD_REQUEST as long as it's called within autograd context, because of forceGradRecording on this line.
pytorch/torch/csrc/distributed/rpc/rref_impl.cpp
Lines 145 to 149 in dfcea82
When the PYTHON_RREF_FETCH_RESPONSE is ready, as in this line,
the response message is wrapped into FORWARD_AUTOGRAD_RESP by this line,
pytorch/torch/csrc/distributed/rpc/request_callback_impl.cpp
Lines 496 to 499 in dfcea82
where the key getMessageWithAutograd(...) is called to add a SendFunction to the autograd context and the on the client receiving this FORWARD_AUTOGRAD_RESP will add a corresponding RecvFunction to the dist autograd context.
If the autograd context is not restored here, the SendFunction will be added to a wrong autograd context, or crashes because of no active context. And the following sum() backward will be propagated back to server but in a wrong autograd context.
Why this test can reproduce the thread switch?
By making rpc.remote(..., _slow_add, ...) request slow, and the rref.to_here() here is processed before the RRef value is set.
So the to_here() request callback will run on another thread. The thread to run it is exactly the thread for rpc.remote(..), because it marks the value as ready, thus responsible to run the added callbacks.
Since rpc.remote(..) is called outside of a dist autograd context, getMessageWithAutograd(...) does not wrap the PYTHON_REMOTE_CALL into a FORWARD_AUTORAD_REQUEST, no autograd context is available on this thread.
…servers" As we gradually making the RPC non-blocking on server side, the processing of the same request can yield-run on different threads. Hence, we need to populate thread_local states (e.g., ctx id) in the continuation thread. Fixes #38439 Differential Revision: [D21583642](https://our.internmc.facebook.com/intern/diff/D21583642) [ghstack-poisoned]
…servers" As we gradually making the RPC non-blocking on server side, the processing of the same request can yield-run on different threads. Hence, we need to populate thread_local states (e.g., ctx id) in the continuation thread. Fixes #38439 Differential Revision: [D21583642](https://our.internmc.facebook.com/intern/diff/D21583642) [ghstack-poisoned]
…servers" As we gradually making the RPC non-blocking on server side, the processing of the same request can yield-run on different threads. Hence, we need to populate thread_local states (e.g., ctx id) in the continuation thread. Fixes #38439 Differential Revision: [D21583642](https://our.internmc.facebook.com/intern/diff/D21583642) [ghstack-poisoned]
| // thread_local states there. | ||
| // TODO: Land on a general solution for RPC ThreadLocalState. See | ||
| // https://github.com/pytorch/pytorch/issues/38510 | ||
| DistAutogradContextGuard ctxGuard(ctxId); |
There was a problem hiding this comment.
Shouldn't we do this in addCallback itself instead of fixing each callsite of addCallback?
There was a problem hiding this comment.
Yes, eventually, we should do that I think. But for now we still have two Futures (utils and ivalue), and not every callback needs to capture ThreadLocalState, it might be better to first fix master for now. Let's revisit this when we reach a consensus on how we should implement RPC ThreadLocalState.
…servers" As we gradually making the RPC non-blocking on server side, the processing of the same request can yield-run on different threads. Hence, we need to populate thread_local states (e.g., ctx id) in the continuation thread. Fixes #38439 Differential Revision: [D21583642](https://our.internmc.facebook.com/intern/diff/D21583642) [ghstack-poisoned]
|
If we are going to release v1.5.1, we should include this fix. |
…torch#38512) Summary: Pull Request resolved: pytorch#38512 As we gradually making the RPC non-blocking on server side, the processing of the same request can yield-run on different threads. Hence, we need to populate thread_local states (e.g., ctx id) in the continuation thread. Fixes pytorch#38439 Test Plan: Imported from OSS Differential Revision: D21583642 Pulled By: mrshenli fbshipit-source-id: a79bce1cb207fd11f1fa02b08465e49badda65fc
…8512) Summary: Pull Request resolved: #38512 As we gradually making the RPC non-blocking on server side, the processing of the same request can yield-run on different threads. Hence, we need to populate thread_local states (e.g., ctx id) in the continuation thread. Fixes #38439 Test Plan: Imported from OSS Differential Revision: D21583642 Pulled By: mrshenli fbshipit-source-id: a79bce1cb207fd11f1fa02b08465e49badda65fc
…torch#38512) Summary: Pull Request resolved: pytorch#38512 As we gradually making the RPC non-blocking on server side, the processing of the same request can yield-run on different threads. Hence, we need to populate thread_local states (e.g., ctx id) in the continuation thread. Fixes pytorch#38439 Test Plan: Imported from OSS Differential Revision: D21583642 Pulled By: mrshenli fbshipit-source-id: a79bce1cb207fd11f1fa02b08465e49badda65fc
Stack from ghstack:
As we gradually making the RPC non-blocking on server side, the
processing of the same request can yield-run on different threads.
Hence, we need to populate thread_local states (e.g., ctx id) in
the continuation thread.
Fixes #38439
Differential Revision: D21583642