Skip to content

Commit 7df176b

Browse files
IvanYashchukfacebook-github-bot
authored andcommitted
Added OpInfo-based testing of some linalg functions (#51107)
Summary: Added OpInfo-based testing of the following linear algebra functions: * cholesky, linalg.cholesky * linalg.eigh * inverse, linalg.inv * qr, linalg.qr * solve The output of `torch.linalg.pinv` for empty inputs was not differentiable, now it's fixed. In some cases, batched grad checks are disabled because it doesn't work well with 0x0 matrices (see #50743 (comment)). Ref. #50006 Pull Request resolved: #51107 Reviewed By: albanD Differential Revision: D27006115 Pulled By: mruberry fbshipit-source-id: 3c1d00e3d506948da25d612fb114e6d4a478c5b1
1 parent d46978c commit 7df176b

5 files changed

Lines changed: 228 additions & 150 deletions

File tree

aten/src/ATen/native/BatchLinearAlgebra.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -848,7 +848,8 @@ static void apply_solve(Tensor& b, Tensor& A, Tensor& infos) {
848848
std::tuple<Tensor, Tensor> _solve_helper_cpu(const Tensor& self, const Tensor& A) {
849849
auto self_working_copy = cloneBatchedColumnMajor(self);
850850
auto A_working_copy = cloneBatchedColumnMajor(A);
851-
auto infos = at::empty({std::max<int64_t>(1, batchCount(self))}, self.options().dtype(kInt));
851+
// infos might not get filled for empty inputs therefore at::zeros is used instead of at::empty
852+
auto infos = at::zeros({std::max<int64_t>(1, batchCount(self))}, self.options().dtype(kInt));
852853
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "solve_cpu", [&]{
853854
apply_solve<scalar_t>(self_working_copy, A_working_copy, infos);
854855
});

aten/src/ATen/native/LinearAlgebra.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -142,9 +142,10 @@ Tensor linalg_pinv(const Tensor& input, const Tensor& rcond, bool hermitian) {
142142
if (input.numel() == 0) {
143143
// The implementation below uses operations that do not work for zero numel tensors
144144
// therefore we need this early return for 'input.numel() == 0' case
145-
auto input_sizes = input.sizes().vec();
146-
std::swap(input_sizes[input.dim() - 1], input_sizes[input.dim() - 2]);
147-
return at::empty(input_sizes, input.options());
145+
Tensor U, S, V;
146+
// TODO: replace input.svd with linalg_svd when torch/xla can work with at::linalg_svd
147+
std::tie(U, S, V) = input.svd();
148+
return at::matmul(V * S.reciprocal().unsqueeze(-2), U.conj().transpose(-2, -1));
148149
}
149150

150151
// If not Hermitian use singular value decomposition, else use eigenvalue decomposition

aten/src/ATen/native/cuda/BatchLinearAlgebra.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1272,7 +1272,8 @@ AT_ERROR("solve: MAGMA library not found in "
12721272
std::tuple<Tensor, Tensor> _solve_helper_cuda(const Tensor& self, const Tensor& A) {
12731273
auto self_working_copy = cloneBatchedColumnMajor(self);
12741274
auto A_working_copy = cloneBatchedColumnMajor(A);
1275-
auto infos = at::empty({std::max<int64_t>(1, batchCount(self))}, self.options().dtype(kInt));
1275+
// infos might not get filled for empty inputs therefore at::zeros is used instead of at::empty
1276+
auto infos = at::zeros({std::max<int64_t>(1, batchCount(self))}, self.options().dtype(kInt));
12761277
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "solve_cuda", [&]{
12771278
apply_solve<scalar_t>(self_working_copy, A_working_copy, infos);
12781279
});

test/test_linalg.py

Lines changed: 9 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -467,25 +467,16 @@ def test_cholesky_errors_and_warnings(self, device, dtype):
467467
@skipCUDAIfNoMagma
468468
@skipCPUIfNoLapack
469469
@dtypes(torch.float64, torch.complex128)
470-
def test_cholesky_autograd(self, device, dtype):
471-
def func(root):
472-
x = 0.5 * (root + root.transpose(-1, -2).conj())
473-
return torch.linalg.cholesky(x)
474-
470+
def test_cholesky_hermitian_grad(self, device, dtype):
471+
# Check that the gradient is Hermitian (or symmetric)
475472
def run_test(shape):
476-
root = torch.rand(*shape, dtype=dtype, device=device, requires_grad=True)
477-
root = root + torch.eye(shape[-1], dtype=dtype, device=device)
478-
479-
gradcheck(func, root)
480-
gradgradcheck(func, root)
481-
482473
root = torch.rand(*shape, dtype=dtype, device=device)
483474
root = torch.matmul(root, root.transpose(-1, -2).conj())
484475
root.requires_grad_()
485476
chol = torch.linalg.cholesky(root).sum().backward()
486-
self.assertEqual(root.grad, root.grad.transpose(-1, -2).conj()) # Check the gradient is hermitian
477+
self.assertEqual(root.grad, root.grad.transpose(-1, -2).conj())
487478

488-
shapes = ((3, 3), (4, 3, 2, 2))
479+
shapes = ((3, 3), (1, 1, 3, 3))
489480
for shape in shapes:
490481
run_test(shape)
491482

@@ -909,35 +900,16 @@ def run_test_skipped_elements(shape, batch, uplo):
909900
@skipCUDAIfNoMagma
910901
@skipCPUIfNoLapack
911902
@dtypes(torch.float64, torch.complex128)
912-
def test_eigh_autograd(self, device, dtype):
903+
def test_eigh_hermitian_grad(self, device, dtype):
913904
from torch.testing._internal.common_utils import random_hermitian_matrix
914905

915-
def func(x, uplo):
916-
x = 0.5 * (x + x.conj().transpose(-2, -1))
917-
return torch.linalg.eigh(x, UPLO=uplo)
918-
919-
def func_grad_w(x, uplo):
920-
return func(x, uplo)[0]
921-
922-
def func_grad_v(x, uplo):
923-
# gauge invariant loss function
924-
return abs(func(x, uplo)[1])
925-
926906
def run_test(dims, uplo):
927-
x = torch.randn(*dims, dtype=dtype, device=device, requires_grad=True)
928-
929-
gradcheck(func_grad_w, [x, uplo])
930-
gradgradcheck(func_grad_w, [x, uplo])
931-
932-
gradcheck(func_grad_v, [x, uplo])
933-
gradgradcheck(func_grad_v, [x, uplo])
934-
935907
x = random_hermitian_matrix(dims[-1], *dims[:-2]).requires_grad_()
936908
w, v = torch.linalg.eigh(x)
937909
(w.sum() + abs(v).sum()).backward()
938910
self.assertEqual(x.grad, x.grad.conj().transpose(-1, -2)) # Check the gradient is Hermitian
939911

940-
for dims, uplo in itertools.product([(3, 3), (2, 3, 3)], ["L", "U"]):
912+
for dims, uplo in itertools.product([(3, 3), (1, 1, 3, 3)], ["L", "U"]):
941913
run_test(dims, uplo)
942914

943915
@skipCUDAIfNoMagma
@@ -2630,6 +2602,7 @@ def test_inverse_errors_large(self, device, dtype):
26302602

26312603
@precisionOverride({torch.float32: 1e-3, torch.complex64: 1e-3, torch.float64: 1e-7, torch.complex128: 1e-7})
26322604
@skipCUDAIfNoMagma
2605+
@skipCUDAIfRocm
26332606
@skipCPUIfNoLapack
26342607
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
26352608
def test_pinv(self, device, dtype):
@@ -5602,6 +5575,7 @@ def test_solve_methods_arg_device(self, device):
56025575

56035576
@precisionOverride({torch.float32: 5e-3, torch.complex64: 1e-3})
56045577
@skipCUDAIfNoMagma
5578+
@skipCUDAIfRocm
56055579
@skipCPUIfNoLapack
56065580
@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
56075581
def test_pinverse(self, device, dtype):
@@ -7021,6 +6995,7 @@ def check_norm(a, b, expected_norm, gels_result):
70216995
self.assertEqual((torch.mm(a, tb) - b).norm(), expectedNorm, atol=1e-8, rtol=0)
70226996

70236997
@skipCUDAIfNoMagma
6998+
@skipCUDAIfRocm
70246999
@skipCPUIfNoLapack
70257000
def test_lapack_empty(self, device):
70267001
# FIXME: these are just a selection of LAPACK functions -- we need a general strategy here.

0 commit comments

Comments
 (0)