Port CPU torch.ormqr to ATen#57315
Port CPU torch.ormqr to ATen#57315IvanYashchuk wants to merge 4 commits intogh/ivanyashchuk/27/basefrom
Conversation
This PR ports `torch.ormqr` from TH to ATen. CUDA path will be implemented in a follow-up PR. With ATen port, support for complex and batched inputs is added. The tests are rewritten and OpInfo entry is added. We can implement the least squares solver with geqrf + ormqr + triangular_solve. So it's useful to have this function renewed at least for the internal code. Resolves #24748 [ghstack-poisoned]
💊 CI failures summary and remediationsAs of commit c9a690d (more details on the Dr. CI page):
2 failures not recognized by patterns:
This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.Please report bugs/suggestions to the (internal) Dr. CI Users group. |
This PR ports `torch.ormqr` from TH to ATen. CUDA path will be implemented in a follow-up PR. With ATen port, support for complex and batched inputs is added. The tests are rewritten and OpInfo entry is added. We can implement the least squares solver with geqrf + ormqr + triangular_solve. So it's useful to have this function renewed at least for the internal code. Resolves pytorch#24748 ghstack-source-id: fdc9cbd Pull Request resolved: pytorch#57315
lezcano
left a comment
There was a problem hiding this comment.
Left a few comments, none of them too important.
| other_matrix_shape = (m, n) if left else (n, m) | ||
| other = make_tensor((*batch, *other_matrix_shape), device, dtype, requires_grad=requires_grad) | ||
| kwargs = {"left": left, "transpose": transpose} | ||
| sample_inputs.append(SampleInput(reflectors, args=(tau, other,), kwargs=kwargs)) |
There was a problem hiding this comment.
Prefer writing it as a generator, so that when OpInfos accept generators it's easier to port. See
pytorch/torch/testing/_internal/common_methods_invocations.py
Lines 1626 to 1653 in 51fc406
| add_docstr(torch.ormqr, | ||
| r""" | ||
| ormqr(input, input2, input3, left=True, transpose=False) -> Tensor | ||
| ormqr(input, tau, other, left=True, transpose=False, *, out=None) -> Tensor |
There was a problem hiding this comment.
Thank you for correcting and improving the docs!!
|
|
||
| return samples | ||
|
|
||
| def sample_inputs_ormqr(op_info, device, dtype, requires_grad): |
There was a problem hiding this comment.
Consider defining a helper function of the form:
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
|
|
||
| .. seealso:: | ||
|
|
||
| :func:`torch.geqrf` can be used to form the Householder representation of matrix `Q` |
There was a problem hiding this comment.
"the Householder representation of matrix Q" -> "a Householder representation (input, tau) of the matrix Q
| TORCH_CHECK(other.dim() >= 2, "torch.ormqr: other must have at least 2 dimensions."); | ||
|
|
||
| int64_t left_size_condition = left ? -2 : -1; | ||
| TORCH_CHECK( |
There was a problem hiding this comment.
missing input.size(-2) >= input.size(-1)?
There was a problem hiding this comment.
No, as you see from the tests passing this function works both for m >= n and m < n matrices.
For the case with m < n only the first m columns that represent m Householder vectors are used by LAPACK.
In [1]: import torch
In [2]: a = torch.randn(3, 5)
In [3]: h, tau = torch.geqrf(a)
In [4]: h.shape
Out[4]: torch.Size([3, 5])
In [5]: c = torch.randn(3, 7)
In [6]: res = torch.ormqr(h, tau, c)
In [7]: q, _ = torch.linalg.qr(a)
In [8]: torch.allclose(q @ c, res)
Out[8]: TrueThere was a problem hiding this comment.
Then we should add that the case m < n to the docs.
There was a problem hiding this comment.
Why do you think the case m < n is special?
There was a problem hiding this comment.
Because in householder_product is not even considered:
https://pytorch.org/docs/master/generated/torch.linalg.householder_product.html
There was a problem hiding this comment.
Well, householder_product is related but it's a different function. It has the same constraints on the input as the original orgqr implementation had:
and LAPACK's orgqr has this
m>=n requirement, but n there is the number of columns of Q to be computed, not n=input.shape[-1]. It has a side effect that the output of torch.geqrf can't be directly used with torch.orgqr, but it can be used with ormqr with the current implementation.
In [1]: import torch
In [2]: a = torch.randn(3, 5)
In [3]: h, tau = torch.geqrf(a)
In [4]: c = torch.eye(3)
In [5]: torch.ormqr(h, tau, c) # narrow of `h` is not required
Out[5]:
tensor([[-0.1468, 0.8266, -0.5433],
[-0.0029, -0.5496, -0.8354],
[-0.9892, -0.1211, 0.0831]])
In [6]: torch.linalg.qr(a)[0]
Out[6]:
tensor([[-0.1468, 0.8266, -0.5433],
[-0.0029, -0.5496, -0.8354],
[-0.9892, -0.1211, 0.0831]])
In [7]: torch.linalg.householder_product(h, tau)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-7-70a58677659d> in <module>
----> 1 torch.linalg.householder_product(h, tau)
RuntimeError: torch.linalg.householder_product: input.shape[-2] must be greater than or equal to input.shape[-1]
In [8]: torch.linalg.householder_product(h.narrow(-1, 0, 3), tau) # narrow is required here
Out[8]:
tensor([[-0.1468, 0.8266, -0.5433],
[-0.0029, -0.5496, -0.8354],
[-0.9892, -0.1211, 0.0831]])There was a problem hiding this comment.
I do not have any strong opinions on how to handle this, as I really think that these functions are too low level to form part of the python API.
That being said, it's quite annoying that we have an h, tau decomposition in several functions, and that each of them has slightly different requirements.
There was a problem hiding this comment.
Even more, the seealso section of householder_reflection should be updated to reflect this dissonance.
There was a problem hiding this comment.
In every place where we have a multiplication of Q with some other matrix, using this function should be more efficient. In PyTorch Python code, there is for example one place it could be used, in the implementation of lobpcg the Q matrix is explicitly formed and then used in multiplication:
Line 878 in c371542
pytorch/torch/_linalg_utils.py
Lines 80 to 88 in c371542
To make using this function more accessible, the documentation should be improved, the name should be changed and autograd support added. If we do that, we can revisit the constraints on the input and make it consistent.
There was a problem hiding this comment.
That's spot on. Let's do that in a follow-up after the branch-cut.
This PR ports `torch.ormqr` from TH to ATen. CUDA path will be implemented in a follow-up PR. With ATen port, support for complex and batched inputs is added. The tests are rewritten and OpInfo entry is added. We can implement the least squares solver with geqrf + ormqr + triangular_solve. So it's useful to have this function renewed at least for the internal code. Resolves #24748 [ghstack-poisoned]
This PR ports `torch.ormqr` from TH to ATen. CUDA path will be implemented in a follow-up PR. With ATen port, support for complex and batched inputs is added. The tests are rewritten and OpInfo entry is added. We can implement the least squares solver with geqrf + ormqr + triangular_solve. So it's useful to have this function renewed at least for the internal code. Resolves pytorch#24748 ghstack-source-id: 4bf58ae Pull Request resolved: pytorch#57315
This PR ports `torch.ormqr` from TH to ATen. CUDA path will be implemented in a follow-up PR. With ATen port, support for complex and batched inputs is added. The tests are rewritten and OpInfo entry is added. We can implement the least squares solver with geqrf + ormqr + triangular_solve. So it's useful to have this function renewed at least for the internal code. Resolves #24748 [ghstack-poisoned]
This PR ports `torch.ormqr` from TH to ATen. CUDA path will be implemented in a follow-up PR. With ATen port, support for complex and batched inputs is added. The tests are rewritten and OpInfo entry is added. We can implement the least squares solver with geqrf + ormqr + triangular_solve. So it's useful to have this function renewed at least for the internal code. Resolves pytorch#24748 ghstack-source-id: e9cd751 Pull Request resolved: pytorch#57315
|
@lezcano, thank you for your feedback! I've updated this PR according to your suggestions and I expanded a bit the documentation on the sizes of inputs added a "Raises:" section, could you please take a look? |
lezcano
left a comment
There was a problem hiding this comment.
I just left a couple comments. Just stylistic points.
In any case, this PR is ready to be merged.
This PR ports `torch.ormqr` from TH to ATen. CUDA path will be implemented in a follow-up PR. With ATen port, support for complex and batched inputs is added. The tests are rewritten and OpInfo entry is added. We can implement the least squares solver with geqrf + ormqr + triangular_solve. So it's useful to have this function renewed at least for the internal code. Resolves #24748 [ghstack-poisoned]
This PR ports `torch.ormqr` from TH to ATen. CUDA path will be implemented in a follow-up PR. With ATen port, support for complex and batched inputs is added. The tests are rewritten and OpInfo entry is added. We can implement the least squares solver with geqrf + ormqr + triangular_solve. So it's useful to have this function renewed at least for the internal code. Resolves pytorch#24748 ghstack-source-id: 6df2e90 Pull Request resolved: pytorch#57315
|
@mruberry, I think this stack is ready to be merged. Could you please take another look? I also moved the cuBLAS path for lstsq PR to this stack after the cuSOLVER PR. |
|
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
|
@mruberry has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
Summary: Pull Request resolved: pytorch#57315 This PR ports `torch.ormqr` from TH to ATen. CUDA path will be implemented in a follow-up PR. With ATen port, support for complex and batched inputs is added. The tests are rewritten and OpInfo entry is added. We can implement the least squares solver with geqrf + ormqr + triangular_solve. So it's useful to have this function renewed at least for the internal code. Resolves pytorch#24748 Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D28242070 Pulled By: mruberry fbshipit-source-id: f070bb6ac2f5a3269b163b22f7354e9089ed3061
Summary: Pull Request resolved: pytorch#57315 This PR ports `torch.ormqr` from TH to ATen. CUDA path will be implemented in a follow-up PR. With ATen port, support for complex and batched inputs is added. The tests are rewritten and OpInfo entry is added. We can implement the least squares solver with geqrf + ormqr + triangular_solve. So it's useful to have this function renewed at least for the internal code. Resolves pytorch#24748 Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D28242070 Pulled By: mruberry fbshipit-source-id: f070bb6ac2f5a3269b163b22f7354e9089ed3061
Summary: Pull Request resolved: pytorch#57315 This PR ports `torch.ormqr` from TH to ATen. CUDA path will be implemented in a follow-up PR. With ATen port, support for complex and batched inputs is added. The tests are rewritten and OpInfo entry is added. We can implement the least squares solver with geqrf + ormqr + triangular_solve. So it's useful to have this function renewed at least for the internal code. Resolves pytorch#24748 Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D28242070 Pulled By: mruberry fbshipit-source-id: f070bb6ac2f5a3269b163b22f7354e9089ed3061
Stack from ghstack:
This PR ports
torch.ormqrfrom TH to ATen.CUDA path will be implemented in a follow-up PR.
With ATen port, support for complex and batched inputs is added.
The tests are rewritten and OpInfo entry is added.
We can implement the least squares solver with geqrf + ormqr +
triangular_solve. So it's useful to have this function renewed at least for the
internal code.
Resolves #24748
Differential Revision: D28242070