Skip to content

Port CPU torch.ormqr to ATen#57315

Closed
IvanYashchuk wants to merge 4 commits intogh/ivanyashchuk/27/basefrom
gh/ivanyashchuk/27/head
Closed

Port CPU torch.ormqr to ATen#57315
IvanYashchuk wants to merge 4 commits intogh/ivanyashchuk/27/basefrom
gh/ivanyashchuk/27/head

Conversation

@IvanYashchuk
Copy link
Copy Markdown
Collaborator

@IvanYashchuk IvanYashchuk commented Apr 29, 2021

Stack from ghstack:

This PR ports torch.ormqr from TH to ATen.
CUDA path will be implemented in a follow-up PR.
With ATen port, support for complex and batched inputs is added.
The tests are rewritten and OpInfo entry is added.

We can implement the least squares solver with geqrf + ormqr +
triangular_solve. So it's useful to have this function renewed at least for the
internal code.

Resolves #24748

Differential Revision: D28242070

This PR ports `torch.ormqr` from TH to ATen.
CUDA path will be implemented in a follow-up PR.
With ATen port, support for complex and batched inputs is added.
The tests are rewritten and OpInfo entry is added.

We can implement the least squares solver with geqrf + ormqr +
triangular_solve. So it's useful to have this function renewed at least for the
internal code.

Resolves #24748

[ghstack-poisoned]
@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Apr 29, 2021

💊 CI failures summary and remediations

As of commit c9a690d (more details on the Dr. CI page):


  • 2/2 failures introduced in this PR

2 failures not recognized by patterns:

Job Step Action
GitHub Actions flake8-py3 Unknown 🔁 rerun
GitHub Actions clang-tidy Unknown 🔁 rerun

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

IvanYashchuk added a commit to IvanYashchuk/pytorch that referenced this pull request Apr 29, 2021
This PR ports `torch.ormqr` from TH to ATen.
CUDA path will be implemented in a follow-up PR.
With ATen port, support for complex and batched inputs is added.
The tests are rewritten and OpInfo entry is added.

We can implement the least squares solver with geqrf + ormqr +
triangular_solve. So it's useful to have this function renewed at least for the
internal code.

Resolves pytorch#24748

ghstack-source-id: fdc9cbd
Pull Request resolved: pytorch#57315
@IvanYashchuk IvanYashchuk added module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul module: porting Issues related to porting TH/THNN legacy to ATen native labels Apr 29, 2021
@IvanYashchuk IvanYashchuk requested review from mruberry and removed request for ezyang April 29, 2021 22:24
Copy link
Copy Markdown
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! @lezcano would you like to sanity check this, too?

@mruberry mruberry requested a review from lezcano May 1, 2021 23:36
Copy link
Copy Markdown
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few comments, none of them too important.

other_matrix_shape = (m, n) if left else (n, m)
other = make_tensor((*batch, *other_matrix_shape), device, dtype, requires_grad=requires_grad)
kwargs = {"left": left, "transpose": transpose}
sample_inputs.append(SampleInput(reflectors, args=(tau, other,), kwargs=kwargs))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer writing it as a generator, so that when OpInfos accept generators it's easier to port. See

def gen_inputs():
# Generic inputs
tgt_gen = (make_arg((S, S), noncontiguous=not ctg) for ctg in (True, False))
src_gen = (make_arg((S,), noncontiguous=not ctg) for ctg in (True, False))
idx = torch.randperm(S * S, device=device, dtype=torch.int64)[:S]
idx_nonctg = torch.repeat_interleave(idx, 2, dim=-1)[::2]
idx_neg = -idx - 1
idx_list = [idx, idx_nonctg, idx_neg]
for tgt, idx, src, acc in product(tgt_gen, idx_list, src_gen, (True, False)):
yield SampleInput(input=tgt, args=(idx, src, acc))
# Scalar cases
scalar_sizes = [(), (1,)]
tgt_gen = (make_arg(size) for size in scalar_sizes)
idx_gen = (make_idx(size, high=1) for size in scalar_sizes)
src_gen = (make_arg(size) for size in scalar_sizes)
for tgt, idx, src, acc in product(tgt_gen, idx_gen, src_gen, (True, False)):
yield SampleInput(input=tgt, args=(idx, src, acc))
# Empty cases
tgt_sizes = [(0,), (), (1,), (3, 2)]
tgt_gen = (make_arg(size) for size in tgt_sizes)
idx = make_idx((0,), high=1)
src = make_arg((0,))
for tgt, acc in product(tgt, (True, False)):
yield SampleInput(input=tgt, args=(idx, src, acc))
return list(gen_inputs())

