@@ -2088,15 +2088,15 @@ AT_ERROR("symeig: MAGMA library not found in "
20882088
20892089 scalar_t * work;
20902090 magma_int_t * iwork;
2091- lwork = magma_int_cast (real_impl<scalar_t , value_t >(wkopt), " work_size" );
2092- liwork = magma_int_cast (iwkopt, " iwork_size" );
2091+ lwork = magma_int_cast (std::max< int64_t >( 1 , real_impl<scalar_t , value_t >(wkopt) ), " work_size" );
2092+ liwork = magma_int_cast (std::max< int64_t >( 1 , iwkopt) , " iwork_size" );
20932093 ALLOCATE_ARRAY (work, scalar_t , lwork);
20942094 ALLOCATE_ARRAY (iwork, magma_int_t , liwork);
20952095
20962096 value_t * rwork = nullptr ;
20972097 c10::Storage storage_rwork;
20982098 if (isComplexType (at::typeMetaToScalarType (self.dtype ()))) {
2099- lrwork = magma_int_cast (rwkopt, " rwork_size" );
2099+ lrwork = magma_int_cast (std::max< int64_t >( 1 , rwkopt) , " rwork_size" );
21002100 storage_rwork = pin_memory<value_t >(lrwork);
21012101 rwork = static_cast <value_t *>(storage_rwork.data ());
21022102 }
@@ -2288,9 +2288,9 @@ AT_ERROR("svd: MAGMA library not found in "
22882288 value_t * rwork = nullptr ;
22892289
22902290 magma_int_t * iwork;
2291- ALLOCATE_ARRAY (iwork, magma_int_t , 8 * mn);
2291+ ALLOCATE_ARRAY (iwork, magma_int_t , std::max< magma_int_t >( 1 , 8 * mn) );
22922292 if (isComplexType (at::typeMetaToScalarType (self.dtype ()))) {
2293- auto lrwork = computeLRWorkDim (jobchar, m, n);
2293+ auto lrwork = std::max< int64_t >( 1 , computeLRWorkDim (jobchar, m, n) );
22942294 storage_rwork = pin_memory<value_t >(lrwork);
22952295 rwork = static_cast <value_t *>(storage_rwork.data ());
22962296 }
@@ -2303,7 +2303,7 @@ AT_ERROR("svd: MAGMA library not found in "
23032303 magma_int_t lwork = -1 ;
23042304 scalar_t wkopt;
23052305 magmaSvd<scalar_t , value_t >(jobz, m, n, self_data, lda, S_data, U_data, lda, VT_data, ldvt, &wkopt, lwork, rwork, iwork, &info);
2306- lwork = magma_int_cast (real_impl<scalar_t , value_t >(wkopt), " work_size" );
2306+ lwork = magma_int_cast (std::max< int64_t >( 1 , real_impl<scalar_t , value_t >(wkopt) ), " work_size" );
23072307 scalar_t * work;
23082308 ALLOCATE_ARRAY (work, scalar_t , lwork);
23092309
@@ -2475,9 +2475,9 @@ Tensor _lu_solve_helper_cuda(const Tensor& self, const Tensor& LU_data, const Te
24752475 TORCH_CHECK (info == 0 , " MAGMA lu_solve : invalid argument: " , -info);
24762476 return self_working_copy;
24772477}
2478- // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
24792478
24802479// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ lstsq ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
2480+
24812481std::tuple<Tensor, Tensor, Tensor> _lstsq_helper_cuda (
24822482 const Tensor& a, const Tensor& b, double cond, c10::optional<std::string> driver_name) {
24832483#ifndef USE_MAGMA
@@ -2492,8 +2492,8 @@ AT_ERROR("torch.linalg.lstsq: MAGMA library not found in "
24922492 auto ldda = std::max<magma_int_t >(1 , m);
24932493 auto lddb = std::max<magma_int_t >(1 , std::max (m, n));
24942494 auto nb = magmaGeqrfOptimalBlocksize<scalar_t >(m, n);
2495- auto lwork = ( m - n + nb) * (nrhs + nb) + nrhs * nb;
2496- Tensor hwork = at::empty ({static_cast < int64_t >( lwork) }, a.scalar_type ());
2495+ magma_int_t lwork = magma_int_cast (std::max< int64_t >( 1 , ( m - n + nb) * (nrhs + nb) + nrhs * nb), " work_size " ) ;
2496+ Tensor hwork = at::empty ({lwork}, a.scalar_type ());
24972497 auto * hwork_ptr = hwork.data_ptr <scalar_t >();
24982498 magma_int_t info;
24992499
@@ -2512,7 +2512,6 @@ AT_ERROR("torch.linalg.lstsq: MAGMA library not found in "
25122512 return std::make_tuple (b, rank, singular_values);
25132513#endif
25142514}
2515- // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
25162515
25172516}} // namespace at::native
25182517
0 commit comments