@@ -467,25 +467,16 @@ def test_cholesky_errors_and_warnings(self, device, dtype):
467467 @skipCUDAIfNoMagma
468468 @skipCPUIfNoLapack
469469 @dtypes (torch .float64 , torch .complex128 )
470- def test_cholesky_autograd (self , device , dtype ):
471- def func (root ):
472- x = 0.5 * (root + root .transpose (- 1 , - 2 ).conj ())
473- return torch .linalg .cholesky (x )
474-
470+ def test_cholesky_hermitian_grad (self , device , dtype ):
471+ # Check that the gradient is Hermitian (or symmetric)
475472 def run_test (shape ):
476- root = torch .rand (* shape , dtype = dtype , device = device , requires_grad = True )
477- root = root + torch .eye (shape [- 1 ], dtype = dtype , device = device )
478-
479- gradcheck (func , root )
480- gradgradcheck (func , root )
481-
482473 root = torch .rand (* shape , dtype = dtype , device = device )
483474 root = torch .matmul (root , root .transpose (- 1 , - 2 ).conj ())
484475 root .requires_grad_ ()
485476 chol = torch .linalg .cholesky (root ).sum ().backward ()
486- self .assertEqual (root .grad , root .grad .transpose (- 1 , - 2 ).conj ()) # Check the gradient is hermitian
477+ self .assertEqual (root .grad , root .grad .transpose (- 1 , - 2 ).conj ())
487478
488- shapes = ((3 , 3 ), (4 , 3 , 2 , 2 ))
479+ shapes = ((3 , 3 ), (1 , 1 , 3 , 3 ))
489480 for shape in shapes :
490481 run_test (shape )
491482
@@ -909,35 +900,16 @@ def run_test_skipped_elements(shape, batch, uplo):
909900 @skipCUDAIfNoMagma
910901 @skipCPUIfNoLapack
911902 @dtypes (torch .float64 , torch .complex128 )
912- def test_eigh_autograd (self , device , dtype ):
903+ def test_eigh_hermitian_grad (self , device , dtype ):
913904 from torch .testing ._internal .common_utils import random_hermitian_matrix
914905
915- def func (x , uplo ):
916- x = 0.5 * (x + x .conj ().transpose (- 2 , - 1 ))
917- return torch .linalg .eigh (x , UPLO = uplo )
918-
919- def func_grad_w (x , uplo ):
920- return func (x , uplo )[0 ]
921-
922- def func_grad_v (x , uplo ):
923- # gauge invariant loss function
924- return abs (func (x , uplo )[1 ])
925-
926906 def run_test (dims , uplo ):
927- x = torch .randn (* dims , dtype = dtype , device = device , requires_grad = True )
928-
929- gradcheck (func_grad_w , [x , uplo ])
930- gradgradcheck (func_grad_w , [x , uplo ])
931-
932- gradcheck (func_grad_v , [x , uplo ])
933- gradgradcheck (func_grad_v , [x , uplo ])
934-
935907 x = random_hermitian_matrix (dims [- 1 ], * dims [:- 2 ]).requires_grad_ ()
936908 w , v = torch .linalg .eigh (x )
937909 (w .sum () + abs (v ).sum ()).backward ()
938910 self .assertEqual (x .grad , x .grad .conj ().transpose (- 1 , - 2 )) # Check the gradient is Hermitian
939911
940- for dims , uplo in itertools .product ([(3 , 3 ), (2 , 3 , 3 )], ["L" , "U" ]):
912+ for dims , uplo in itertools .product ([(3 , 3 ), (1 , 1 , 3 , 3 )], ["L" , "U" ]):
941913 run_test (dims , uplo )
942914
943915 @skipCUDAIfNoMagma
@@ -2630,6 +2602,7 @@ def test_inverse_errors_large(self, device, dtype):
26302602
26312603 @precisionOverride ({torch .float32 : 1e-3 , torch .complex64 : 1e-3 , torch .float64 : 1e-7 , torch .complex128 : 1e-7 })
26322604 @skipCUDAIfNoMagma
2605+ @skipCUDAIfRocm
26332606 @skipCPUIfNoLapack
26342607 @dtypes (torch .float32 , torch .float64 , torch .complex64 , torch .complex128 )
26352608 def test_pinv (self , device , dtype ):
@@ -5602,6 +5575,7 @@ def test_solve_methods_arg_device(self, device):
56025575
56035576 @precisionOverride ({torch .float32 : 5e-3 , torch .complex64 : 1e-3 })
56045577 @skipCUDAIfNoMagma
5578+ @skipCUDAIfRocm
56055579 @skipCPUIfNoLapack
56065580 @dtypes (torch .float32 , torch .float64 , torch .complex64 , torch .complex128 )
56075581 def test_pinverse (self , device , dtype ):
@@ -7021,6 +6995,7 @@ def check_norm(a, b, expected_norm, gels_result):
70216995 self .assertEqual ((torch .mm (a , tb ) - b ).norm (), expectedNorm , atol = 1e-8 , rtol = 0 )
70226996
70236997 @skipCUDAIfNoMagma
6998+ @skipCUDAIfRocm
70246999 @skipCPUIfNoLapack
70257000 def test_lapack_empty (self , device ):
70267001 # FIXME: these are just a selection of LAPACK functions -- we need a general strategy here.
0 commit comments