Comment thread torch/_torch_docs.py
Comment thread torch/_torch_docs.py
add_docstr(torch.ormqr,
r"""
ormqr(input, input2, input3, left=True, transpose=False) -> Tensor
ormqr(input, tau, other, left=True, transpose=False, *, out=None) -> Tensor
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for correcting and improving the docs!!


return samples

def sample_inputs_ormqr(op_info, device, dtype, requires_grad):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider defining a helper function of the form:
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)

Comment thread torch/testing/_internal/common_methods_invocations.py Outdated
Comment thread torch/_torch_docs.py Outdated

.. seealso::

:func:`torch.geqrf` can be used to form the Householder representation of matrix `Q`
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"the Householder representation of matrix Q" -> "a Householder representation (input, tau) of the matrix Q

Comment thread test/test_linalg.py
TORCH_CHECK(other.dim() >= 2, "torch.ormqr: other must have at least 2 dimensions.");

int64_t left_size_condition = left ? -2 : -1;
TORCH_CHECK(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing input.size(-2) >= input.size(-1)?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, as you see from the tests passing this function works both for m >= n and m < n matrices.
For the case with m < n only the first m columns that represent m Householder vectors are used by LAPACK.

In [1]: import torch
In [2]: a = torch.randn(3, 5)
In [3]: h, tau = torch.geqrf(a)
In [4]: h.shape
Out[4]: torch.Size([3, 5])
In [5]: c = torch.randn(3, 7)
In [6]: res = torch.ormqr(h, tau, c)
In [7]: q, _ = torch.linalg.qr(a)
In [8]: torch.allclose(q @ c, res)
Out[8]: True

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then we should add that the case m < n to the docs.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you think the case m < n is special?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because in householder_product is not even considered:
https://pytorch.org/docs/master/generated/torch.linalg.householder_product.html

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, householder_product is related but it's a different function. It has the same constraints on the input as the original orgqr implementation had:

THArgCheck(m >= n, 1, "input.size(0) must be greater than or equal to input.size(1)");

and LAPACK's orgqr has this m>=n requirement, but n there is the number of columns of Q to be computed, not n=input.shape[-1]. It has a side effect that the output of torch.geqrf can't be directly used with torch.orgqr, but it can be used with ormqr with the current implementation.

In [1]: import torch
In [2]: a = torch.randn(3, 5)
In [3]: h, tau = torch.geqrf(a)
In [4]: c = torch.eye(3)
In [5]: torch.ormqr(h, tau, c) # narrow of `h` is not required
Out[5]: 
tensor([[-0.1468,  0.8266, -0.5433],
        [-0.0029, -0.5496, -0.8354],
        [-0.9892, -0.1211,  0.0831]])
In [6]: torch.linalg.qr(a)[0]
Out[6]: 
tensor([[-0.1468,  0.8266, -0.5433],
        [-0.0029, -0.5496, -0.8354],
        [-0.9892, -0.1211,  0.0831]])
In [7]: torch.linalg.householder_product(h, tau)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-7-70a58677659d> in <module>
----> 1 torch.linalg.householder_product(h, tau)

RuntimeError: torch.linalg.householder_product: input.shape[-2] must be greater than or equal to input.shape[-1]
In [8]: torch.linalg.householder_product(h.narrow(-1, 0, 3), tau) # narrow is required here
Out[8]: 
tensor([[-0.1468,  0.8266, -0.5433],
        [-0.0029, -0.5496, -0.8354],
        [-0.9892, -0.1211,  0.0831]])

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not have any strong opinions on how to handle this, as I really think that these functions are too low level to form part of the python API.

That being said, it's quite annoying that we have an h, tau decomposition in several functions, and that each of them has slightly different requirements.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even more, the seealso section of householder_reflection should be updated to reflect this dissonance.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In every place where we have a multiplication of Q with some other matrix, using this function should be more efficient. In PyTorch Python code, there is for example one place it could be used, in the implementation of lobpcg the Q matrix is explicitly formed and then used in multiplication:

P = mm(S_, mm(Z[:, n - nc:], _utils.basis(_utils.transpose(Z[:n - nc, n - nc:]))))

def basis(A):
"""Return orthogonal basis of A columns.
"""
if A.is_cuda:
# torch.orgqr is not available in CUDA
Q, _ = torch.qr(A, some=True)
else:
Q = torch.orgqr(*torch.geqrf(A))
return Q

To make using this function more accessible, the documentation should be improved, the name should be changed and autograd support added. If we do that, we can revisit the constraints on the input and make it consistent.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's spot on. Let's do that in a follow-up after the branch-cut.

Comment thread aten/src/ATen/native/BatchLinearAlgebra.cpp
Comment thread aten/src/ATen/native/BatchLinearAlgebra.cpp
This PR ports `torch.ormqr` from TH to ATen.
CUDA path will be implemented in a follow-up PR.
With ATen port, support for complex and batched inputs is added.
The tests are rewritten and OpInfo entry is added.

We can implement the least squares solver with geqrf + ormqr +
triangular_solve. So it's useful to have this function renewed at least for the
internal code.

Resolves #24748

[ghstack-poisoned]
IvanYashchuk added a commit to IvanYashchuk/pytorch that referenced this pull request May 4, 2021
This PR ports `torch.ormqr` from TH to ATen.
CUDA path will be implemented in a follow-up PR.
With ATen port, support for complex and batched inputs is added.
The tests are rewritten and OpInfo entry is added.

We can implement the least squares solver with geqrf + ormqr +
triangular_solve. So it's useful to have this function renewed at least for the
internal code.

Resolves pytorch#24748

ghstack-source-id: 4bf58ae
Pull Request resolved: pytorch#57315
This PR ports `torch.ormqr` from TH to ATen.
CUDA path will be implemented in a follow-up PR.
With ATen port, support for complex and batched inputs is added.
The tests are rewritten and OpInfo entry is added.

We can implement the least squares solver with geqrf + ormqr +
triangular_solve. So it's useful to have this function renewed at least for the
internal code.

Resolves #24748

[ghstack-poisoned]
IvanYashchuk added a commit to IvanYashchuk/pytorch that referenced this pull request May 4, 2021
This PR ports `torch.ormqr` from TH to ATen.
CUDA path will be implemented in a follow-up PR.
With ATen port, support for complex and batched inputs is added.
The tests are rewritten and OpInfo entry is added.

We can implement the least squares solver with geqrf + ormqr +
triangular_solve. So it's useful to have this function renewed at least for the
internal code.

Resolves pytorch#24748

ghstack-source-id: e9cd751
Pull Request resolved: pytorch#57315
@IvanYashchuk
Copy link
Copy Markdown
Collaborator Author

@lezcano, thank you for your feedback! I've updated this PR according to your suggestions and I expanded a bit the documentation on the sizes of inputs added a "Raises:" section, could you please take a look?

Copy link
Copy Markdown
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just left a couple comments. Just stylistic points.
In any case, this PR is ready to be merged.

Comment thread torch/_torch_docs.py
Comment thread torch/_torch_docs.py
Comment thread torch/_torch_docs.py
Comment thread torch/_torch_docs.py
This PR ports `torch.ormqr` from TH to ATen.
CUDA path will be implemented in a follow-up PR.
With ATen port, support for complex and batched inputs is added.
The tests are rewritten and OpInfo entry is added.

We can implement the least squares solver with geqrf + ormqr +
triangular_solve. So it's useful to have this function renewed at least for the
internal code.

Resolves #24748

[ghstack-poisoned]
IvanYashchuk added a commit to IvanYashchuk/pytorch that referenced this pull request May 4, 2021
This PR ports `torch.ormqr` from TH to ATen.
CUDA path will be implemented in a follow-up PR.
With ATen port, support for complex and batched inputs is added.
The tests are rewritten and OpInfo entry is added.

We can implement the least squares solver with geqrf + ormqr +
triangular_solve. So it's useful to have this function renewed at least for the
internal code.

Resolves pytorch#24748

ghstack-source-id: 6df2e90
Pull Request resolved: pytorch#57315
@IvanYashchuk
Copy link
Copy Markdown
Collaborator Author

@mruberry, I think this stack is ready to be merged. Could you please take another look?

I also moved the cuBLAS path for lstsq PR to this stack after the cuSOLVER PR.

@mruberry
Copy link
Copy Markdown
Collaborator

mruberry commented May 6, 2021

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@ngimel ngimel mentioned this pull request May 6, 2021
14 tasks
@mruberry
Copy link
Copy Markdown
Collaborator

mruberry commented May 6, 2021

@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@mruberry merged this pull request in 59d794b.

mrshenli pushed a commit to mrshenli/pytorch that referenced this pull request May 8, 2021
Summary:
Pull Request resolved: pytorch#57315

This PR ports `torch.ormqr` from TH to ATen.
CUDA path will be implemented in a follow-up PR.
With ATen port, support for complex and batched inputs is added.
The tests are rewritten and OpInfo entry is added.

We can implement the least squares solver with geqrf + ormqr +
triangular_solve. So it's useful to have this function renewed at least for the
internal code.

Resolves pytorch#24748

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D28242070

Pulled By: mruberry

fbshipit-source-id: f070bb6ac2f5a3269b163b22f7354e9089ed3061
@facebook-github-bot facebook-github-bot deleted the gh/ivanyashchuk/27/head branch May 9, 2021 14:17
krshrimali pushed a commit to krshrimali/pytorch that referenced this pull request May 19, 2021
Summary:
Pull Request resolved: pytorch#57315

This PR ports `torch.ormqr` from TH to ATen.
CUDA path will be implemented in a follow-up PR.
With ATen port, support for complex and batched inputs is added.
The tests are rewritten and OpInfo entry is added.

We can implement the least squares solver with geqrf + ormqr +
triangular_solve. So it's useful to have this function renewed at least for the
internal code.

Resolves pytorch#24748

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D28242070

Pulled By: mruberry

fbshipit-source-id: f070bb6ac2f5a3269b163b22f7354e9089ed3061
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 25, 2026
Summary:
Pull Request resolved: pytorch#57315

This PR ports `torch.ormqr` from TH to ATen.
CUDA path will be implemented in a follow-up PR.
With ATen port, support for complex and batched inputs is added.
The tests are rewritten and OpInfo entry is added.

We can implement the least squares solver with geqrf + ormqr +
triangular_solve. So it's useful to have this function renewed at least for the
internal code.

Resolves pytorch#24748

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D28242070

Pulled By: mruberry

fbshipit-source-id: f070bb6ac2f5a3269b163b22f7354e9089ed3061
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul module: porting Issues related to porting TH/THNN legacy to ATen native open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants