Skip to content

Commit 6a77ecf

Browse files
committed
OpInfo: Add DecorateInfo class similar to SkipInfo for decorators
1 parent ce30dba commit 6a77ecf

2 files changed

Lines changed: 73 additions & 46 deletions

File tree

torch/testing/_internal/common_device_type.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -249,25 +249,34 @@ def instantiate_test_helper(cls, name, *, test, dtype, op):
249249
# op-specific decorators to the original test.
250250
# Test-sepcific decorators are applied to the original test,
251251
# however.
252-
if op is not None and op.decorators is not None:
252+
if op is not None:
253+
active_decorators = []
254+
if op.should_skip(generic_cls.__name__, name, cls.device_type, dtype):
255+
active_decorators.append(skipIf(True, "Skipped!"))
256+
257+
if op.decorators is not None:
258+
for decorator in op.decorators:
259+
# Can't use isinstance as it would cause a circular import
260+
if decorator.__class__.__name__ == 'DecorateInfo':
261+
if decorator.is_active(generic_cls.__name__, name, cls.device_type, dtype):
262+
active_decorators += decorator.decorators
263+
else:
264+
active_decorators.append(decorator)
265+
253266
@wraps(test)
254267
def test_wrapper(*args, **kwargs):
255268
return test(*args, **kwargs)
256269

257-
for decorator in op.decorators:
270+
for decorator in active_decorators:
258271
test_wrapper = decorator(test_wrapper)
259272

260273
test_fn = test_wrapper
261274
else:
262275
test_fn = test
263276

264277
# Constructs the test
265-
@wraps(test)
278+
@wraps(test_fn)
266279
def instantiated_test(self, name=name, test=test_fn, dtype=dtype, op=op):
267-
if op is not None and op.should_skip(generic_cls.__name__, name,
268-
self.device_type, dtype):
269-
self.skipTest("Skipped!")
270-
271280
device_arg: str = cls.get_primary_device()
272281
if hasattr(test_fn, 'num_required_devices'):
273282
device_arg = cls.get_all_devices()

torch/testing/_internal/common_methods_invocations.py

Lines changed: 57 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99
from torch._six import inf, istuple
1010
from torch.autograd import Variable
11+
import collections.abc
1112

1213
from typing import List, Tuple, Dict, Any
1314

@@ -16,37 +17,61 @@
1617
floating_and_complex_types, floating_and_complex_types_and,
1718
all_types_and_complex_and, all_types_and)
1819
from torch.testing._internal.common_device_type import \
19-
(skipCUDAIfNoMagma, skipCPUIfNoLapack, skipCPUIfNoMkl, skipCUDAIfRocm,
20-
expectedAlertNondeterministic, precisionOverride, onlyCPU)
20+
(skipIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, skipCPUIfNoMkl,
21+
skipCUDAIfRocm, expectedAlertNondeterministic, precisionOverride, onlyCPU)
2122
from torch.testing._internal.common_cuda import tf32_is_not_fp32
2223
from torch.testing._internal.common_utils import \
2324
(prod_single_zero, random_square_matrix_of_rank,
2425
random_symmetric_matrix, random_symmetric_psd_matrix,
2526
random_symmetric_pd_matrix, make_nonzero_det,
2627
random_fullrank_matrix_distinct_singular_value, set_rng_seed,
2728
TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, make_tensor, TEST_SCIPY,
28-
torch_to_numpy_dtype_dict, TEST_WITH_SLOW)
29+
torch_to_numpy_dtype_dict, slowTest)
2930

3031
from distutils.version import LooseVersion
3132

3233
if TEST_SCIPY:
3334
import scipy.special
3435

35-
class SkipInfo(object):
36-
"""Describes which test, or type of tests, should be skipped when testing
37-
an operator. Any test that matches all provided arguments will be skipped.
38-
The skip will only be checked if the active_if argument is True."""
3936

40-
__slots__ = ['cls_name', 'test_name', 'device_type', 'dtypes', 'active_if']
37+
class DecorateInfo(object):
38+
"""Describes which test, or type of tests, should be wrapped in the given
39+
decorators when testing an operator. Any test that matches all provided
40+
arguments will be decorated. The decorators will only be applied if the
41+
active_if argument is True."""
4142

42-
def __init__(self, cls_name=None, test_name=None, *,
43+
__slots__ = ['decorators', 'cls_name', 'test_name', 'device_type', 'dtypes', 'active_if']
44+
45+
def __init__(self, decorators, cls_name=None, test_name=None, *,
4346
device_type=None, dtypes=None, active_if=True):
47+
self.decorators = list(decorators) if isinstance(decorators, collections.abc.Sequence) else [decorators]
4448
self.cls_name = cls_name
4549
self.test_name = test_name
4650
self.device_type = device_type
4751
self.dtypes = dtypes
4852
self.active_if = active_if
4953

54+
def is_active(self, cls_name, test_name, device_type, dtype):
55+
return (
56+
self.active_if and
57+
(self.cls_name is None or self.cls_name == cls_name) and
58+
(self.test_name is None or self.test_name == test_name) and
59+
(self.device_type is None or self.device_type == device_type) and
60+
(self.dtypes is None or dtype in self.dtypes)
61+
)
62+
63+
64+
class SkipInfo(DecorateInfo):
65+
"""Describes which test, or type of tests, should be skipped when testing
66+
an operator. Any test that matches all provided arguments will be skipped.
67+
The skip will only be checked if the active_if argument is True."""
68+
69+
def __init__(self, cls_name=None, test_name=None, *,
70+
device_type=None, dtypes=None, active_if=True):
71+
super().__init__(decorators=skipIf(True, "Skipped!"), cls_name=cls_name,
72+
test_name=test_name, device_type=device_type, dtypes=dtypes,
73+
active_if=active_if)
74+
5075
class SampleInput(object):
5176
"""Represents sample inputs to a function."""
5277

@@ -190,18 +215,8 @@ def sample_inputs(self, device, dtype, requires_grad=False):
190215

191216
# Returns True if the test should be skipped and False otherwise
192217
def should_skip(self, cls_name, test_name, device_type, dtype):
193-
for si in self.skips:
194-
if not si.active_if:
195-
continue
196-
197-
cls_name_match = si.cls_name is None or cls_name == si.cls_name
198-
name_match = si.test_name is None or test_name == si.test_name
199-
device_type_match = si.device_type is None or device_type == si.device_type
200-
dtype_match = si.dtypes is None or dtype in si.dtypes
201-
if cls_name_match and name_match and device_type_match and dtype_match:
202-
return True
203-
204-
return False
218+
return any(si.is_active(cls_name, test_name, device_type, dtype)
219+
for si in self.skips)
205220

206221
def supported_dtypes(self, device_type):
207222
if device_type == 'cpu':
@@ -500,21 +515,18 @@ def __init__(self,
500515
ref=None, # Reference implementation (probably in np.fft namespace)
501516
dtypes=floating_and_complex_types(),
502517
ndimensional: bool, # Whether dim argument can be a tuple
503-
skips=None,
504518
decorators=None,
505519
**kwargs):
506-
skips = skips if skips is not None else []
507-
508-
# gradgrad is quite slow
509-
if not TEST_WITH_SLOW:
510-
skips.append(SkipInfo('TestGradients', 'test_fn_gradgrad'))
511-
512-
decorators = decorators if decorators is not None else []
513-
decorators += [skipCPUIfNoMkl, skipCUDAIfRocm]
520+
decorators = list(decorators) if decorators is not None else []
521+
decorators += [
522+
skipCPUIfNoMkl,
523+
skipCUDAIfRocm,
524+
# gradgrad is quite slow
525+
DecorateInfo(slowTest, 'TestGradients', 'test_fn_gradgrad'),
526+
]
514527

515528
super().__init__(name=name,
516529
dtypes=dtypes,
517-
skips=skips,
518530
decorators=decorators,
519531
**kwargs)
520532
self.ref = ref if ref is not None else _getattr_qual(np, name)
@@ -1318,27 +1330,33 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad):
13181330
test_inplace_grad=False,
13191331
supports_tensor_out=False,
13201332
sample_inputs_func=sample_inputs_svd,
1321-
decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
1322-
skips=(
1333+
decorators=[
1334+
skipCUDAIfNoMagma,
1335+
skipCPUIfNoLapack,
13231336
# gradgrad checks are slow
1324-
SkipInfo('TestGradients', 'test_fn_gradgrad', active_if=(not TEST_WITH_SLOW)),
1337+
DecorateInfo(slowTest, 'TestGradients', 'test_fn_gradgrad'),
1338+
],
1339+
skips=(
13251340
# cuda gradchecks are very slow
13261341
# see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775
1327-
SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'))),
1342+
SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'),)),
13281343
OpInfo('linalg.svd',
13291344
op=torch.linalg.svd,
13301345
aten_name='linalg_svd',
13311346
dtypes=floating_and_complex_types(),
13321347
test_inplace_grad=False,
13331348
supports_tensor_out=False,
13341349
sample_inputs_func=sample_inputs_linalg_svd,
1335-
decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack],
1336-
skips=(
1350+
decorators=[
1351+
skipCUDAIfNoMagma,
1352+
skipCPUIfNoLapack,
13371353
# gradgrad checks are slow
1338-
SkipInfo('TestGradients', 'test_fn_gradgrad', active_if=(not TEST_WITH_SLOW)),
1354+
DecorateInfo(slowTest, 'TestGradients', 'test_fn_gradgrad'),
1355+
],
1356+
skips=(
13391357
# cuda gradchecks are very slow
13401358
# see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775
1341-
SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'))),
1359+
SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'),)),
13421360
OpInfo('pinverse',
13431361
op=torch.pinverse,
13441362
dtypes=floating_and_complex_types(),

0 commit comments

Comments
 (0)