Skip to content

Rewrite kron with broadcasting at::mul#50927

Closed
zasdfgbnm wants to merge 3 commits intomasterfrom
kron
Closed

Rewrite kron with broadcasting at::mul#50927
zasdfgbnm wants to merge 3 commits intomasterfrom
kron

Conversation

@zasdfgbnm
Copy link
Copy Markdown
Collaborator

@zasdfgbnm zasdfgbnm commented Jan 22, 2021

Because it is shorter, faster, and does not have TF32 issue.

Benchmark: https://github.com/zasdfgbnm/things/blob/master/2021Q1/kron.ipynb

@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Jan 22, 2021

💊 CI failures summary and remediations

As of commit f1aa841 (more details on the Dr. CI page):


  • 1/1 failures possibly* introduced in this PR
    • 1/1 non-CircleCI failure(s)

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

@zasdfgbnm zasdfgbnm changed the title Rewrite kron with elementwise op Rewrite kron with broadcasting at::mul Jan 22, 2021
@codecov
Copy link
Copy Markdown

codecov Bot commented Jan 22, 2021

Codecov Report

Merging #50927 (f1aa841) into master (8e9ed27) will decrease coverage by 0.00%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master   #50927      +/-   ##
==========================================
- Coverage   81.00%   80.99%   -0.01%     
==========================================
  Files        1916     1916              
  Lines      209481   209473       -8     
==========================================
- Hits       169690   169671      -19     
- Misses      39791    39802      +11     

Copy link
Copy Markdown
Collaborator

@IvanYashchuk IvanYashchuk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests are passing and it is faster now, so I'm happy with these changes.
Thank you!

Comment thread aten/src/ATen/native/LinearAlgebra.cpp Outdated
return result;
Tensor kron(const Tensor& self, const Tensor& other) {
at::Tensor result;
return at::kron_out(result, self, other);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we silently enabled type promotion for kron :-)

Comment thread aten/src/ATen/native/LinearAlgebra.cpp Outdated
Comment thread aten/src/ATen/native/LinearAlgebra.cpp Outdated
auto maxdim = std::max(self.dim(), other.dim());
auto pad_self = maxdim - self.dim();
auto pad_other = maxdim - other.dim();
DimVector a_reshape(2 * maxdim);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be worse than std::vector for 3+d tensors, because it will have to grow original 5-element vector

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed to c10::SmallVector<int64_t, 10>

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ngimel has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ngimel
Copy link
Copy Markdown
Collaborator

ngimel commented Jan 22, 2021

If you have it handy, can you post final benchmarking numbers for this implementation instead of python?

@zasdfgbnm
Copy link
Copy Markdown
Collaborator Author

zasdfgbnm commented Jan 22, 2021

@ngimel
CPU 64.4us
CUDA 25.1us

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@ngimel merged this pull request in ab331da.

a_reshape[2 * i] = i >= pad_self ? self.sizes()[i - pad_self] : 1;
a_reshape[2 * i + 1] = 1;
b_reshape[2 * i] = 1;
b_reshape[2 * i + 1] = i >= pad_other ? other.sizes()[i - pad_other] : 1;
Copy link
Copy Markdown
Collaborator

@mruberry mruberry Jan 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style nit: parens after the equals sign for readability

c10::SmallVector<int64_t, 10> a_reshape(2 * maxdim);
c10::SmallVector<int64_t, 10> b_reshape(2 * maxdim);
c10::SmallVector<int64_t, 10> result_reshape(maxdim);
for (int i = 0; i < maxdim; i++) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decltype i to maxdim's type to be consistent with auto usage

result = at::_unsafe_view(at::mul(self_view, other_view), result_reshape);
} else {
at::mul_out(result, self_view, other_view);
result.resize_(result_reshape);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to use resize_output? Actually, in this case things seem trickier than that.

What kind of tensor can a use pass to out= here without this complaining? If a user passes a tensor with shape result_reshape then can mul_out yell that it received a tensor of the incorrect size? If a user passes a tensor of the shape that mul_out likes then can it be resized without warning? That'd be bad if out= was a view, for example.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is out=view a defined behavior?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, typically it uses resize_ on out, so if out is correct size, it is directly used and view relationship persists, if it is not then anything could happen, so in that sense it's not defined. Here though preserving this behavior would require first re-viewing out with expected size of mul if possible (for reductions with similar problems I inserted unsqueezes and not cared about actual sizes, if they are wrong the resize_ would handle it), and then re-viewing it back.

Copy link
Copy Markdown
Collaborator Author

@zasdfgbnm zasdfgbnm Jan 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the current approach preserve storage, but will erase discontiguity:

>>> import torch

>>> a = torch.randn(4,6,8,10).contiguous(memory_format=torch.channels_last)
>>> a.stride()
(480, 1, 60, 6)
>>> a.view(2,2,2,3,2,4,2,5).stride()
(960, 480, 3, 1, 240, 60, 30, 6)
>>> a.resize_(2,2,2,3,2,4,2,5).stride()
(960, 480, 240, 80, 40, 10, 5, 1)

>>> a = torch.randn(4,6,8,10).contiguous(memory_format=torch.channels_last)
>>> a.view(2,2,2,3,2,4,2,5).resize_(2,2,2,3,2,4,2,5).stride()
(960, 480, 3, 1, 240, 60, 30, 6)

>>> a = torch.randn(4,6,8,10).contiguous(memory_format=torch.channels_last)
>>> a.storage().data_ptr()
94728832336256
>>> a.resize_(2,2,2,3,2,4,2,5).storage().data_ptr()
94728832336256

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

>>> a = torch.randn(2,3,4,5).contiguous(memory_format=torch.channels_last)
>>> b = torch.kron(a, a)
>>> b.stride()
(3600, 1, 225, 9)
>>> torch.kron(a, a, out=b).stride()
(3600, 400, 25, 1)

Copy link
Copy Markdown
Collaborator Author

@zasdfgbnm zasdfgbnm Jan 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we do

result.resize_(result_reshape, self.suggest_memory_format());

instead of resize_output?

Copy link
Copy Markdown
Collaborator

@mruberry mruberry Jan 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, why? We want to be consistent with PyTorch, where the out= semantics are that the operation is performed and then its output is "safe copied" to the tensor passed to out=. If the tensor has to be altered then a warning is thrown, unless the tensor had no elements.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we should preserve any permutation, like we automatically do now, not just channels last.
For a warning, we can do result.view(...with inserted 1 dimensions...) and rely on mul_out to throw a warning if sizes are wrong, or do as Xiang suggested a few comments up, and use resize_output there instead of resize_, it doesn't really matter.

Copy link
Copy Markdown
Collaborator Author

@zasdfgbnm zasdfgbnm Jan 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe add an argument c10::optional<MemoryFormat> optional_memory_format to resize_output, and inside resize_output, check for both shape and memory format.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please no, use of memory_format argument is wrong in most cases outside of torch.nn.

result.copy_(result_tmp);
return result;
Tensor kron(const Tensor& self, const Tensor& other) {
at::Tensor result;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use this pattern more often.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, Sebastian wants to disallow it.

@mruberry
Copy link
Copy Markdown
Collaborator

@zasdfgbnm Would you check a broader set of tensor sizes, too, in your benchmark? Like, what's the smallest vs. the biggest tensors we can reasonably benchmark? Are odd shapes an issue?

@zasdfgbnm zasdfgbnm mentioned this pull request Jan 25, 2021
@zasdfgbnm
Copy link
Copy Markdown
Collaborator Author

Let's move discussions to #51045

facebook-github-bot pushed a commit that referenced this pull request Jan 27, 2021
Summary:
Followup of #50927

Pull Request resolved: #51045

Reviewed By: mruberry

Differential Revision: D26089204

Pulled By: ngimel

fbshipit-source-id: 77291dd83fba32d6f80a8540910b112a1d85a892
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
Because it is shorter, faster, and does not have TF32 issue.

Benchmark: https://github.com/zasdfgbnm/things/blob/master/2021Q1/kron.ipynb

Pull Request resolved: pytorch#50927

Reviewed By: glaringlee

Differential Revision: D26022385

Pulled By: ngimel

fbshipit-source-id: 513c9e9138c35c70d3a475a8407728af21321dae
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
Followup of pytorch#50927

Pull Request resolved: pytorch#51045

Reviewed By: mruberry

Differential Revision: D26089204

Pulled By: ngimel

fbshipit-source-id: 77291dd83fba32d6f80a8540910b112a1d85a892
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants