Skip to content

functionalization: introduce a "zero()" aten op#75913

Closed
bdhirsh wants to merge 12 commits intogh/bdhirsh/208/basefrom
gh/bdhirsh/208/head
Closed

functionalization: introduce a "zero()" aten op#75913
bdhirsh wants to merge 12 commits intogh/bdhirsh/208/basefrom
gh/bdhirsh/208/head

Conversation

@bdhirsh
Copy link
Collaborator

@bdhirsh bdhirsh commented Apr 15, 2022

Fixes pytorch/functorch#705

This adds support for zero_() in the functionalization pass by introducing a new at::zero().

It's identically to at::zeros_like(t), but adding it directly in to native_functions.yaml allows the functionalization pass to automatically figure out how to undo a mutation from zero_().

We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding.

From conversation with @ezyang, we should probably just do the same with at::_copy() (even though at::copy() will be a pretty unintuitive op.

This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88

Stack from ghstack:

Differential Revision: D35705378

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Apr 15, 2022

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 5add523 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@bdhirsh bdhirsh changed the title functionalization: add support for zero_() functionalization: introduce a "zero()" aten op Apr 15, 2022
- func: zero(Tensor self) -> Tensor
variants: function
dispatch:
CompositeExplicitAutograd: zero
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't want this to be CompositeImplicitAutograd because I want you to be able to trace out a zero call.

@ezyang this should probably go under the new CompositeNonFunctional alias key once we have one (name tbd), so functional backends don't actually decompose it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH I am confused, what is wrong with this getting traced as zeros_like?

@bdhirsh bdhirsh requested a review from ezyang April 15, 2022 22:03
This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`.

It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`.

We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding.

From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op.

This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88




[ghstack-poisoned]
This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`.

It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`.

We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding.

From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op.

This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88




[ghstack-poisoned]
@bdhirsh
Copy link
Collaborator Author

bdhirsh commented Apr 17, 2022

@bdhirsh has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ezyang
Copy link
Contributor

ezyang commented Apr 17, 2022

I know I suggested you do this, but seeing the PR brings another thing to mind: you don't want this to be exposed as a user visible concept so you've suppressed Python bindings. But this is still exposed to the user in another way: as an operator that can show up in a trace. This seems counterproductive: now you have to know to define BOTH aten::zero and aten::zeros_like but actually they're the same thing. You only need the completion for functionalization mapping zero <-> zero_, but it seems to me that you probably want it to evaporate after the mapping into zeros_like to avoid adding a duped operator to the set of operators that need to be created.

@bdhirsh
Copy link
Collaborator Author

bdhirsh commented Apr 18, 2022

it seems to me that you probably want it to evaporate after the mapping into zeros_like to avoid adding a duped operator to the set of operators that need to be created.

Ah yep, this makes sense to me.

One problem with that today is that at::Tensor::zeros_like is a CompositeImplicitAutograd op, which means that it will actually decompose further in the trace (and it calls... empty_like and zero_. So mutations show up in the trace again!)

Somehow... we want zeros_like() not to decompose before hitting the Python key, when functionalization is involved.

I have a question: would it be too heavy-handed to require that all functional op CompositeImplicitAutograd kernels that decompose into mutations get their own autograd formulas? (And get switched over eventually to some 'DecomposeWithMutations` alias key that runs underneath autograd).

If we have a derivative formula for zeros_like, then we can:

  • keep zero() as CompositeImplicitAutograd
  • mark zeros_like() as CompositeExplicitAutograd so it shows up in traces

@ezyang
Copy link
Contributor

ezyang commented Apr 18, 2022

Weren't you going to distinguish between mutating and non-mutating composites? It seems like that would help.

@bdhirsh
Copy link
Collaborator Author

bdhirsh commented Apr 18, 2022

Weren't you going to distinguish between mutating and non-mutating composites? It seems like that would help

Yep - that doesn't help with CompositeImplicitAutograd ops though, because we automatically lose autograd support if we decide not to decompose. That's why I'm thinking we'd need to give zeros_like an autograd formula as part of this PR.

At first I thought that we'd need to do the same for all other CompositeImplicitAutograd ops that call into mutation ops, but I guess that's not true - we only need to worry about CompositeImplicitAutograd ops that get called underneath the functionalization pass, like zeros_like.

@ezyang
Copy link
Contributor

ezyang commented Apr 19, 2022

I have a question: would it be too heavy-handed to require that all functional op CompositeImplicitAutograd kernels that decompose into mutations get their own autograd formulas? (And get switched over eventually to some 'DecomposeWithMutations` alias key that runs underneath autograd).

It's a bit finely balanced, but this plan seems reasonable to me.

1 similar comment
@ezyang
Copy link
Contributor

ezyang commented Apr 19, 2022

I have a question: would it be too heavy-handed to require that all functional op CompositeImplicitAutograd kernels that decompose into mutations get their own autograd formulas? (And get switched over eventually to some 'DecomposeWithMutations` alias key that runs underneath autograd).

It's a bit finely balanced, but this plan seems reasonable to me.

This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`.

It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`.

We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding.

From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op.

This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88


Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378)

[ghstack-poisoned]
This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`.

It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`.

We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding.

From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op.

This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88


Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378)

[ghstack-poisoned]

- name: zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
self: zeros_like(grad)
result: auto_linear
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @albanD / @soulitzer, I'm giving a derivative formula to zeros_like (see the comment above)

Fixes pytorch/functorch#705

This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`.

It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`.

We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding.

From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op.

This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88


Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378)

[ghstack-poisoned]
Fixes pytorch/functorch#705

This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`.

It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`.

We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding.

From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op.

This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88


Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378)

[ghstack-poisoned]
bdhirsh added 2 commits April 21, 2022 16:36
Fixes pytorch/functorch#705

This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`.

It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`.

We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding.

From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op.

This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88


Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378)

[ghstack-poisoned]
Fixes pytorch/functorch#705

This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`.

It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`.

We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding.

From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op.

This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88


Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378)

[ghstack-poisoned]
# Fails due to a limitation of gradgradcheck
# https://github.com/pytorch/pytorch/issues/59137
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_gradgrad'),
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_inplace_gradgrad'),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @soulitzer, these tests started passing after I added a derivative formula for zeros_like. Just double checking - does that sound reasonable? (If it is I can close the linked issue)

Fixes pytorch/functorch#705

This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`.

It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`.

We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding.

From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op.

This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88


Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378)

[ghstack-poisoned]
@bdhirsh bdhirsh requested review from mruberry and ngimel as code owners April 22, 2022 13:57
Fixes pytorch/functorch#705

This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`.

It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`.

We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding.

From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op.

This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88


Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378)

[ghstack-poisoned]
Fixes pytorch/functorch#705

This adds support for `zero_()` in the functionalization pass by introducing a new `at::zero()`.

It's identically to `at::zeros_like(t)`, but adding it directly in to `native_functions.yaml` allows the functionalization pass to automatically figure out how to undo a mutation from `zero_()`.

We probably don't want users to actually use the operator, so I didn't give it a tensor method or a python binding.

From conversation with @ezyang, we should probably just do the same with `at::_copy()` (even though `at::copy()` will be a pretty unintuitive op.

This also fixes one of the torch dynamo integration issues mentioned in pytorch/torchdynamo#88


Differential Revision: [D35705378](https://our.internmc.facebook.com/intern/diff/D35705378)

[ghstack-poisoned]
@bdhirsh
Copy link
Collaborator Author

bdhirsh commented Apr 25, 2022

@pytorchbot merge this please

@datumbox
Copy link
Contributor

@bdhirsh We have started receiving Runtime Errors related to torch.zeros_like at TorchVision after the landing of this PR. It's not clear if it is related to this PR but could you please have a look at: pytorch/vision#5881

@vfdev-5
Copy link
Contributor

vfdev-5 commented Apr 26, 2022

Repro code:

import torch  # '1.12.0.dev20220426+cu113' and '1.12.0a0+gitb17b2b1'
import torchvision  # 0.13.0a0+01b0a00

model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=None, weights_backbone=None)
model.eval()

smodel = torch.jit.script(model)
smodel.eval()
smodel([torch.rand(3, 224, 224), ])
# Based on Detectron2 implementation, just manually call nms() on each class independently
    keep_mask = torch.zeros_like(scores, dtype=torch.bool)
                ~~~~~~~~~~~~~~~~ <--- HERE
    for class_id in torch.unique(idxs):
        curr_indices = torch.where(idxs == class_id)[0]
RuntimeError: isDifferentiableType(variable.scalar_type()) INTERNAL ASSERT FAILED at "../torch/csrc/autograd/functions/utils.h":65, please report a bug to PyTorch. 

bdhirsh added a commit that referenced this pull request Apr 26, 2022
Reverting #75913 as it broke torchvision (see comment at #75913 (comment))

Diagnose: The following code now fails but used to work:
```
import torch

def foo(a):
    b = torch.zeros_like(a, dtype=torch.bool)
    return b

a = torch.ones(2, requires_grad=True)
sfoo = torch.jit.script(foo)
sfoo(a)
```

Why? The reverted PR added an autograd formula for `zeros_like()`, which used to be a `CompositeImplicitAutograd` op.

Unfortunately, that changed the behavior of zeros_like as follows.  The `*_like` ops "work" with autograd, but they sever the autograd graph.
```
>>> a = torch.ones(2, requires_grad=True)
>>> b = torch.zeros_like(a)
>>> b.requires_grad
False
>>> b.is_leaf
True
```

That makes code like `torch.zeros_like(a, dtype=torch.bool)` valid even if `a` requires_grad: if the requires_grad-ness were propagated, autograd would throw an error that you can't use autograd with `bool` tensors.


This reverts commit 7d44b36.

[ghstack-poisoned]
facebook-github-bot pushed a commit that referenced this pull request Apr 26, 2022
Summary:
Pull Request resolved: #75913

Approved by: https://github.com/albanD

Test Plan: contbuild & OSS CI, see https://hud.pytorch.org/commit/pytorch/pytorch/7d44b3675bafdfbd59e6c81734ca3febd771dd7b

Reviewed By: albanD

Differential Revision: D35705378

Pulled By: bdhirsh

fbshipit-source-id: 7aebc1bbe8fdc7aca461920a2ac1f3b4a1afbe28
@facebook-github-bot facebook-github-bot deleted the gh/bdhirsh/208/head branch April 29, 2022 14:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants