[1.6 cherry pick] Add guard for non-default stream in DDP's autograd engine callback#41151
Merged
malfet merged 1 commit intorelease/1.6from Jul 9, 2020
Merged
[1.6 cherry pick] Add guard for non-default stream in DDP's autograd engine callback#41151malfet merged 1 commit intorelease/1.6from
malfet merged 1 commit intorelease/1.6from
Conversation
…40115) Summary: Pull Request resolved: #40115 Closes #37790 Closes #37944 A user may wish to run DDP's forward + backwards step under a non-default CUDA stream such as those created by `with torch.cuda.Stream(stream)`. In this case, the user should be responsible for synchronizing events on this stream with other streams used in the program (per the documentation at https://pytorch.org/docs/stable/notes/cuda.html#cuda-semantics), but currently DDP has a bug which causes DDP under non-default streams to fail. If a user does the following: ``` model = DDP(...) loss = model(inptut).sum() loss.backward() grad = model.module.weight.grad() average = dist.all_reduce(grad) ``` There is a chance that `average` and `grad` will not be equal. This is because the CUDA kernels corresponding to the `all_reduce` call may run before `loss.backward()`'s kernels are finished. Specifically, in DDP we copy the allreduced gradients back to the model parameter gradients in an autograd engine callback, but this callback runs on the default stream. Note that this can also be fixed by the application synchronizing on the current stream, although this should not be expected, since the application is not using the current stream at all. This PR fixes the issue by passing the current stream into DDP's callback. Tested by adding a UT `test_DistributedDataParallel_non_default_stream` that fails without this PR ghstack-source-id: 106481208 Differential Revision: D22073353 fbshipit-source-id: 70da9b44e5f546ff8b6d8c42022ecc846dff033e
mrshenli
approved these changes
Jul 8, 2020
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Port of #40115, to merge into release/1.6.
Original commit description:
Summary:
Pull Request resolved: #40115
Closes #37790
Closes #37944
A user may wish to run DDP's forward + backwards step under a non-default CUDA stream such as those created by
with torch.cuda.Stream(stream). In this case, the user should be responsible for synchronizing events on this stream with other streams used in the program (per the documentation at https://pytorch.org/docs/stable/notes/cuda.html#cuda-semantics), but currently DDP has a bug which causes DDP under non-default streams to fail.If a user does the following:
There is a chance that
averageandgradwill not be equal. This is because the CUDA kernels corresponding to theall_reducecall may run beforeloss.backward()'s kernels are finished. Specifically, in DDP we copy the allreduced gradients back to the model parameter gradients in an autograd engine callback, but this callback runs on the default stream. Note that this can also be fixed by the application synchronizing on the current stream, although this should not be expected, since the application is not using the current stream at all.This PR fixes the issue by passing the current stream into DDP's callback.
Tested by adding a UT
test_DistributedDataParallel_non_default_streamthat fails without this PRghstack-source-id: 106481208
Differential Revision: D22073353
fbshipit-source-id: 70da9b44e5f546ff8b6d8c42022ecc846dff033e