Skip to content

Commit fe4f90c

Browse files
xwang233facebook-github-bot
authored andcommitted
Cusolver inverse check info (#46625)
Summary: Fixes #46557 Pull Request resolved: #46625 Reviewed By: zou3519 Differential Revision: D24438577 Pulled By: ngimel fbshipit-source-id: d00e6eb2eae4aa39ca6ecf5914fe9cf37c24b906
1 parent adffd8e commit fe4f90c

4 files changed

Lines changed: 39 additions & 24 deletions

File tree

aten/src/ATen/native/LinearAlgebraUtils.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,16 +110,16 @@ static inline void batchCheckErrors(std::vector<int64_t>& infos, const char* nam
110110
/*
111111
* This is an overloaded case of the previous function for a tensor of infos.
112112
*/
113-
static inline void batchCheckErrors(const Tensor& infos, const char* name, bool allow_singular=false) {
113+
static inline void batchCheckErrors(const Tensor& infos, const char* name, bool allow_singular=false, int info_per_batch=1) {
114114
auto batch_size = infos.numel();
115115
auto infos_cpu = infos.to(at::kCPU);
116116
auto infos_data = infos_cpu.data_ptr<int>();
117117
for (int64_t i = 0; i < batch_size; i++) {
118118
auto info = infos_data[i];
119119
if (info < 0) {
120-
AT_ERROR(name, ": For batch ", i, ": Argument ", -info, " has illegal value");
120+
AT_ERROR(name, ": For batch ", i/info_per_batch, ": Argument ", -info, " has illegal value");
121121
} else if (!allow_singular && info > 0) {
122-
AT_ERROR(name, ": For batch ", i, ": U(", info, ",", info, ") is zero, singular U.");
122+
AT_ERROR(name, ": For batch ", i/info_per_batch, ": U(", info, ",", info, ") is zero, singular U.");
123123
}
124124
}
125125
}

aten/src/ATen/native/cuda/BatchLinearAlgebraLib.cu

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ inline static void _apply_single_inverse_helper(scalar_t* self_ptr, scalar_t* se
3131

3232
auto handle = at::cuda::getCurrentCUDASolverDnHandle();
3333
at::cuda::solver::getrf<scalar_t>(handle, n, n, self_ptr, n, ipiv_ptr, info_ptr);
34-
at::cuda::solver::getrs<scalar_t>(handle, n, n, self_ptr, n, ipiv_ptr, self_inv_ptr, n, info_ptr);
34+
at::cuda::solver::getrs<scalar_t>(handle, n, n, self_ptr, n, ipiv_ptr, self_inv_ptr, n, info_ptr + 1);
3535
}
3636

3737
template <typename scalar_t>
@@ -60,7 +60,7 @@ static void apply_batched_inverse_lib(Tensor& self, Tensor& self_inv, Tensor& in
6060

6161
int* pivot = reinterpret_cast<int*>(allocator.allocate(sizeof(int) * n).get());
6262
_apply_single_inverse_helper<scalar_t>(
63-
&self_data[i * self_mat_stride], &self_inv_data[i * self_inv_mat_stride], pivot, p_infos + i, n);
63+
&self_data[i * self_mat_stride], &self_inv_data[i * self_inv_mat_stride], pivot, p_infos + i * 2, n);
6464

6565
at::cuda::CUDAEvent finished;
6666
finished.record(stream);
@@ -88,16 +88,13 @@ static void apply_batched_inverse_lib(Tensor& self, Tensor& self_inv, Tensor& in
8888
}
8989

9090
template <typename scalar_t>
91-
static void apply_single_inverse_lib(const Tensor& self, Tensor& self_inv, int64_t& info) {
91+
static void apply_single_inverse_lib(const Tensor& self, Tensor& self_inv, Tensor& info) {
9292
int n = cuda_int_cast(self.size(-2), "self.size(-2)");
9393

9494
Tensor ipiv = at::empty({n}, self.options().dtype(at::kInt));
95-
Tensor info_tmp = at::zeros({1}, self.options().dtype(at::kInt));
9695

9796
_apply_single_inverse_helper<scalar_t>(
98-
self.data_ptr<scalar_t>(), self_inv.data_ptr<scalar_t>(), ipiv.data_ptr<int>(), info_tmp.data_ptr<int>(), n);
99-
100-
info = info_tmp.item<int>();
97+
self.data_ptr<scalar_t>(), self_inv.data_ptr<scalar_t>(), ipiv.data_ptr<int>(), info.data_ptr<int>(), n);
10198
}
10299

103100
Tensor _inverse_helper_cuda_lib(const Tensor& self) {
@@ -106,18 +103,17 @@ Tensor _inverse_helper_cuda_lib(const Tensor& self) {
106103
const int batch_size = cuda_int_cast(batchCount(self), "batchCount");
107104

108105
if (self.dim() > 2 && batch_size > 1) {
109-
Tensor infos = at::zeros({batchCount(self)}, self.options().dtype(kInt));
106+
Tensor infos = at::zeros({batchCount(self) * 2}, self.options().dtype(kInt));
110107
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "inverse_cuda", [&]{
111-
apply_batched_inverse_lib<scalar_t>(
112-
self_working_copy, self_inv_working_copy, infos);
108+
apply_batched_inverse_lib<scalar_t>(self_working_copy, self_inv_working_copy, infos);
113109
});
114-
batchCheckErrors(infos, "inverse_cuda");
110+
batchCheckErrors(infos, "inverse_cuda", false, 2);
115111
} else {
116-
int64_t info = 0;
112+
Tensor info = at::zeros({2}, self.options().dtype(at::kInt));
117113
AT_DISPATCH_FLOATING_TYPES(self.scalar_type(), "inverse_cuda", [&]{
118114
apply_single_inverse_lib<scalar_t>(self_working_copy, self_inv_working_copy, info);
119115
});
120-
singleCheckErrors(info, "inverse_cuda");
116+
batchCheckErrors(info, "inverse_cuda", false, 2);
121117
}
122118

123119
return self_inv_working_copy;

test/test_torch.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,14 @@
3434
wrapDeterministicFlagAPITest, make_tensor)
3535
from multiprocessing.reduction import ForkingPickler
3636
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
37-
skipCPUIfNoLapack, skipCUDAIfNoMagma, skipCUDAIfRocm, skipCUDAIfNotRocm, onlyCUDA, onlyCPU, \
37+
skipCPUIfNoLapack, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCUDAIfNotRocm, \
38+
onlyCUDA, onlyCPU, \
3839
dtypes, dtypesIfCUDA, dtypesIfCPU, deviceCountAtLeast, skipCUDAIf, precisionOverride, \
3940
PYTORCH_CUDA_MEMCHECK, largeCUDATensorTest, largeTensorTest, onlyOnCPUAndCUDA, expectedAlertNondeterministic
4041
from typing import Dict, List, Tuple, Union
4142
import torch.backends.quantized
4243
import torch.testing._internal.data
43-
from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_is_not_fp32, with_tf32_off, \
44-
_get_torch_cuda_version, TEST_MAGMA
44+
from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_is_not_fp32, with_tf32_off
4545

4646

4747
# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
@@ -6073,10 +6073,7 @@ def test_pow(self, device):
60736073
torch.pow(m1, 1, out=out)
60746074
self.assertEqual(out, m1)
60756075

6076-
@skipCUDAIf(
6077-
_get_torch_cuda_version() < [10, 0] and not TEST_MAGMA,
6078-
"On cuda 9.2, torch.inverse relies on magma"
6079-
)
6076+
@skipCUDAIfNoMagmaAndNoCusolver
60806077
@skipCPUIfNoLapack
60816078
def test_inverse(self, device):
60826079
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value
@@ -6144,6 +6141,20 @@ def test_inverse_helper_zero_size(size):
61446141
expected_inv = torch.as_tensor(inv(matrices.cpu().numpy()))
61456142
self.assertEqual(matrices_inverse, expected_inv.to(device))
61466143

6144+
@skipCUDAIfNoMagmaAndNoCusolver
6145+
@skipCPUIfNoLapack
6146+
@onlyOnCPUAndCUDA # TODO: XLA doesn't raise exception
6147+
def test_inverse_singular(self, device):
6148+
def helper(batch_dim, n):
6149+
x = torch.eye(3, 3, dtype=torch.float, device=device).reshape((1, 3, 3)).repeat(batch_dim, 1, 1)
6150+
x[n, -1, -1] = 0
6151+
6152+
with self.assertRaisesRegex(RuntimeError, rf'For batch {n}: U\(3,3\) is zero'):
6153+
torch.inverse(x)
6154+
6155+
for params in [(1, 0), (2, 0), (2, 1), (4, 0), (4, 2), (10, 2)]:
6156+
helper(*params)
6157+
61476158
@unittest.skipIf(not TEST_NUMPY, 'NumPy not found')
61486159
@onlyOnCPUAndCUDA
61496160
@dtypes(torch.int8, torch.int16, torch.int32, torch.int64)
@@ -6870,7 +6881,7 @@ def test_is_set_to(self, device):
68706881
self.assertFalse(t2.is_set_to(t1))
68716882

68726883
@slowTest
6873-
@skipCUDAIfNoMagma
6884+
@skipCUDAIfNoMagmaAndNoCusolver
68746885
@skipCPUIfNoLapack
68756886
def test_inverse_many_batches(self, device):
68766887
from torch.testing._internal.common_utils import random_fullrank_matrix_distinct_singular_value

torch/testing/_internal/common_device_type.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import torch
1111
from torch.testing._internal.common_utils import TestCase, TEST_WITH_ROCM, TEST_MKL, \
1212
skipCUDANonDefaultStreamIf, TEST_WITH_ASAN, TEST_WITH_UBSAN, TEST_WITH_TSAN
13+
from torch.testing._internal.common_cuda import _get_torch_cuda_version
1314
from torch.testing import \
1415
(get_all_dtypes)
1516

@@ -801,6 +802,13 @@ def skipCPUIfNoMkl(fn):
801802
def skipCUDAIfNoMagma(fn):
802803
return skipCUDAIf('no_magma', "no MAGMA library detected")(skipCUDANonDefaultStreamIf(True)(fn))
803804

805+
def skipCUDAIfNoMagmaAndNoCusolver(fn):
806+
version = _get_torch_cuda_version()
807+
if version >= [10, 2]:
808+
return fn
809+
else:
810+
# cuSolver is disabled on cuda < 10.1.243, tests depend on MAGMA
811+
return skipCUDAIfNoMagma(fn)
804812

805813
# Skips a test on CUDA when using ROCm.
806814
def skipCUDAIfRocm(fn):

0 commit comments

Comments
 (0)