Implement ravel#46098
Conversation
[ghstack-poisoned]
Doc:  [ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit dcf8140 (more details on the Dr. CI page): 💚 💚 Looks good so far! There are no failures yet. 💚 💚 This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group. This comment has been revised 18 times. |
| // Same as flatten() with start_dim as 0 and end_dim as -1. | ||
| return native::flatten(self, 0, -1); |
There was a problem hiding this comment.
nit: self.reshape(-1) avoids some of the computation that goes on inside flatten.
Calling native::flatten instead of at::flatten inside of here means that we won't actually get autograd support. Using at::flatten (or reshape as pointed out above) fixes that problem. You can test this locally with something like:
x = torch.randn(3, 3, requires_grad=True)
out = x.ravel().sum()
out.backward() # I think this throws an error with the current implementation
There was a problem hiding this comment.
You are right. It doesn't really need to go into flatten.
But, t's interesting that it doesn't throw any error and gradient for the tensor is correct when I used native::flatten.
There was a problem hiding this comment.
Oh, I see what happened now. native::flatten calls at::reshape, and at::reshape is supported by the autograd.
| r""" | ||
| ravel(input) -> Tensor | ||
|
|
||
| Return a contiguous flattened tensor. |
There was a problem hiding this comment.
It's important to note here that "A copy is made only if needed."; that is part of the numpy semantics for ravel (https://numpy.org/doc/stable/reference/generated/numpy.ravel.html) and is also what the current implementation (via input.flatten()) does
zou3519
left a comment
There was a problem hiding this comment.
This generally looks good to me. We should add some more test cases:
- we should add an autograd test to make sure the operator works with autograd
- we should add some more cases to the test_torch test to check the behavior of ravel() on a non-contiguous Tensor
Doc:  [ghstack-poisoned]
Doc:  [ghstack-poisoned]
| self.assertEqual(flat0.shape, flat1.shape) | ||
|
|
||
| # Test both float tensor and quantized tensor | ||
| tensors = [torch.randn(5, 5), |
There was a problem hiding this comment.
nit: We should probably test more shapes. The implementation code for ravel is simple enough that I can believe it works for everything, but there's no harm in testing more
| self.assertEqual(flat0.shape, flat1.shape) | ||
|
|
||
| # Test both float tensor and quantized tensor | ||
| tensors = [torch.randn(5, 5), |
There was a problem hiding this comment.
One edge case we should test is a tensor with a zero sized dimension. For example, the following should be true:
x = torch.randn([0, 2, 3]) # tensor of size [0, 2, 3]
y = x.ravel()
assert y.shape = ()
zou3519
left a comment
There was a problem hiding this comment.
LGTM. Some minor nits on testing.
Also, we should try to test a tensor with a zero size in some dimension. I think the implementation should support it.
Doc:  [ghstack-poisoned]
Codecov Report
@@ Coverage Diff @@
## gh/ejguan/2/base #46098 +/- ##
=================================================
Coverage 68.20% 68.20%
=================================================
Files 410 410
Lines 53453 53455 +2
=================================================
+ Hits 36458 36461 +3
+ Misses 16995 16994 -1
Continue to review full report at Codecov.
|
Summary: Pull Request resolved: pytorch#46098 Doc:  Test Plan: Imported from OSS Reviewed By: glaringlee Differential Revision: D24253213 Pulled By: ejguan fbshipit-source-id: 42a866c902272cbe3743a9d0cb3afb9165d51c0b

Related #38349
Stack from ghstack:
Doc:

Differential Revision: D24253213