Skip to content

as_strided support for functionalization; introduce as_strided_scatter#77128

Closed
bdhirsh wants to merge 16 commits intogh/bdhirsh/226/basefrom
gh/bdhirsh/226/head
Closed

as_strided support for functionalization; introduce as_strided_scatter#77128
bdhirsh wants to merge 16 commits intogh/bdhirsh/226/basefrom
gh/bdhirsh/226/head

Conversation

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented May 10, 2022

🔗 Helpful links

✅ No Failures (0 Pending)

As of commit 7fc1936 (more details on the Dr. CI page):

Expand to see more

💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

- name: as_strided_scatter(Tensor self, Tensor src, int[] size, int[] stride, int? storage_offset=None) -> Tensor
self: as_strided_scatter(grad, zeros_like(src), size, stride, storage_offset)
src: grad.as_strided(size, stride, storage_offset)
result: auto_linear
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't 100% sure on whether or not I needed to involve as_strided_backward() somewhere in the derivative formula here. I see get some interesting op info failures... but for the most part they're the same failures that as_strided already has (plus two new ones for conjugate views that I didn't try too hard to debug)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that you assume that doing "as_strided" on the grad with the parameters from the forward will do the right thing (it won't if any of these have different strides/storage_offset to begin with).
Hence the madness of as_strided_backward() :D

I think the simplest way to do this is to create a new function in functions manual that call the backward of each elementary function in order (copy_backward, as_strided_backward) as if you were doing AD on the CompositeExplicit implementation that you have. (nit: it's ok to replicate the copy_backward impl from tensor.cpp instead of making it a helper function I think).

@bdhirsh bdhirsh requested a review from ezyang May 11, 2022 14:40
…ided_scatter"

Adds a new `torch.as_strided_scatter` + `tensor.as_strided_scatter` op (with docs + op info tests), which I needed to provide better support for `as_strided()` in functionalization.




[ghstack-poisoned]
Copy link
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Autograd formula looks wrong.
If you don't want to solve it because it is too tricky, we should have an assert there that the input and grad metadata do match and there are no 0 strides (and maybe more?) that ensures that the simplified formula is only used when it is valid.

Tensor FunctionalInverses::as_strided_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::IntArrayRef size, at::IntArrayRef stride, c10::optional<int64_t> storage_offset) {
TORCH_INTERNAL_ASSERT(false, "as_strided has not been implemented in the functionalization pass yet");
return Tensor();
// Pessimism: we can't reapply views for as_strided_scatter.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what that means?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a similar comment for all of the other "slicing" ops like slice/select/diagonal.

For most operators, the "inverse" can be a view (e.g. the inverse of permute is another permute call).

functionalize() will also by default re-apply views instead of using view_copy operators, so in most of the functions in this file you'll see code like "if (re-apply views) return view() else view_copy()". But for the ops listed above we can't do that - running the "inverse" forces us to allocate a new tensor.

- name: as_strided_scatter(Tensor self, Tensor src, int[] size, int[] stride, int? storage_offset=None) -> Tensor
self: as_strided_scatter(grad, zeros_like(src), size, stride, storage_offset)
src: grad.as_strided(size, stride, storage_offset)
result: auto_linear
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that you assume that doing "as_strided" on the grad with the parameters from the forward will do the right thing (it won't if any of these have different strides/storage_offset to begin with).
Hence the madness of as_strided_backward() :D

I think the simplest way to do this is to create a new function in functions manual that call the backward of each elementary function in order (copy_backward, as_strided_backward) as if you were doing AD on the CompositeExplicit implementation that you have. (nit: it's ok to replicate the copy_backward impl from tensor.cpp instead of making it a helper function I think).

…ided_scatter"

Adds a new `torch.as_strided_scatter` + `tensor.as_strided_scatter` op (with docs + op info tests), which I needed to provide better support for `as_strided()` in functionalization.




[ghstack-poisoned]
bdhirsh added 8 commits May 17, 2022 19:58
…ided_scatter"

Adds a new `torch.as_strided_scatter` + `tensor.as_strided_scatter` op (with docs + op info tests), which I needed to provide better support for `as_strided()` in functionalization.




[ghstack-poisoned]
…ided_scatter"

Adds a new `torch.as_strided_scatter` + `tensor.as_strided_scatter` op (with docs + op info tests), which I needed to provide better support for `as_strided()` in functionalization.




[ghstack-poisoned]
…ided_scatter"

Adds a new `torch.as_strided_scatter` + `tensor.as_strided_scatter` op (with docs + op info tests), which I needed to provide better support for `as_strided()` in functionalization.




[ghstack-poisoned]
…ided_scatter"

Adds a new `torch.as_strided_scatter` + `tensor.as_strided_scatter` op (with docs + op info tests), which I needed to provide better support for `as_strided()` in functionalization.




[ghstack-poisoned]
…ided_scatter"

Adds a new `torch.as_strided_scatter` + `tensor.as_strided_scatter` op (with docs + op info tests), which I needed to provide better support for `as_strided()` in functionalization.




[ghstack-poisoned]
…ided_scatter"

Adds a new `torch.as_strided_scatter` + `tensor.as_strided_scatter` op (with docs + op info tests), which I needed to provide better support for `as_strided()` in functionalization.




[ghstack-poisoned]
…ided_scatter"

Adds a new `torch.as_strided_scatter` + `tensor.as_strided_scatter` op (with docs + op info tests), which I needed to provide better support for `as_strided()` in functionalization.




[ghstack-poisoned]
…ided_scatter"

Adds a new `torch.as_strided_scatter` + `tensor.as_strided_scatter` op (with docs + op info tests), which I needed to provide better support for `as_strided()` in functionalization.




[ghstack-poisoned]
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_grad'),
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_gradgrad'),
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad'),)),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are quite a few skips here. I updated the derivative formula to handle the contiguous case properly, and locally was able to run this:

  a = torch.tensor([[3, 3], [3, 3]], requires_grad=True, dtype=torch.float32)
  b = torch.ones(2, requires_grad=True, dtype=torch.float32)

  c = torch.as_strided_scatter(a, b, (2,), (2,))
  c.sum().backward()
  print(a.grad)

as_strided_scatter doesn't work in more complicated cases, but this PR adds enough support to pass the basic LazyTensor test case, so I'd rather land it sooner to get everything working E2E, and handle the harder cases later.

bdhirsh added 3 commits May 23, 2022 14:40
…ided_scatter"

Adds a new `torch.as_strided_scatter` + `tensor.as_strided_scatter` op (with docs + op info tests), which I needed to provide better support for `as_strided()` in functionalization.




[ghstack-poisoned]
…ided_scatter"

Adds a new `torch.as_strided_scatter` + `tensor.as_strided_scatter` op (with docs + op info tests), which I needed to provide better support for `as_strided()` in functionalization.




[ghstack-poisoned]
…ided_scatter"

Adds a new `torch.as_strided_scatter` + `tensor.as_strided_scatter` op (with docs + op info tests), which I needed to provide better support for `as_strided()` in functionalization.




[ghstack-poisoned]
@bdhirsh
Copy link
Collaborator Author

bdhirsh commented May 24, 2022

@pytorchbot merge

@github-actions
Copy link
Contributor

Hey @bdhirsh.
You've committed this PR, but it does not have both a 'release notes: ...' and 'topics: ...' label. Please add one of each to the PR. The 'release notes: ...' label should represent the part of PyTorch that this PR changes (fx, autograd, distributed, etc) and the 'topics: ...' label should represent the kind of PR it is (not user facing, new feature, bug fix, perf improvement, etc). The list of valid labels can be found here for the 'release notes: ...' and here for the 'topics: ...'.
For changes that are 'topic: not user facing' there is no need for a release notes label.

@suo
Copy link
Member

suo commented May 24, 2022

@pytorchbot revert -m "This broke rocm tests on master https://hud.pytorch.org/pytorch/pytorch/commit/3a921f2d267430f292a111e8bcd40c76022cfd47. rocm tests are no longer run on PRs, you should add a ciflow/trunk label if you want to run them" -c nosignal

pytorchmergebot added a commit that referenced this pull request May 24, 2022
…d_scatter"

This reverts commit 3a921f2.

Reverted #77128 on behalf of https://github.com/suo due to This broke rocm tests on master https://hud.pytorch.org/pytorch/pytorch/commit/3a921f2d267430f292a111e8bcd40c76022cfd47. rocm tests are no longer run on PRs, you should add a `ciflow/trunk` label if you want to run them
@bdhirsh
Copy link
Collaborator Author

bdhirsh commented May 24, 2022

Welp thanks!

Looks like the bot was able to revert this PR without having to revert the other PR's in the stack. If that's true then I'll try to re-land this PR by itself

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants