Skip to content

Deprecate ctx.saved_variables via python warning.#5923

Merged
ezyang merged 3 commits intopytorch:masterfrom
zou3519:deprecate-saved-vars
Mar 26, 2018
Merged

Deprecate ctx.saved_variables via python warning.#5923
ezyang merged 3 commits intopytorch:masterfrom
zou3519:deprecate-saved-vars

Conversation

@zou3519
Copy link
Copy Markdown
Contributor

@zou3519 zou3519 commented Mar 21, 2018

As suggested in #5907. I'm trying to keep that PR to mostly doc changes so opening this one for this deprecation.

Advises replacing saved_variables with saved_tensors.
Also replaces all instances of ctx.saved_variables with ctx.saved_tensors in the
codebase.

cc @apaszke @colesbury

Test by running:

import torch
from torch.autograd import Function

class MyFunction(Function):
    @staticmethod
    def forward(ctx, tensor1, tensor2):
        ctx.save_for_backward(tensor1, tensor2)
        return tensor1 + tensor2

    @staticmethod
    def backward(ctx, grad_output):
        var1, var2 = ctx.saved_variables
        return (grad_output, grad_output)

x = torch.randn((3, 3), requires_grad=True)
y = torch.randn((3, 3), requires_grad=True)
model = MyFunction()
model.apply(x, y).sum().backward()

and assert the warning shows up.

Comment thread torch/csrc/autograd/python_function.cpp Outdated

This comment was marked as off-topic.

@zou3519 zou3519 force-pushed the deprecate-saved-vars branch from 58eb1b1 to cb6d2c2 Compare March 21, 2018 21:12
@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Mar 23, 2018

It might be good to leave in one test for the deprecated pattern so we don't accidentally delete it before the deprecation cycle is over.

@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Mar 23, 2018

@pytorchbot retest this please

1 similar comment
@ezyang
Copy link
Copy Markdown
Contributor

ezyang commented Mar 24, 2018

@pytorchbot retest this please

zou3519 added 3 commits March 26, 2018 06:51
Advises replacing saved_variables with saved_tensors.
Also replaces all instances of ctx.saved_variables with ctx.saved_tensors in the
codebase.

Test by running:
```
import torch
from torch.autograd import Function

class MyFunction(Function):
    @staticmethod
    def forward(ctx, tensor1, tensor2):
        ctx.save_for_backward(tensor1, tensor2)
        return tensor1 + tensor2

    @staticmethod
    def backward(ctx, grad_output):
        var1, var2 = ctx.saved_variables
        return (grad_output, grad_output)

x = torch.randn((3, 3), requires_grad=True)
y = torch.randn((3, 3), requires_grad=True)
model = MyFunction()
model.apply(x, y).sum().backward()
```
and assert the warning shows up.
@zou3519 zou3519 force-pushed the deprecate-saved-vars branch from 66b3f66 to a97394b Compare March 26, 2018 14:00
@ezyang ezyang merged commit 5d628db into pytorch:master Mar 26, 2018
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
* Deprecate ctx.saved_variables via python warning.

Advises replacing saved_variables with saved_tensors.
Also replaces all instances of ctx.saved_variables with ctx.saved_tensors in the
codebase.

Test by running:
```
import torch
from torch.autograd import Function

class MyFunction(Function):
    @staticmethod
    def forward(ctx, tensor1, tensor2):
        ctx.save_for_backward(tensor1, tensor2)
        return tensor1 + tensor2

    @staticmethod
    def backward(ctx, grad_output):
        var1, var2 = ctx.saved_variables
        return (grad_output, grad_output)

x = torch.randn((3, 3), requires_grad=True)
y = torch.randn((3, 3), requires_grad=True)
model = MyFunction()
model.apply(x, y).sum().backward()
```
and assert the warning shows up.

* Address comments

* Add deprecation test for saved_variables
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants