Skip to content

Replace Variable.volatile with torch.no_grad()#3970

Merged
colesbury merged 11 commits intopytorch:masterfrom
colesbury:backprop_mode
Dec 18, 2017
Merged

Replace Variable.volatile with torch.no_grad()#3970
colesbury merged 11 commits intopytorch:masterfrom
colesbury:backprop_mode

Conversation

@colesbury
Copy link
Copy Markdown
Member

@colesbury colesbury commented Dec 1, 2017

This removes volatile from Variable. The functionality is mostly
replaced by a global (thread-local) flag, which is controlled by
torch.set_grad_enabled() and the context manager torch.no_grad().

In C++, the flag is exposed through GradMode::is_enabled() and GradMode::set_enabled()

Fixes #3627

@pytorchbot
Copy link
Copy Markdown
Collaborator

@colesbury, thanks for your PR! We identified @zdevito to be a potential reviewer.

@colesbury
Copy link
Copy Markdown
Member Author

I placed the context manage in the torch package instead of torch.autograd in anticipation of merging Tensor and Variable.

@apaszke
Copy link
Copy Markdown
Contributor

apaszke commented Dec 2, 2017

Haven't started the review, but can we please call it no_grad or sth like this? Backprop is very nn-specific and is a special case of a more general reverse-mode AD, which is the level on which autograd operates.

colesbury added a commit to colesbury/pytorch that referenced this pull request Dec 2, 2017
Gradients were becoming non-volatile because at::zeros_like returned a
Variable with volatile always set to false. The non-volatile gradients
accumulated history in the model which results in continuously
increasing memory usage,

See pytorch#3983, pytorch#3835, pytorch#3824

In v0.4 this will be more robustly solved by pytorch#3970
soumith pushed a commit that referenced this pull request Dec 2, 2017
Gradients were becoming non-volatile because at::zeros_like returned a
Variable with volatile always set to false. The non-volatile gradients
accumulated history in the model which results in continuously
increasing memory usage,

See #3983, #3835, #3824

In v0.4 this will be more robustly solved by #3970
peterjc123 pushed a commit to peterjc123/pytorch that referenced this pull request Dec 4, 2017
Gradients were becoming non-volatile because at::zeros_like returned a
Variable with volatile always set to false. The non-volatile gradients
accumulated history in the model which results in continuously
increasing memory usage,

See pytorch#3983, pytorch#3835, pytorch#3824

In v0.4 this will be more robustly solved by pytorch#3970
soumith pushed a commit that referenced this pull request Dec 4, 2017
Gradients were becoming non-volatile because at::zeros_like returned a
Variable with volatile always set to false. The non-volatile gradients
accumulated history in the model which results in continuously
increasing memory usage,

See #3983, #3835, #3824

In v0.4 this will be more robustly solved by #3970
@Randl
Copy link
Copy Markdown
Contributor

Randl commented Dec 13, 2017

ping

Copy link
Copy Markdown
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great. I have a ton of questions, because I want to make sure that everything works fine, and we haven't missed anything. Also, I think there are a few things that should be fixed before we merge this.

Comment thread tools/autograd/gen_variable_type.py Outdated

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Comment thread torch/csrc/autograd/python_variable.cpp Outdated

This comment was marked as off-topic.

This comment was marked as off-topic.

Comment thread torch/csrc/autograd/variable.h Outdated

This comment was marked as off-topic.

This comment was marked as off-topic.

Comment thread torch/csrc/autograd/variable.h Outdated

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Comment thread torch/csrc/autograd/variable.h Outdated

This comment was marked as off-topic.

This comment was marked as off-topic.

Comment thread torch/csrc/autograd/variable.cpp Outdated

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@colesbury colesbury force-pushed the backprop_mode branch 2 times, most recently from bc4bd95 to 4968993 Compare December 13, 2017 23:36
@colesbury colesbury changed the title Replace Variable.volatile with torch.no_backprop() Replace Variable.volatile with torch.no_grad() Dec 13, 2017
Comment thread test/test_autograd.py Outdated

This comment was marked as off-topic.

Comment thread test/test_autograd.py Outdated

This comment was marked as off-topic.

Comment thread tools/autograd/gen_variable_type.py Outdated

This comment was marked as off-topic.

Copy link
Copy Markdown
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add a test that checks that views of a detached base still require grad and have grad functions?

Comment thread torch/csrc/autograd/variable.h Outdated

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

Comment thread torch/autograd/function.py Outdated

This comment was marked as off-topic.

This removes volatile from Variable. The functionality is mostly
replaced by a global (thread-local) flag, which is controlled by
torch.set_backprop_enabled() and the context manager
torch.no_backprop().

Fixes pytorch#3627
rebase_history also now immediately updates view._grad_fn.
@colesbury colesbury merged commit d605058 into pytorch:master Dec 18, 2017
@colesbury colesbury deleted the backprop_mode branch December 18, 2017 20:46
@Randl Randl mentioned this pull request Dec 19, 2017
base.output_nr() = 0;
base.get()->_grad_fn = std::make_shared<CopySlices>(
base, TensorGeometry(data), std::move(grad_fn));
get_grad_fn(); // trigger an update to the view's grad_fn

This comment was marked as off-topic.

This comment was marked as off-topic.

This comment was marked as off-topic.

@rawalkhirodkar
Copy link
Copy Markdown

Can we update the documentation highlighting the removal of volatile in favor of torch.no_grad( )?
Thank you

@apaszke
Copy link
Copy Markdown
Contributor

apaszke commented Jan 21, 2018

@rawalkhirodkar yes, that's definitely needed. Can you please send a PR?

@Randl
Copy link
Copy Markdown
Contributor

Randl commented Jan 21, 2018

Also I'm not sure what to do if only part of variables don't require grads

@apaszke
Copy link
Copy Markdown
Contributor

apaszke commented Jan 21, 2018

@Randl requires_grad still works as it used to (when grad mode is enabled)

@Randl
Copy link
Copy Markdown
Contributor

Randl commented Jan 22, 2018

@apaszke I probably didn't explain myself well.

As far as I understand, with torch.no_grad() is equivalent to running with all variables volatile. What if want only one variable to be volatile?

@apaszke
Copy link
Copy Markdown
Contributor

apaszke commented Jan 22, 2018

Then you need to separate the volatile codepaths from the other ones. I haven't yet seen a case where keeping only some Variables volatile is useful.

@fmassa
Copy link
Copy Markdown
Member

fmassa commented Jan 22, 2018

@Randl also note that volatile variables propagate in the graph. So if you have 2 variables, a and b, where only a is volatile, the result of an operation between a and b will be volatile as well.

@my-hello-world
Copy link
Copy Markdown

hello,i want to know that if PyTorch 0.4 can with CUDA7.5 support.cause when i install cuda8.0,it show:" Found GPU1 Quadro K2000 which is of cuda capability 3.0.
PyTorch no longer supports this GPU because it is too old."
i sorry that i can't afford the new Nvidia.
now, i install cuda7.5,and when i try "export CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" # [anaconda root directory]
conda install numpy pyyaml mkl mkl-include setuptools cmake cffi typing
conda install -c mingfeima mkldnn
conda install -c pytorch magma-cuda75
git clone --recursive https://github.com/pytorch/pytorch
cd pytorch
python setup.py install" the pytorch -V is 0.3.0
so how can i do?
My environment is : ubuntu14.04,anaconda, python3.6.5
Thanks!!!!!

wuhuikx pushed a commit to wuhuikx/pytorch that referenced this pull request Jan 30, 2020
Gradients were becoming non-volatile because at::zeros_like returned a
Variable with volatile always set to false. The non-volatile gradients
accumulated history in the model which results in continuously
increasing memory usage,

See pytorch#3983, pytorch#3835, pytorch#3824

In v0.4 this will be more robustly solved by pytorch#3970
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
This removes volatile from Variable. The functionality is mostly
replaced by a global (thread-local) flag, which is controlled by
torch.set_grad_enabled() and the context manager torch.no_grad().

In C++, the flag is exposed through GradMode::is_enabled() and GradMode::set_enabled()

Fixes pytorch#3627
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants