Deprecate ctx.saved_variables via python warning.#5923
Merged
ezyang merged 3 commits intopytorch:masterfrom Mar 26, 2018
Merged
Deprecate ctx.saved_variables via python warning.#5923ezyang merged 3 commits intopytorch:masterfrom
ezyang merged 3 commits intopytorch:masterfrom
Conversation
apaszke
reviewed
Mar 21, 2018
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
58eb1b1 to
cb6d2c2
Compare
Contributor
|
It might be good to leave in one test for the deprecated pattern so we don't accidentally delete it before the deprecation cycle is over. |
Contributor
|
@pytorchbot retest this please |
1 similar comment
Contributor
|
@pytorchbot retest this please |
Advises replacing saved_variables with saved_tensors.
Also replaces all instances of ctx.saved_variables with ctx.saved_tensors in the
codebase.
Test by running:
```
import torch
from torch.autograd import Function
class MyFunction(Function):
@staticmethod
def forward(ctx, tensor1, tensor2):
ctx.save_for_backward(tensor1, tensor2)
return tensor1 + tensor2
@staticmethod
def backward(ctx, grad_output):
var1, var2 = ctx.saved_variables
return (grad_output, grad_output)
x = torch.randn((3, 3), requires_grad=True)
y = torch.randn((3, 3), requires_grad=True)
model = MyFunction()
model.apply(x, y).sum().backward()
```
and assert the warning shows up.
66b3f66 to
a97394b
Compare
laurentdupin
pushed a commit
to laurentdupin/pytorch
that referenced
this pull request
Apr 24, 2026
* Deprecate ctx.saved_variables via python warning.
Advises replacing saved_variables with saved_tensors.
Also replaces all instances of ctx.saved_variables with ctx.saved_tensors in the
codebase.
Test by running:
```
import torch
from torch.autograd import Function
class MyFunction(Function):
@staticmethod
def forward(ctx, tensor1, tensor2):
ctx.save_for_backward(tensor1, tensor2)
return tensor1 + tensor2
@staticmethod
def backward(ctx, grad_output):
var1, var2 = ctx.saved_variables
return (grad_output, grad_output)
x = torch.randn((3, 3), requires_grad=True)
y = torch.randn((3, 3), requires_grad=True)
model = MyFunction()
model.apply(x, y).sum().backward()
```
and assert the warning shows up.
* Address comments
* Add deprecation test for saved_variables
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
As suggested in #5907. I'm trying to keep that PR to mostly doc changes so opening this one for this deprecation.
Advises replacing saved_variables with saved_tensors.
Also replaces all instances of ctx.saved_variables with ctx.saved_tensors in the
codebase.
cc @apaszke @colesbury
Test by running:
and assert the warning shows up.