Skip to content

Commit 5994863

Browse files
xwang233facebook-github-bot
authored andcommitted
Disable TF32 in some linalg tests; Disable TF32 in svd_lowrank forward (#73614)
Summary: Follow up of #73460, #73461 Pull Request resolved: #73614 Reviewed By: malfet Differential Revision: D34772822 Pulled By: ngimel fbshipit-source-id: 4e2bea0173d1b6b01e857ef63ef5c2d8c3802544
1 parent 2c2af72 commit 5994863

1 file changed

Lines changed: 15 additions & 12 deletions

File tree

torch/testing/_internal/common_methods_invocations.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
(onlyCPU, onlyCUDA, onlyNativeDeviceTypes, disablecuDNN, skipCUDAIfNoMagma, skipCUDAIfNoMagmaAndNoCusolver,
2727
skipCUDAIfNoCusolver, skipCPUIfNoLapack, skipCPUIfNoFFT, skipCUDAIfRocm, precisionOverride,
2828
toleranceOverride, tol, has_cusolver)
29-
from torch.testing._internal.common_cuda import CUDA11OrLater, SM53OrLater, SM60OrLater
29+
from torch.testing._internal.common_cuda import CUDA11OrLater, SM53OrLater, SM60OrLater, with_tf32_off
3030
from torch.testing._internal.common_utils import \
3131
(is_iterable_of_tensors,
3232
random_symmetric_matrix, random_symmetric_psd_matrix,
@@ -8956,7 +8956,7 @@ def ref_pairwise_distance(input1, input2):
89568956
check_batched_gradgrad=False,
89578957
sample_inputs_func=sample_inputs_symeig,
89588958
gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
8959-
decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack]),
8959+
decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack, with_tf32_off]),
89608960
# NOTE: clamp has seperate opinfos for scalar min/max (unary op) vs. tensors
89618961
OpInfo('clamp',
89628962
aliases=('clip',),
@@ -9907,7 +9907,7 @@ def ref_pairwise_distance(input1, input2):
99079907
supports_forward_ad=True,
99089908
supports_fwgrad_bwgrad=True,
99099909
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
9910-
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],),
9910+
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],),
99119911
OpInfo('linalg.eig',
99129912
aten_name='linalg_eig',
99139913
op=torch.linalg.eig,
@@ -9918,7 +9918,7 @@ def ref_pairwise_distance(input1, input2):
99189918
check_batched_gradgrad=False,
99199919
supports_forward_ad=True,
99209920
supports_fwgrad_bwgrad=True,
9921-
decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
9921+
decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack, with_tf32_off],
99229922
skips=(
99239923
# Forward-over-reverse gradgrad might be incorrect
99249924
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad'),),),
@@ -9947,7 +9947,7 @@ def ref_pairwise_distance(input1, input2):
99479947
check_batched_gradgrad=False,
99489948
supports_forward_ad=True,
99499949
supports_fwgrad_bwgrad=True,
9950-
decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
9950+
decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack, with_tf32_off],
99519951
skips=(
99529952
# Forward-over-reverse gradgrad might be incorrect
99539953
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad',
@@ -10021,7 +10021,7 @@ def ref_pairwise_distance(input1, input2):
1002110021
supports_forward_ad=True,
1002210022
supports_fwgrad_bwgrad=True,
1002310023
check_batched_grad=False,
10024-
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
10024+
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
1002510025
sample_inputs_func=sample_inputs_linalg_matrix_power,
1002610026
gradcheck_nondet_tol=GRADCHECK_NONDET_TOL,
1002710027
),
@@ -10067,7 +10067,7 @@ def ref_pairwise_distance(input1, input2):
1006710067
aten_name='linalg_matrix_norm',
1006810068
dtypes=floating_and_complex_types(),
1006910069
check_batched_gradgrad=False,
10070-
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
10070+
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
1007110071
sample_inputs_func=sample_inputs_linalg_matrix_norm,
1007210072
skips=(
1007310073
# Pre-existing condition; Needs to be fixed
@@ -12994,7 +12994,7 @@ def ref_pairwise_distance(input1, input2):
1299412994
# We're using at::allclose, which does not have a batching rule
1299512995
check_batched_grad=False,
1299612996
check_batched_gradgrad=False,
12997-
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
12997+
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
1299812998
skips=(
1299912999
# Fixme, forward over backward gives a numerical error
1300013000
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad', dtypes=(torch.complex128,)),
@@ -13010,7 +13010,7 @@ def ref_pairwise_distance(input1, input2):
1301013010
check_batched_grad=False,
1301113011
check_batched_gradgrad=False,
1301213012
sample_inputs_func=sample_inputs_svd,
13013-
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
13013+
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off],
1301413014
skips=(
1301513015
# FIXME forward over backward gives a numerical error
1301613016
DecorateInfo(unittest.expectedFailure, 'TestGradients', 'test_fn_fwgrad_bwgrad', dtypes=(torch.complex128,)),
@@ -13025,7 +13025,7 @@ def ref_pairwise_distance(input1, input2):
1302513025
# We're using at::allclose, which does not have a batching rule
1302613026
check_batched_gradgrad=False,
1302713027
sample_inputs_func=sample_inputs_linalg_svdvals,
13028-
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack]),
13028+
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off]),
1302913029
OpInfo('svd_lowrank',
1303013030
op=lambda *args, **kwargs: wrapper_set_seed(
1303113031
lambda a, b, **kwargs: torch.svd_lowrank(a @ b.mT, **kwargs),
@@ -13039,7 +13039,7 @@ def ref_pairwise_distance(input1, input2):
1303913039
supports_fwgrad_bwgrad=True,
1304013040
supports_forward_ad=True,
1304113041
sample_inputs_func=sample_inputs_svd_lowrank,
13042-
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack,
13042+
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off,
1304313043
DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-03)}),
1304413044
'TestCommon', 'test_noncontiguous_samples',
1304513045
device_type='cuda')],
@@ -13060,7 +13060,10 @@ def ref_pairwise_distance(input1, input2):
1306013060
supports_forward_ad=True,
1306113061
supports_fwgrad_bwgrad=True,
1306213062
sample_inputs_func=sample_inputs_pca_lowrank,
13063-
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack],
13063+
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCPUIfNoLapack, with_tf32_off,
13064+
DecorateInfo(toleranceOverride({torch.float32: tol(atol=1e-03, rtol=1e-03)}),
13065+
'TestCommon', 'test_noncontiguous_samples',
13066+
device_type='cuda')],
1306413067
skips=(
1306513068
# test does not work with passing lambda for op
1306613069
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),

0 commit comments

Comments
 (0)