-
Notifications
You must be signed in to change notification settings - Fork 27.7k
Port CPU torch.orgqr to ATen #50502
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Port CPU torch.orgqr to ATen #50502
Changes from all commits
1b47d8f
496018b
e06f813
b05d1c6
3111406
b3df30a
1dd97ab
b63cf51
0d6fc9f
c52e7c3
1165b22
fe17c6d
19afe91
20b0654
60f3787
a4cb0d6
99f3197
d903b31
5f67bc9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -127,9 +127,6 @@ void lapackTriangularSolve(char uplo, char trans, char diag, int n, int nrhs, sc | |
| template<class scalar_t> | ||
| void lapackGeqrf(int m, int n, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info); | ||
|
|
||
| template<class scalar_t> | ||
| void lapackOrgqr(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info); | ||
|
|
||
| template<class scalar_t, class value_t=scalar_t> | ||
| void lapackSymeig(char jobz, char uplo, int n, scalar_t *a, int lda, value_t *w, scalar_t *work, int lwork, value_t *rwork, int *info); | ||
|
|
||
|
|
@@ -982,44 +979,6 @@ static void apply_geqrf(Tensor& self, Tensor& tau, int64_t m, int64_t n, | |
| #endif | ||
| } | ||
|
|
||
| template<typename scalar_t> | ||
| static void apply_orgqr(Tensor& self, const Tensor& tau, int64_t m, int64_t n_columns, | ||
| int64_t k, std::vector<int64_t>& infos) { | ||
| #ifndef USE_LAPACK | ||
| AT_ERROR("qr: LAPACK library not found in compilation"); | ||
| #else | ||
| using value_t = typename c10::scalar_value_type<scalar_t>::type; | ||
| auto self_data = self.data_ptr<scalar_t>(); | ||
| auto tau_data = tau.data_ptr<scalar_t>(); | ||
| auto self_matrix_stride = matrixStride(self); | ||
| auto tau_stride = tau.size(-1); | ||
| auto batch_size = batchCount(self); | ||
|
|
||
| int info; | ||
| // Run once, first to get the optimum work size. | ||
| // Since we deal with batches of matrices with the same dimensions, doing this outside | ||
| // the loop saves (batch_size - 1) workspace queries which would provide the same result | ||
| // and (batch_size - 1) calls to allocate and deallocate workspace using at::empty() | ||
| int lwork = -1; | ||
| scalar_t wkopt; | ||
| lapackOrgqr<scalar_t>(m, n_columns, k, self_data, m, tau_data, &wkopt, lwork, &info); | ||
| lwork = static_cast<int>(real_impl<scalar_t, value_t>(wkopt)); | ||
| Tensor work = at::empty({lwork}, self.options()); | ||
|
|
||
| for (int64_t i = 0; i < batch_size; i++) { | ||
| scalar_t* self_working_ptr = &self_data[i * self_matrix_stride]; | ||
| scalar_t* tau_working_ptr = &tau_data[i * tau_stride]; | ||
|
|
||
| // now compute the actual Q | ||
| lapackOrgqr<scalar_t>(m, n_columns, k, self_working_ptr, m, tau_working_ptr, work.data_ptr<scalar_t>(), lwork, &info); | ||
| infos[i] = info; | ||
| if (info != 0) { | ||
| return; | ||
| } | ||
| } | ||
| #endif | ||
| } | ||
|
|
||
| std::tuple<Tensor, Tensor> _linalg_qr_helper_cpu(const Tensor& self, std::string mode) { | ||
| bool compute_q, reduced; | ||
| std::tie(compute_q, reduced) = _parse_qr_mode(mode); | ||
|
|
@@ -1074,13 +1033,14 @@ std::tuple<Tensor, Tensor> _linalg_qr_helper_cpu(const Tensor& self, std::string | |
| } | ||
|
|
||
| // Next perform ORGQR for Q using the results (both raw R and TAU) from GEQRF | ||
| auto infos_orgqr = at::empty({std::max<int64_t>(1, batchCount(self))}, self.options().dtype(kInt)); | ||
| AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "qr_cpu", [&]{ | ||
| apply_orgqr<scalar_t>(q_working_copy, tau_working_copy, m, n_columns_q, std::min(m, n), infos); | ||
| apply_orgqr<scalar_t>(q_working_copy, tau_working_copy, infos_orgqr, n_columns_q); | ||
| }); | ||
| if (self.dim() > 2) { | ||
| batchCheckErrors(infos, "qr_cpu"); | ||
| batchCheckErrors(infos_orgqr, "qr_cpu"); | ||
| } else { | ||
| singleCheckErrors(infos[0], "qr_cpu"); | ||
| singleCheckErrors(infos_orgqr.item().toInt(), "qr_cpu"); | ||
| } | ||
| return std::make_tuple(q_working_copy.narrow(-1, 0, n_columns_q), R); | ||
| } | ||
|
|
@@ -1113,6 +1073,114 @@ std::tuple<Tensor&,Tensor&> qr_out(Tensor& Q, Tensor& R, const Tensor& self, boo | |
| return at::linalg_qr_out(Q, R, self, mode); | ||
| } | ||
|
|
||
| // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ orgqr ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
|
||
| DEFINE_DISPATCH(orgqr_stub); | ||
|
|
||
| /* | ||
| The orgqr function allows reconstruction of an orthogonal (or unitary) matrix Q, | ||
| from a sequence of elementary reflectors, such as is produced by the geqrf function. | ||
|
|
||
| Args: | ||
| * `input` - Tensor with the directions of the elementary reflectors below the diagonal. | ||
| * `tau` - Tensor containing the magnitudes of the elementary reflectors. | ||
| * `result` - result Tensor, which will contain the orthogonal (or unitary) matrix Q. | ||
| * `infos` - Tensor to store LAPACK/MAGMA error codes | ||
|
|
||
| For further details, please see the LAPACK/MAGMA documentation. | ||
| */ | ||
| Tensor& orgqr_out_info(const Tensor& input, const Tensor& tau, Tensor& result, Tensor& infos) { | ||
| TORCH_INTERNAL_ASSERT(input.dim() >= 2); | ||
| TORCH_INTERNAL_ASSERT(input.size(-2) >= input.size(-1)); | ||
| TORCH_INTERNAL_ASSERT(input.size(-1) >= tau.size(-1)); | ||
|
|
||
| TORCH_INTERNAL_ASSERT(input.scalar_type() == tau.scalar_type()); | ||
| TORCH_INTERNAL_ASSERT(input.device() == tau.device()); | ||
|
|
||
| TORCH_INTERNAL_ASSERT(result.scalar_type() == input.scalar_type()); | ||
| TORCH_INTERNAL_ASSERT(result.device() == input.device()); | ||
|
|
||
| TORCH_INTERNAL_ASSERT(infos.scalar_type() == at::kInt); | ||
| TORCH_INTERNAL_ASSERT(infos.device() == at::kCPU); | ||
| TORCH_INTERNAL_ASSERT(infos.numel() == std::max<int64_t>(1, batchCount(input))); | ||
|
|
||
| // if result has no elements we can modify it | ||
| if (result.numel() == 0) { | ||
| at::native::resize_as_(result, input.transpose(-2, -1), MemoryFormat::Contiguous); | ||
| result.transpose_(-2, -1); | ||
| } | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. adding early return here fixes internal error. I now wonder how OSS tests pass, because for empty matrix lwork is returned as 0, and that's an illegal value (it should be at least 1)
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reference LAPACK for empty matrices returns lwork as 1 |
||
| // result tensor must be in batched column major order (Fortran contiguous) | ||
| TORCH_INTERNAL_ASSERT(result.transpose(-2, -1).is_contiguous()); | ||
| TORCH_INTERNAL_ASSERT(result.sizes().equals(input.sizes())); | ||
|
|
||
| // orgqr_stub (apply_orgqr) performs calculations in-place and result must be a copy of input | ||
| result.copy_(input); | ||
|
|
||
| // infos must be contiguous | ||
| TORCH_INTERNAL_ASSERT(infos.is_contiguous()); | ||
| infos.fill_(0); | ||
|
|
||
| auto n = input.size(-1); | ||
| result = orgqr_stub(result.device().type(), result, tau, infos, n); | ||
| return result; | ||
| } | ||
|
|
||
| Tensor& orgqr_out(const Tensor& input, const Tensor& tau, Tensor& result) { | ||
| TORCH_CHECK(input.dim() >= 2, "orgqr: input must have at least 2 dimensions."); | ||
| TORCH_CHECK(input.size(-2) >= input.size(-1), "orgqr: input.shape[-2] must be greater than or equal to input.shape[-1]"); | ||
| TORCH_CHECK(input.size(-1) >= tau.size(-1), "orgqr: input.shape[-1] must be greater than or equal to tau.shape[-1]"); | ||
|
|
||
| TORCH_CHECK(tau.scalar_type() == input.scalar_type(), | ||
| "orgqr: tau dtype ", tau.scalar_type(), " does not match input dtype ", input.scalar_type()); | ||
| TORCH_CHECK(input.device() == tau.device(), | ||
| "orgqr: Expected input and tau to be on the same device, but found input on ", | ||
| input.device(), " and tau on ", tau.device(), " instead."); | ||
|
|
||
| TORCH_CHECK(result.scalar_type() == input.scalar_type(), | ||
| "orgqr: result dtype ", result.scalar_type(), " does not match the expected dtype ", input.scalar_type()); | ||
| TORCH_CHECK(result.device() == input.device(), | ||
| "orgqr: Expected result and input to be on the same device, but found result on ", | ||
| result.device(), " and input on ", input.device(), " instead."); | ||
|
|
||
| // TODO: uncomment the following when passing incorrectly sized 'result' is not allowed | ||
| // if (result.numel() != 0) { | ||
| // // Resize messes up the strides, so let's not use at::native::resize_output | ||
| // TORCH_CHECK(result.sizes().equals(input.sizes()), | ||
| // "result shape ", result.sizes(), " does not match input shape ", input.sizes()); | ||
| // } | ||
|
|
||
| // Single matrix MAGMA routine requires 'infos' to reside in CPU memory, | ||
| // therefore we create 'infos' only on CPU for now. | ||
| // This should be changed if cuSOLVER would be used | ||
| auto infos = at::empty({std::max<int64_t>(1, batchCount(input))}, input.options().dtype(kInt).device(kCPU)); | ||
|
|
||
| // if result is not empty and not in batched column major format we have to allocate a temporary tensor | ||
| if (result.numel() != 0 && !result.transpose(-2, -1).is_contiguous()) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is really nice. |
||
| Tensor result_tmp = at::empty({0}, input.options()); | ||
| result_tmp = orgqr_out_info(input, tau, result_tmp, infos); | ||
| at::native::resize_output(result, result_tmp.sizes()); | ||
| result.copy_(result_tmp); | ||
| } else { | ||
| // use result's storage directly | ||
| result = orgqr_out_info(input, tau, result, infos); | ||
| } | ||
|
|
||
| // Now check LAPACK/MAGMA error codes | ||
| if (result.dim() > 2) { | ||
| batchCheckErrors(infos, "orgqr"); | ||
| } else { | ||
| singleCheckErrors(infos.item().toInt(), "orgqr"); | ||
| } | ||
| return result; | ||
| } | ||
|
|
||
| Tensor orgqr(const Tensor& input, const Tensor& tau) { | ||
| Tensor result = at::empty({0}, input.options()); | ||
| result = at::orgqr_outf(input, tau, result); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, that's the actual function now. |
||
| return result; | ||
| } | ||
|
|
||
| // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ syevd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
|
|
||
| // This function computes eigenvalues 'w' and eigenvectors 'v' of the input that is stored initially in 'v' | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,8 @@ | |
|
|
||
| #include <ATen/ATen.h> | ||
| #include <ATen/native/DispatchStub.h> | ||
| #include <ATen/native/LinearAlgebraUtils.h> | ||
| #include <ATen/native/cpu/zmath.h> | ||
|
|
||
| #include <TH/TH.h> // for USE_LAPACK | ||
|
|
||
|
|
@@ -15,10 +17,81 @@ namespace at { namespace native { | |
| template<class scalar_t> | ||
| void lapackEig(char jobvl, char jobvr, int n, scalar_t *a, int lda, scalar_t *wr, scalar_t *wi, scalar_t* vl, int ldvl, scalar_t *vr, int ldvr, scalar_t *work, int lwork, int *info); | ||
|
|
||
| template<class scalar_t> | ||
| void lapackOrgqr(int m, int n, int k, scalar_t *a, int lda, scalar_t *tau, scalar_t *work, int lwork, int *info); | ||
|
|
||
| #endif | ||
|
|
||
| using eig_fn = std::tuple<Tensor, Tensor> (*)(const Tensor&, bool&); | ||
|
|
||
| DECLARE_DISPATCH(eig_fn, eig_stub); | ||
|
|
||
| /* | ||
| The orgqr function allows reconstruction of an orthogonal (or unitary) matrix Q, | ||
| from a sequence of elementary reflectors, such as produced by the geqrf function. | ||
|
|
||
| Args: | ||
| * `self` - Tensor with the directions of the elementary reflectors below the diagonal, | ||
| it will be overwritten with the result | ||
| * `tau` - Tensor containing the magnitudes of the elementary reflectors | ||
| * `infos` - Tensor to store LAPACK's error codes | ||
| * `n_columns` - The number of columns of Q to be computed | ||
|
|
||
| For further details, please see the LAPACK documentation for ORGQR and UNGQR. | ||
| */ | ||
| template <typename scalar_t> | ||
| inline void apply_orgqr(Tensor& self, const Tensor& tau, Tensor& infos, int64_t n_columns) { | ||
| #ifndef USE_LAPACK | ||
| TORCH_CHECK(false, "Calling torch.orgqr on a CPU tensor requires compiling ", | ||
| "PyTorch with LAPACK. Please use PyTorch built with LAPACK support."); | ||
| #else | ||
| // Some LAPACK implementations might not work well with empty matrices: | ||
| // workspace query might return lwork as 0, which is not allowed (requirement is lwork >= 1) | ||
| // We don't need to do any calculations in this case, so let's return early | ||
| if (self.numel() == 0) { | ||
| infos.fill_(0); | ||
| return; | ||
| } | ||
|
|
||
| using value_t = typename c10::scalar_value_type<scalar_t>::type; | ||
| auto self_data = self.data_ptr<scalar_t>(); | ||
| auto tau_data = tau.data_ptr<scalar_t>(); | ||
| auto infos_data = infos.data_ptr<int>(); | ||
| auto self_matrix_stride = matrixStride(self); | ||
| auto tau_stride = tau.size(-1); | ||
| auto batch_size = batchCount(self); | ||
| auto m = self.size(-2); | ||
| auto k = tau.size(-1); | ||
| auto lda = std::max<int64_t>(1, m); | ||
|
|
||
| // LAPACK's requirement | ||
| TORCH_INTERNAL_ASSERT(m >= n_columns); | ||
| TORCH_INTERNAL_ASSERT(n_columns >= k); | ||
|
|
||
| // Run once, first to get the optimum work size. | ||
| // Since we deal with batches of matrices with the same dimensions, doing this outside | ||
| // the loop saves (batch_size - 1) workspace queries which would provide the same result | ||
| // and (batch_size - 1) calls to allocate and deallocate workspace using at::empty() | ||
| int lwork = -1; | ||
| scalar_t wkopt; | ||
| lapackOrgqr<scalar_t>(m, n_columns, k, self_data, lda, tau_data, &wkopt, lwork, &infos_data[0]); | ||
| lwork = static_cast<int>(real_impl<scalar_t, value_t>(wkopt)); | ||
|
Comment on lines
+75
to
+78
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't like this error
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The versioned used internally seems to be much older than the version in CI. Could it be the change in how the function is called, using m instead of lda? I'll try to debug internally, too, to get a better sense for what's going on. @ngimel points out that torch.linalg.qr must be relying on this same function, so it's surprising we haven't seen this issue previously. |
||
| Tensor work = at::empty({lwork}, self.options()); | ||
|
|
||
| for (int64_t i = 0; i < batch_size; i++) { | ||
| scalar_t* self_working_ptr = &self_data[i * self_matrix_stride]; | ||
| scalar_t* tau_working_ptr = &tau_data[i * tau_stride]; | ||
| int* info_working_ptr = &infos_data[i]; | ||
| // now compute the actual Q | ||
| lapackOrgqr<scalar_t>(m, n_columns, k, self_working_ptr, lda, tau_working_ptr, work.data_ptr<scalar_t>(), lwork, info_working_ptr); | ||
| if (*info_working_ptr != 0) { | ||
| return; | ||
| } | ||
| } | ||
| #endif | ||
| } | ||
|
|
||
| using orgqr_fn = Tensor& (*)(Tensor& /*result*/, const Tensor& /*tau*/, Tensor& /*infos*/, int64_t /*n_columns*/); | ||
| DECLARE_DISPATCH(orgqr_fn, orgqr_stub); | ||
|
|
||
| }} // namespace at::native | ||
Uh oh!
There was an error while loading. Please reload this page.