Skip to content

Implement ravel#46098

Closed
ejguan wants to merge 5 commits intogh/ejguan/2/basefrom
gh/ejguan/2/head
Closed

Implement ravel#46098
ejguan wants to merge 5 commits intogh/ejguan/2/basefrom
gh/ejguan/2/head

Conversation

@ejguan
Copy link
Copy Markdown
Contributor

@ejguan ejguan commented Oct 9, 2020

Related #38349
Stack from ghstack:

Doc:
image

Differential Revision: D24253213

[ghstack-poisoned]
@ejguan ejguan marked this pull request as draft October 9, 2020 17:02
@ejguan ejguan added the module: numpy Related to numpy support, and also numpy compatibility of our operators label Oct 9, 2020
@ejguan
Copy link
Copy Markdown
Contributor Author

ejguan commented Oct 9, 2020

TODO:

  • Add tests.
python test/test_torch.py AbstractTestCases._TestTorchMixin.test_ravel

image

@dr-ci
Copy link
Copy Markdown

dr-ci Bot commented Oct 9, 2020

💊 CI failures summary and remediations

As of commit dcf8140 (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 18 times.

@ejguan ejguan marked this pull request as ready for review October 9, 2020 18:21
@ejguan ejguan requested a review from zou3519 October 9, 2020 18:24
Comment thread aten/src/ATen/native/TensorShape.cpp Outdated
Comment on lines +1628 to +1629
// Same as flatten() with start_dim as 0 and end_dim as -1.
return native::flatten(self, 0, -1);
Copy link
Copy Markdown
Contributor

@zou3519 zou3519 Oct 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: self.reshape(-1) avoids some of the computation that goes on inside flatten.

Calling native::flatten instead of at::flatten inside of here means that we won't actually get autograd support. Using at::flatten (or reshape as pointed out above) fixes that problem. You can test this locally with something like:

x = torch.randn(3, 3, requires_grad=True)
out = x.ravel().sum()
out.backward()  # I think this throws an error with the current implementation

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. It doesn't really need to go into flatten.
But, t's interesting that it doesn't throw any error and gradient for the tensor is correct when I used native::flatten.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see what happened now. native::flatten calls at::reshape, and at::reshape is supported by the autograd.

Comment thread test/test_torch.py
Comment thread test/test_torch.py
Comment thread test/test_torch.py Outdated
Comment thread torch/_torch_docs.py Outdated
r"""
ravel(input) -> Tensor

Return a contiguous flattened tensor.
Copy link
Copy Markdown
Contributor

@zou3519 zou3519 Oct 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's important to note here that "A copy is made only if needed."; that is part of the numpy semantics for ravel (https://numpy.org/doc/stable/reference/generated/numpy.ravel.html) and is also what the current implementation (via input.flatten()) does

Copy link
Copy Markdown
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This generally looks good to me. We should add some more test cases:

  • we should add an autograd test to make sure the operator works with autograd
  • we should add some more cases to the test_torch test to check the behavior of ravel() on a non-contiguous Tensor

Comment thread torch/testing/_internal/common_methods_invocations.py Outdated
Comment thread test/test_torch.py Outdated
Comment thread test/test_torch.py Outdated
ejguan added a commit that referenced this pull request Oct 9, 2020
ghstack-source-id: ecd94c9
Pull Request resolved: #46098
Comment thread test/test_torch.py
self.assertEqual(flat0.shape, flat1.shape)

# Test both float tensor and quantized tensor
tensors = [torch.randn(5, 5),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: We should probably test more shapes. The implementation code for ravel is simple enough that I can believe it works for everything, but there's no harm in testing more

Comment thread test/test_torch.py
self.assertEqual(flat0.shape, flat1.shape)

# Test both float tensor and quantized tensor
tensors = [torch.randn(5, 5),
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One edge case we should test is a tensor with a zero sized dimension. For example, the following should be true:

x = torch.randn([0, 2, 3]) # tensor of size [0, 2, 3]
y = x.ravel()
assert y.shape = ()

Copy link
Copy Markdown
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Some minor nits on testing.

Also, we should try to test a tensor with a zero size in some dimension. I think the implementation should support it.

ejguan added a commit that referenced this pull request Oct 10, 2020
ghstack-source-id: 6bc7834
Pull Request resolved: #46098
@codecov
Copy link
Copy Markdown

codecov Bot commented Oct 11, 2020

Codecov Report

Merging #46098 into gh/ejguan/2/base will increase coverage by 0.00%.
The diff coverage is 100.00%.

Impacted file tree graph

@@                Coverage Diff                @@
##           gh/ejguan/2/base   #46098   +/-   ##
=================================================
  Coverage             68.20%   68.20%           
=================================================
  Files                   410      410           
  Lines                 53453    53455    +2     
=================================================
+ Hits                  36458    36461    +3     
+ Misses                16995    16994    -1     
Impacted Files Coverage Δ
torch/overrides.py 97.08% <ø> (ø)
...ch/testing/_internal/common_methods_invocations.py 91.45% <ø> (ø)
torch/_tensor_docs.py 100.00% <100.00%> (ø)
torch/_torch_docs.py 100.00% <100.00%> (ø)
torch/tensor.py 89.08% <0.00%> (+0.21%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 4c87d33...dcf8140. Read the comment docs.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@ejguan merged this pull request in bed3b40.

@facebook-github-bot facebook-github-bot deleted the gh/ejguan/2/head branch October 16, 2020 14:21
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
Pull Request resolved: pytorch#46098

Doc:
![image](https://user-images.githubusercontent.com/68879799/95611323-ae5cf380-0a2f-11eb-9b8e-56bf79ce68af.png)

Test Plan: Imported from OSS

Reviewed By: glaringlee

Differential Revision: D24253213

Pulled By: ejguan

fbshipit-source-id: 42a866c902272cbe3743a9d0cb3afb9165d51c0b
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Merged module: numpy Related to numpy support, and also numpy compatibility of our operators

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants