Grad clip for parameters on different devices#9302
Closed
Stonesjtu wants to merge 1 commit intopytorch:masterfrom
Closed
Grad clip for parameters on different devices#9302Stonesjtu wants to merge 1 commit intopytorch:masterfrom
Stonesjtu wants to merge 1 commit intopytorch:masterfrom
Conversation
soumith
approved these changes
Jul 10, 2018
Contributor
facebook-github-bot
left a comment
There was a problem hiding this comment.
@soumith is landing this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Collaborator
|
thank you, this looks good! |
Contributor
|
Wouldn’t it be much faster to just use a defaultdict to accumulate norms on different devices and only them transfer them all to CPU? That would at least cover the common case of all params on a single GPU |
Contributor
Author
|
@apaszke I've thought about that, but counter-intuitively on my environment, the scalar addition on a single device does not run as fast as expected. My envs: |
goodlux
pushed a commit
to goodlux/pytorch
that referenced
this pull request
Aug 15, 2018
Summary:
I'm trying to write a multi-gpu network by pipelining some layers onto different GPUs. However, the current gradient clip requires all the parameters to locate in the same device.
The overhead of CUDA launch is reduced since the scalar calculation is performed on CPU, but it introduces extra data transfers.
No performance regression is observed by running the following snippet:
```python
import time
import torch
module = torch.nn.Sequential(
torch.nn.LSTM(1024, 1024),
torch.nn.LSTM(256, 256),
torch.nn.Linear(100, 10000),
).cuda()
torch.nn.utils.clip_grad_norm_(module.parameters(), 1)
torch.cuda.synchronize()
start = time.time()
for _ in range(1000):
torch.nn.utils.clip_grad_norm_(module.parameters(), 1)
torch.cuda.synchronize()
time_elapse = time.time() - start
print('{} ms per clip'.format(time_elapse))
```
Pull Request resolved: pytorch#9302
Differential Revision: D8781551
Pulled By: soumith
fbshipit-source-id: 9d76d01fe0531927f770a16b9523872a7e08e927
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
I'm trying to write a multi-gpu network by pipelining some layers onto different GPUs. However, the current gradient clip requires all the parameters to locate in the same device.
The overhead of CUDA launch is reduced since the scalar calculation is performed on CPU, but it introduces extra data transfers.
No performance regression is observed by running the following snippet: