as_strided support for functionalization; introduce as_strided_scatter#77128
as_strided support for functionalization; introduce as_strided_scatter#77128bdhirsh wants to merge 16 commits intogh/bdhirsh/226/basefrom
Conversation
[ghstack-poisoned]
🔗 Helpful links
✅ No Failures (0 Pending)As of commit 7fc1936 (more details on the Dr. CI page): Expand to see more💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Please report bugs/suggestions to the (internal) Dr. CI Users group. |
…ided_scatter" [ghstack-poisoned]
…ided_scatter" [ghstack-poisoned]
| - name: as_strided_scatter(Tensor self, Tensor src, int[] size, int[] stride, int? storage_offset=None) -> Tensor | ||
| self: as_strided_scatter(grad, zeros_like(src), size, stride, storage_offset) | ||
| src: grad.as_strided(size, stride, storage_offset) | ||
| result: auto_linear |
There was a problem hiding this comment.
I wasn't 100% sure on whether or not I needed to involve as_strided_backward() somewhere in the derivative formula here. I see get some interesting op info failures... but for the most part they're the same failures that as_strided already has (plus two new ones for conjugate views that I didn't try too hard to debug)
There was a problem hiding this comment.
The problem is that you assume that doing "as_strided" on the grad with the parameters from the forward will do the right thing (it won't if any of these have different strides/storage_offset to begin with).
Hence the madness of as_strided_backward() :D
I think the simplest way to do this is to create a new function in functions manual that call the backward of each elementary function in order (copy_backward, as_strided_backward) as if you were doing AD on the CompositeExplicit implementation that you have. (nit: it's ok to replicate the copy_backward impl from tensor.cpp instead of making it a helper function I think).
…ided_scatter" Adds a new `torch.as_strided_scatter` + `tensor.as_strided_scatter` op (with docs + op info tests), which I needed to provide better support for `as_strided()` in functionalization. [ghstack-poisoned]
albanD
left a comment
There was a problem hiding this comment.
Autograd formula looks wrong.
If you don't want to solve it because it is too tricky, we should have an assert there that the input and grad metadata do match and there are no 0 strides (and maybe more?) that ensures that the simplified formula is only used when it is valid.
| Tensor FunctionalInverses::as_strided_copy_inverse(const Tensor& base, const Tensor& mutated_view, bool reapply_views, at::IntArrayRef size, at::IntArrayRef stride, c10::optional<int64_t> storage_offset) { | ||
| TORCH_INTERNAL_ASSERT(false, "as_strided has not been implemented in the functionalization pass yet"); | ||
| return Tensor(); | ||
| // Pessimism: we can't reapply views for as_strided_scatter. |
There was a problem hiding this comment.
I have a similar comment for all of the other "slicing" ops like slice/select/diagonal.
For most operators, the "inverse" can be a view (e.g. the inverse of permute is another permute call).
functionalize() will also by default re-apply views instead of using view_copy operators, so in most of the functions in this file you'll see code like "if (re-apply views) return view() else view_copy()". But for the ops listed above we can't do that - running the "inverse" forces us to allocate a new tensor.
| - name: as_strided_scatter(Tensor self, Tensor src, int[] size, int[] stride, int? storage_offset=None) -> Tensor | ||
| self: as_strided_scatter(grad, zeros_like(src), size, stride, storage_offset) | ||
| src: grad.as_strided(size, stride, storage_offset) | ||
| result: auto_linear |
There was a problem hiding this comment.
The problem is that you assume that doing "as_strided" on the grad with the parameters from the forward will do the right thing (it won't if any of these have different strides/storage_offset to begin with).
Hence the madness of as_strided_backward() :D
I think the simplest way to do this is to create a new function in functions manual that call the backward of each elementary function in order (copy_backward, as_strided_backward) as if you were doing AD on the CompositeExplicit implementation that you have. (nit: it's ok to replicate the copy_backward impl from tensor.cpp instead of making it a helper function I think).
…ided_scatter" Adds a new `torch.as_strided_scatter` + `tensor.as_strided_scatter` op (with docs + op info tests), which I needed to provide better support for `as_strided()` in functionalization. [ghstack-poisoned]
…ided_scatter" Adds a new `torch.as_strided_scatter` + `tensor.as_strided_scatter` op (with docs + op info tests), which I needed to provide better support for `as_strided()` in functionalization. [ghstack-poisoned]
…ided_scatter" Adds a new `torch.as_strided_scatter` + `tensor.as_strided_scatter` op (with docs + op info tests), which I needed to provide better support for `as_strided()` in functionalization. [ghstack-poisoned]
…ided_scatter" Adds a new `torch.as_strided_scatter` + `tensor.as_strided_scatter` op (with docs + op info tests), which I needed to provide better support for `as_strided()` in functionalization. [ghstack-poisoned]
…ided_scatter" Adds a new `torch.as_strided_scatter` + `tensor.as_strided_scatter` op (with docs + op info tests), which I needed to provide better support for `as_strided()` in functionalization. [ghstack-poisoned]
…ided_scatter" Adds a new `torch.as_strided_scatter` + `tensor.as_strided_scatter` op (with docs + op info tests), which I needed to provide better support for `as_strided()` in functionalization. [ghstack-poisoned]
…ided_scatter" Adds a new `torch.as_strided_scatter` + `tensor.as_strided_scatter` op (with docs + op info tests), which I needed to provide better support for `as_strided()` in functionalization. [ghstack-poisoned]
…ided_scatter" Adds a new `torch.as_strided_scatter` + `tensor.as_strided_scatter` op (with docs + op info tests), which I needed to provide better support for `as_strided()` in functionalization. [ghstack-poisoned]
…ided_scatter" Adds a new `torch.as_strided_scatter` + `tensor.as_strided_scatter` op (with docs + op info tests), which I needed to provide better support for `as_strided()` in functionalization. [ghstack-poisoned]
| DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'), | ||
| DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_grad'), | ||
| DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_gradgrad'), | ||
| DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad'),)), |
There was a problem hiding this comment.
There are quite a few skips here. I updated the derivative formula to handle the contiguous case properly, and locally was able to run this:
a = torch.tensor([[3, 3], [3, 3]], requires_grad=True, dtype=torch.float32)
b = torch.ones(2, requires_grad=True, dtype=torch.float32)
c = torch.as_strided_scatter(a, b, (2,), (2,))
c.sum().backward()
print(a.grad)
as_strided_scatter doesn't work in more complicated cases, but this PR adds enough support to pass the basic LazyTensor test case, so I'd rather land it sooner to get everything working E2E, and handle the harder cases later.
…ided_scatter" Adds a new `torch.as_strided_scatter` + `tensor.as_strided_scatter` op (with docs + op info tests), which I needed to provide better support for `as_strided()` in functionalization. [ghstack-poisoned]
…ided_scatter" Adds a new `torch.as_strided_scatter` + `tensor.as_strided_scatter` op (with docs + op info tests), which I needed to provide better support for `as_strided()` in functionalization. [ghstack-poisoned]
…ided_scatter" Adds a new `torch.as_strided_scatter` + `tensor.as_strided_scatter` op (with docs + op info tests), which I needed to provide better support for `as_strided()` in functionalization. [ghstack-poisoned]
|
@pytorchbot merge |
|
Hey @bdhirsh. |
|
@pytorchbot revert -m "This broke rocm tests on master https://hud.pytorch.org/pytorch/pytorch/commit/3a921f2d267430f292a111e8bcd40c76022cfd47. rocm tests are no longer run on PRs, you should add a |
…d_scatter" This reverts commit 3a921f2. Reverted #77128 on behalf of https://github.com/suo due to This broke rocm tests on master https://hud.pytorch.org/pytorch/pytorch/commit/3a921f2d267430f292a111e8bcd40c76022cfd47. rocm tests are no longer run on PRs, you should add a `ciflow/trunk` label if you want to run them
|
Welp thanks! Looks like the bot was able to revert this PR without having to revert the other PR's in the stack. If that's true then I'll try to re-land this PR by itself |
Adds a new
torch.as_strided_scatter+tensor.as_strided_scatterop (with docs + op info tests), which I needed to provide better support foras_strided()in functionalization.Stack from ghstack: