Skip to content

Commit f4824c2

Browse files
committed
Fixed worksize
1 parent ee4ce8e commit f4824c2

3 files changed

Lines changed: 6 additions & 6 deletions

File tree

aten/src/ATen/native/BatchLinearAlgebra.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,7 @@ static void apply_inverse(Tensor& self, Tensor& infos_lu, Tensor& infos_getri) {
695695
int lwork = -1;
696696
scalar_t wkopt;
697697
lapackGetri<scalar_t>(n, self_data, lda, ipiv_data, &wkopt, lwork, &info);
698-
lwork = static_cast<int>(real_impl<scalar_t, value_t>(wkopt));
698+
lwork = std::max<int>(1, real_impl<scalar_t, value_t>(wkopt));
699699
Tensor work = at::empty({lwork}, self.options());
700700
auto work_data = work.data_ptr<scalar_t>();
701701

@@ -1211,7 +1211,7 @@ static void apply_geqrf(Tensor& self, Tensor& tau, int64_t m, int64_t n,
12111211
int lwork = -1;
12121212
scalar_t wkopt;
12131213
lapackGeqrf<scalar_t>(m, n, self_data, m, tau_data, &wkopt, lwork, &info);
1214-
lwork = static_cast<int>(real_impl<scalar_t, value_t>(wkopt));
1214+
lwork = std::max<int>(1, real_impl<scalar_t, value_t>(wkopt));
12151215
Tensor work = at::empty({lwork}, self.options());
12161216

12171217
for (const auto i : c10::irange(batch_size)) {
@@ -1626,7 +1626,7 @@ static void apply_symeig(Tensor& self, Tensor& eigvals, bool eigenvectors, bool
16261626
}
16271627

16281628
lapackSymeig<scalar_t, value_t>(jobz, uplo, n, self_data, n, eigvals_data, &wkopt, lwork, rwork_data, &info);
1629-
lwork = static_cast<int>(real_impl<scalar_t, value_t>(wkopt));
1629+
lwork = std::max<int>(1, real_impl<scalar_t, value_t>(wkopt));
16301630
Tensor work = at::empty({lwork}, self.options());
16311631

16321632
for (const auto i : c10::irange(batch_size)) {
@@ -1782,7 +1782,7 @@ static void apply_svd(Tensor& self, Tensor& U, Tensor& S, Tensor& VT,
17821782
int lwork = -1;
17831783
scalar_t wkopt;
17841784
lapackSvd<scalar_t, value_t>(jobz, m, n, self_data, lda, S_data, U_data, lda, VT_data, ldvt, &wkopt, lwork, rwork_data, iwork_data, &info);
1785-
lwork = static_cast<int>(real_impl<scalar_t, value_t>(wkopt));
1785+
lwork = std::max<int>(1, real_impl<scalar_t, value_t>(wkopt));
17861786
Tensor work = at::empty({lwork}, self.options());
17871787
auto work_data = work.data_ptr<scalar_t>();
17881788

aten/src/ATen/native/BatchLinearAlgebra.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ inline void apply_orgqr(Tensor& self, const Tensor& tau, Tensor& infos, int64_t
8282
int lwork = -1;
8383
scalar_t wkopt;
8484
lapackOrgqr<scalar_t>(m, n_columns, k, self_data, lda, tau_data, &wkopt, lwork, &infos_data[0]);
85-
lwork = static_cast<int>(real_impl<scalar_t, value_t>(wkopt));
85+
lwork = std::max<int>(1, real_impl<scalar_t, value_t>(wkopt));
8686
Tensor work = at::empty({lwork}, self.options());
8787

8888
for (int64_t i = 0; i < batch_size; i++) {

aten/src/ATen/native/BatchLinearAlgebraKernel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ void apply_eig(const Tensor& self, bool eigenvectors, Tensor& vals_, Tensor& vec
115115
int info;
116116
lapackEig<scalar_t, value_t>('N', jobvr, n, self_data, n, wr,
117117
nullptr, 1, vecs_data, ldvr, &wkopt, -1, rwork_data, &info);
118-
int lwork = static_cast<int>(real_impl<scalar_t, value_t>(wkopt));
118+
int lwork = std::max<int>(1, real_impl<scalar_t, value_t>(wkopt));
119119

120120
// call again to do the actual work
121121
Tensor work = at::empty({lwork}, self.dtype());

0 commit comments

Comments
 (0)