Skip to content

Restore thread_local states in continuation thread on RPC servers#38512

Closed
mrshenli wants to merge 9 commits intogh/mrshenli/179/basefrom
gh/mrshenli/179/head
Closed

Restore thread_local states in continuation thread on RPC servers#38512
mrshenli wants to merge 9 commits intogh/mrshenli/179/basefrom
gh/mrshenli/179/head

Conversation

@mrshenli
Copy link
Copy Markdown
Contributor

@mrshenli mrshenli commented May 14, 2020

Stack from ghstack:

As we gradually making the RPC non-blocking on server side, the
processing of the same request can yield-run on different threads.
Hence, we need to populate thread_local states (e.g., ctx id) in
the continuation thread.

Fixes #38439

Differential Revision: D21583642

As we gradually making the RPC non-blocking on server side, the
processing of the same request can yield-run on different threads.
Hence, we need to populate thread_local states (e.g., ctx id) in
the continuation thread.

Fixes #38439

[ghstack-poisoned]
…servers"

As we gradually making the RPC non-blocking on server side, the
processing of the same request can yield-run on different threads.
Hence, we need to populate thread_local states (e.g., ctx id) in
the continuation thread.

Fixes #38439

[ghstack-poisoned]
return t3

@dist_init
def test_thread_local_context_id(self):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I confirm that this test fails without this fix.

…servers"

As we gradually making the RPC non-blocking on server side, the
processing of the same request can yield-run on different threads.
Hence, we need to populate thread_local states (e.g., ctx id) in
the continuation thread.

Fixes #38439

[ghstack-poisoned]
mrshenli pushed a commit that referenced this pull request May 14, 2020
As we gradually making the RPC non-blocking on server side, the
processing of the same request can yield-run on different threads.
Hence, we need to populate thread_local states (e.g., ctx id) in
the continuation thread.

Fixes #38439

ghstack-source-id: 1d3307f
Pull Request resolved: #38512
weak = std::weak_ptr<FutureMessage>(wrappedRpcResponseFuture)]() {
weak = std::weak_ptr<FutureMessage>(wrappedRpcResponseFuture),
threadLocalState = ThreadLocalState(),
ctxId = autogradContext->contextId()]() {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is currently the only place on server-side where we have such a continuation by continuing processing with addCallback, correct?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are more in other types of requests. But IIUC, this is the only place we need to propagate context id.

fromWorkerId,
weak = std::weak_ptr<FutureMessage>(wrappedRpcResponseFuture)]() {
weak = std::weak_ptr<FutureMessage>(wrappedRpcResponseFuture),
threadLocalState = ThreadLocalState(),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need TLS state or just the dist autograd context id for now? (I am planning to eventually use TLS state for distributed profiler work, but curious if this is already needed now)

Copy link
Copy Markdown
Contributor

@rohan-varma rohan-varma May 15, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, could we potentially reuse the approach taken in record_function_ops.cpp?
This is basically the same here, but there we declare the tls_state outside of the lambda capture and std::move it into the capture. Not sure if there's a difference perf wise.

also cc @ilia-cher, if you have any comments on the usage here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need TLS state or just the dist autograd context id for now? (I am planning to eventually use TLS state for distributed profiler work, but curious if this is already needed now)

For now we only need the context id I think. @xush6528 also pointed out that we will need this for profiler later, so added ThreadLocalState as well. I think we can also remove it in this PR and leave to the profiler-related PR?

This is basically the same here, but there we declare the tls_state outside of the lambda capture and std::move it into the capture. Not sure if there's a difference perf wise.

Should be the same I guess, as both are rvalue references?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can remove. I think profiler state is not needed because no more tensor operations happens in this callback.
Autograd context id is needed because getMessageWithAutograd(..) needs it.

We will need to restore both profiler state and autograd context id for python call user async function continuation.

Comment thread torch/testing/_internal/distributed/rpc/dist_autograd_test.py Outdated
…servers"

As we gradually making the RPC non-blocking on server side, the
processing of the same request can yield-run on different threads.
Hence, we need to populate thread_local states (e.g., ctx id) in
the continuation thread.

Fixes #38439

Differential Revision: [D21583642](https://our.internmc.facebook.com/intern/diff/D21583642)

[ghstack-poisoned]
mrshenli pushed a commit that referenced this pull request May 15, 2020
As we gradually making the RPC non-blocking on server side, the
processing of the same request can yield-run on different threads.
Hence, we need to populate thread_local states (e.g., ctx id) in
the continuation thread.

Fixes #38439

ghstack-source-id: 5cd17f0
Pull Request resolved: #38512
@dr-ci
Copy link
Copy Markdown

dr-ci Bot commented May 15, 2020

💊 CI failures summary and remediations

As of commit 95d287f (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

See how this bot performed.

This comment has been revised 11 times.

…servers"

As we gradually making the RPC non-blocking on server side, the
processing of the same request can yield-run on different threads.
Hence, we need to populate thread_local states (e.g., ctx id) in
the continuation thread.

Fixes #38439

Differential Revision: [D21583642](https://our.internmc.facebook.com/intern/diff/D21583642)

[ghstack-poisoned]
mrshenli pushed a commit that referenced this pull request May 15, 2020
As we gradually making the RPC non-blocking on server side, the
processing of the same request can yield-run on different threads.
Hence, we need to populate thread_local states (e.g., ctx id) in
the continuation thread.

Fixes #38439

ghstack-source-id: 8787a49
Pull Request resolved: #38512
Comment thread torch/testing/_internal/distributed/rpc/dist_autograd_test.py Outdated
rref = rpc.remote(dst, DistAutogradTest._slow_add, args=(t1, t2))

with dist_autograd.context() as context_id:
loss = rref.to_here().sum()
Copy link
Copy Markdown
Contributor

@xush6528 xush6528 May 15, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me write down a note for whoever like me is not super clear with FORWARD_AUTOGRAD_REQ.

The rref.to_here() here is a PYTHON_RREF_FETCH_CALL, which is always wrapped in FORWARD_AUTORAD_REQUEST as long as it's called within autograd context, because of forceGradRecording on this line.

auto futureResponse = autograd::sendMessageWithAutograd(
*agent,
agent->getWorkerInfo(ownerId_),
std::move(msgToSend),
true /* forceGradRecording */);

When the PYTHON_RREF_FETCH_RESPONSE is ready, as in this line,

responseFuture->markCompleted(std::move(m));

the response message is wrapped into FORWARD_AUTOGRAD_RESP by this line,

auto msg = getMessageWithAutograd(
fromWorkerId,
std::move(*wrappedRpcResponseFuture).moveValue(),
MessageType::FORWARD_AUTOGRAD_RESP);

where the key getMessageWithAutograd(...) is called to add a SendFunction to the autograd context and the on the client receiving this FORWARD_AUTOGRAD_RESP will add a corresponding RecvFunction to the dist autograd context.

If the autograd context is not restored here, the SendFunction will be added to a wrong autograd context, or crashes because of no active context. And the following sum() backward will be propagated back to server but in a wrong autograd context.

Why this test can reproduce the thread switch?

By making rpc.remote(..., _slow_add, ...) request slow, and the rref.to_here() here is processed before the RRef value is set.

So the to_here() request callback will run on another thread. The thread to run it is exactly the thread for rpc.remote(..), because it marks the value as ready, thus responsible to run the added callbacks.

Since rpc.remote(..) is called outside of a dist autograd context, getMessageWithAutograd(...) does not wrap the PYTHON_REMOTE_CALL into a FORWARD_AUTORAD_REQUEST, no autograd context is available on this thread.

…servers"

As we gradually making the RPC non-blocking on server side, the
processing of the same request can yield-run on different threads.
Hence, we need to populate thread_local states (e.g., ctx id) in
the continuation thread.

Fixes #38439

Differential Revision: [D21583642](https://our.internmc.facebook.com/intern/diff/D21583642)

[ghstack-poisoned]
mrshenli pushed a commit that referenced this pull request May 15, 2020
As we gradually making the RPC non-blocking on server side, the
processing of the same request can yield-run on different threads.
Hence, we need to populate thread_local states (e.g., ctx id) in
the continuation thread.

Fixes #38439

ghstack-source-id: afb9d07
Pull Request resolved: #38512
…servers"

As we gradually making the RPC non-blocking on server side, the
processing of the same request can yield-run on different threads.
Hence, we need to populate thread_local states (e.g., ctx id) in
the continuation thread.

Fixes #38439

Differential Revision: [D21583642](https://our.internmc.facebook.com/intern/diff/D21583642)

[ghstack-poisoned]
…servers"

As we gradually making the RPC non-blocking on server side, the
processing of the same request can yield-run on different threads.
Hence, we need to populate thread_local states (e.g., ctx id) in
the continuation thread.

Fixes #38439

Differential Revision: [D21583642](https://our.internmc.facebook.com/intern/diff/D21583642)

[ghstack-poisoned]
// thread_local states there.
// TODO: Land on a general solution for RPC ThreadLocalState. See
// https://github.com/pytorch/pytorch/issues/38510
DistAutogradContextGuard ctxGuard(ctxId);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we do this in addCallback itself instead of fixing each callsite of addCallback?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, eventually, we should do that I think. But for now we still have two Futures (utils and ivalue), and not every callback needs to capture ThreadLocalState, it might be better to first fix master for now. Let's revisit this when we reach a consensus on how we should implement RPC ThreadLocalState.

…servers"

As we gradually making the RPC non-blocking on server side, the
processing of the same request can yield-run on different threads.
Hence, we need to populate thread_local states (e.g., ctx id) in
the continuation thread.

Fixes #38439

Differential Revision: [D21583642](https://our.internmc.facebook.com/intern/diff/D21583642)

[ghstack-poisoned]
mrshenli pushed a commit that referenced this pull request May 15, 2020
As we gradually making the RPC non-blocking on server side, the
processing of the same request can yield-run on different threads.
Hence, we need to populate thread_local states (e.g., ctx id) in
the continuation thread.

Fixes #38439

ghstack-source-id: da83b3a
Pull Request resolved: #38512
@facebook-github-bot
Copy link
Copy Markdown
Contributor

@mrshenli merged this pull request in f39222a.

@mrshenli
Copy link
Copy Markdown
Contributor Author

If we are going to release v1.5.1, we should include this fix.

mrshenli pushed a commit to mrshenli/pytorch that referenced this pull request May 28, 2020
…torch#38512)

Summary:
Pull Request resolved: pytorch#38512

As we gradually making the RPC non-blocking on server side, the
processing of the same request can yield-run on different threads.
Hence, we need to populate thread_local states (e.g., ctx id) in
the continuation thread.

Fixes pytorch#38439

Test Plan: Imported from OSS

Differential Revision: D21583642

Pulled By: mrshenli

fbshipit-source-id: a79bce1cb207fd11f1fa02b08465e49badda65fc
gchanan pushed a commit that referenced this pull request Jun 1, 2020
…8512)

Summary:
Pull Request resolved: #38512

As we gradually making the RPC non-blocking on server side, the
processing of the same request can yield-run on different threads.
Hence, we need to populate thread_local states (e.g., ctx id) in
the continuation thread.

Fixes #38439

Test Plan: Imported from OSS

Differential Revision: D21583642

Pulled By: mrshenli

fbshipit-source-id: a79bce1cb207fd11f1fa02b08465e49badda65fc
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
…torch#38512)

Summary:
Pull Request resolved: pytorch#38512

As we gradually making the RPC non-blocking on server side, the
processing of the same request can yield-run on different threads.
Hence, we need to populate thread_local states (e.g., ctx id) in
the continuation thread.

Fixes pytorch#38439

Test Plan: Imported from OSS

Differential Revision: D21583642

Pulled By: mrshenli

fbshipit-source-id: a79bce1cb207fd11f1fa02b08465e49badda65fc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Future callbacks in RPC should capture and restore autograd context id

6 participants