@@ -31,7 +31,7 @@ inline static void _apply_single_inverse_helper(scalar_t* self_ptr, scalar_t* se
3131
3232 auto handle = at::cuda::getCurrentCUDASolverDnHandle ();
3333 at::cuda::solver::getrf<scalar_t >(handle, n, n, self_ptr, n, ipiv_ptr, info_ptr);
34- at::cuda::solver::getrs<scalar_t >(handle, n, n, self_ptr, n, ipiv_ptr, self_inv_ptr, n, info_ptr);
34+ at::cuda::solver::getrs<scalar_t >(handle, n, n, self_ptr, n, ipiv_ptr, self_inv_ptr, n, info_ptr + 1 );
3535}
3636
3737template <typename scalar_t >
@@ -60,7 +60,7 @@ static void apply_batched_inverse_lib(Tensor& self, Tensor& self_inv, Tensor& in
6060
6161 int * pivot = reinterpret_cast <int *>(allocator.allocate (sizeof (int ) * n).get ());
6262 _apply_single_inverse_helper<scalar_t >(
63- &self_data[i * self_mat_stride], &self_inv_data[i * self_inv_mat_stride], pivot, p_infos + i, n);
63+ &self_data[i * self_mat_stride], &self_inv_data[i * self_inv_mat_stride], pivot, p_infos + i * 2 , n);
6464
6565 at::cuda::CUDAEvent finished;
6666 finished.record (stream);
@@ -88,16 +88,13 @@ static void apply_batched_inverse_lib(Tensor& self, Tensor& self_inv, Tensor& in
8888}
8989
9090template <typename scalar_t >
91- static void apply_single_inverse_lib (const Tensor& self, Tensor& self_inv, int64_t & info) {
91+ static void apply_single_inverse_lib (const Tensor& self, Tensor& self_inv, Tensor & info) {
9292 int n = cuda_int_cast (self.size (-2 ), " self.size(-2)" );
9393
9494 Tensor ipiv = at::empty ({n}, self.options ().dtype (at::kInt ));
95- Tensor info_tmp = at::zeros ({1 }, self.options ().dtype (at::kInt ));
9695
9796 _apply_single_inverse_helper<scalar_t >(
98- self.data_ptr <scalar_t >(), self_inv.data_ptr <scalar_t >(), ipiv.data_ptr <int >(), info_tmp.data_ptr <int >(), n);
99-
100- info = info_tmp.item <int >();
97+ self.data_ptr <scalar_t >(), self_inv.data_ptr <scalar_t >(), ipiv.data_ptr <int >(), info.data_ptr <int >(), n);
10198}
10299
103100Tensor _inverse_helper_cuda_lib (const Tensor& self) {
@@ -106,18 +103,17 @@ Tensor _inverse_helper_cuda_lib(const Tensor& self) {
106103 const int batch_size = cuda_int_cast (batchCount (self), " batchCount" );
107104
108105 if (self.dim () > 2 && batch_size > 1 ) {
109- Tensor infos = at::zeros ({batchCount (self)}, self.options ().dtype (kInt ));
106+ Tensor infos = at::zeros ({batchCount (self) * 2 }, self.options ().dtype (kInt ));
110107 AT_DISPATCH_FLOATING_TYPES (self.scalar_type (), " inverse_cuda" , [&]{
111- apply_batched_inverse_lib<scalar_t >(
112- self_working_copy, self_inv_working_copy, infos);
108+ apply_batched_inverse_lib<scalar_t >(self_working_copy, self_inv_working_copy, infos);
113109 });
114- batchCheckErrors (infos, " inverse_cuda" );
110+ batchCheckErrors (infos, " inverse_cuda" , false , 2 );
115111 } else {
116- int64_t info = 0 ;
112+ Tensor info = at::zeros ({ 2 }, self. options (). dtype (at:: kInt )) ;
117113 AT_DISPATCH_FLOATING_TYPES (self.scalar_type (), " inverse_cuda" , [&]{
118114 apply_single_inverse_lib<scalar_t >(self_working_copy, self_inv_working_copy, info);
119115 });
120- singleCheckErrors (info, " inverse_cuda" );
116+ batchCheckErrors (info, " inverse_cuda" , false , 2 );
121117 }
122118
123119 return self_inv_working_copy;
0 commit comments