@@ -695,7 +695,7 @@ static void apply_inverse(Tensor& self, Tensor& infos_lu, Tensor& infos_getri) {
695695 int lwork = -1 ;
696696 scalar_t wkopt;
697697 lapackGetri<scalar_t >(n, self_data, lda, ipiv_data, &wkopt, lwork, &info);
698- lwork = static_cast <int >(real_impl<scalar_t , value_t >(wkopt));
698+ lwork = std::max <int >(1 , real_impl<scalar_t , value_t >(wkopt));
699699 Tensor work = at::empty ({lwork}, self.options ());
700700 auto work_data = work.data_ptr <scalar_t >();
701701
@@ -1211,7 +1211,7 @@ static void apply_geqrf(Tensor& self, Tensor& tau, int64_t m, int64_t n,
12111211 int lwork = -1 ;
12121212 scalar_t wkopt;
12131213 lapackGeqrf<scalar_t >(m, n, self_data, m, tau_data, &wkopt, lwork, &info);
1214- lwork = static_cast <int >(real_impl<scalar_t , value_t >(wkopt));
1214+ lwork = std::max <int >(1 , real_impl<scalar_t , value_t >(wkopt));
12151215 Tensor work = at::empty ({lwork}, self.options ());
12161216
12171217 for (const auto i : c10::irange (batch_size)) {
@@ -1626,7 +1626,7 @@ static void apply_symeig(Tensor& self, Tensor& eigvals, bool eigenvectors, bool
16261626 }
16271627
16281628 lapackSymeig<scalar_t , value_t >(jobz, uplo, n, self_data, n, eigvals_data, &wkopt, lwork, rwork_data, &info);
1629- lwork = static_cast <int >(real_impl<scalar_t , value_t >(wkopt));
1629+ lwork = std::max <int >(1 , real_impl<scalar_t , value_t >(wkopt));
16301630 Tensor work = at::empty ({lwork}, self.options ());
16311631
16321632 for (const auto i : c10::irange (batch_size)) {
@@ -1782,7 +1782,7 @@ static void apply_svd(Tensor& self, Tensor& U, Tensor& S, Tensor& VT,
17821782 int lwork = -1 ;
17831783 scalar_t wkopt;
17841784 lapackSvd<scalar_t , value_t >(jobz, m, n, self_data, lda, S_data, U_data, lda, VT_data, ldvt, &wkopt, lwork, rwork_data, iwork_data, &info);
1785- lwork = static_cast <int >(real_impl<scalar_t , value_t >(wkopt));
1785+ lwork = std::max <int >(1 , real_impl<scalar_t , value_t >(wkopt));
17861786 Tensor work = at::empty ({lwork}, self.options ());
17871787 auto work_data = work.data_ptr <scalar_t >();
17881788
0 commit comments