|
8 | 8 | import numpy as np |
9 | 9 | from torch._six import inf, istuple |
10 | 10 | from torch.autograd import Variable |
| 11 | +import collections.abc |
11 | 12 |
|
12 | 13 | from typing import List, Tuple, Dict, Any |
13 | 14 |
|
|
16 | 17 | floating_and_complex_types, floating_and_complex_types_and, |
17 | 18 | all_types_and_complex_and, all_types_and) |
18 | 19 | from torch.testing._internal.common_device_type import \ |
19 | | - (skipCUDAIfNoMagma, skipCPUIfNoLapack, skipCPUIfNoMkl, skipCUDAIfRocm, |
20 | | - expectedAlertNondeterministic, precisionOverride, onlyCPU) |
| 20 | + (skipIf, skipCUDAIfNoMagma, skipCPUIfNoLapack, skipCPUIfNoMkl, |
| 21 | + skipCUDAIfRocm, expectedAlertNondeterministic, precisionOverride, onlyCPU) |
21 | 22 | from torch.testing._internal.common_cuda import tf32_is_not_fp32 |
22 | 23 | from torch.testing._internal.common_utils import \ |
23 | 24 | (prod_single_zero, random_square_matrix_of_rank, |
24 | 25 | random_symmetric_matrix, random_symmetric_psd_matrix, |
25 | 26 | random_symmetric_pd_matrix, make_nonzero_det, |
26 | 27 | random_fullrank_matrix_distinct_singular_value, set_rng_seed, |
27 | 28 | TEST_WITH_ROCM, IS_WINDOWS, IS_MACOS, make_tensor, TEST_SCIPY, |
28 | | - torch_to_numpy_dtype_dict, TEST_WITH_SLOW) |
| 29 | + torch_to_numpy_dtype_dict, slowTest) |
29 | 30 |
|
30 | 31 | from distutils.version import LooseVersion |
31 | 32 |
|
32 | 33 | if TEST_SCIPY: |
33 | 34 | import scipy.special |
34 | 35 |
|
35 | | -class SkipInfo(object): |
36 | | - """Describes which test, or type of tests, should be skipped when testing |
37 | | - an operator. Any test that matches all provided arguments will be skipped. |
38 | | - The skip will only be checked if the active_if argument is True.""" |
39 | 36 |
|
40 | | - __slots__ = ['cls_name', 'test_name', 'device_type', 'dtypes', 'active_if'] |
| 37 | +class DecorateInfo(object): |
| 38 | + """Describes which test, or type of tests, should be wrapped in the given |
| 39 | + decorators when testing an operator. Any test that matches all provided |
| 40 | + arguments will be decorated. The decorators will only be applied if the |
| 41 | + active_if argument is True.""" |
41 | 42 |
|
42 | | - def __init__(self, cls_name=None, test_name=None, *, |
| 43 | + __slots__ = ['decorators', 'cls_name', 'test_name', 'device_type', 'dtypes', 'active_if'] |
| 44 | + |
| 45 | + def __init__(self, decorators, cls_name=None, test_name=None, *, |
43 | 46 | device_type=None, dtypes=None, active_if=True): |
| 47 | + self.decorators = list(decorators) if isinstance(decorators, collections.abc.Sequence) else [decorators] |
44 | 48 | self.cls_name = cls_name |
45 | 49 | self.test_name = test_name |
46 | 50 | self.device_type = device_type |
47 | 51 | self.dtypes = dtypes |
48 | 52 | self.active_if = active_if |
49 | 53 |
|
| 54 | + def is_active(self, cls_name, test_name, device_type, dtype): |
| 55 | + return ( |
| 56 | + self.active_if and |
| 57 | + (self.cls_name is None or self.cls_name == cls_name) and |
| 58 | + (self.test_name is None or self.test_name == test_name) and |
| 59 | + (self.device_type is None or self.device_type == device_type) and |
| 60 | + (self.dtypes is None or dtype in self.dtypes) |
| 61 | + ) |
| 62 | + |
| 63 | + |
| 64 | +class SkipInfo(DecorateInfo): |
| 65 | + """Describes which test, or type of tests, should be skipped when testing |
| 66 | + an operator. Any test that matches all provided arguments will be skipped. |
| 67 | + The skip will only be checked if the active_if argument is True.""" |
| 68 | + |
| 69 | + def __init__(self, cls_name=None, test_name=None, *, |
| 70 | + device_type=None, dtypes=None, active_if=True): |
| 71 | + super().__init__(decorators=skipIf(True, "Skipped!"), cls_name=cls_name, |
| 72 | + test_name=test_name, device_type=device_type, dtypes=dtypes, |
| 73 | + active_if=active_if) |
| 74 | + |
50 | 75 | class SampleInput(object): |
51 | 76 | """Represents sample inputs to a function.""" |
52 | 77 |
|
@@ -190,18 +215,8 @@ def sample_inputs(self, device, dtype, requires_grad=False): |
190 | 215 |
|
191 | 216 | # Returns True if the test should be skipped and False otherwise |
192 | 217 | def should_skip(self, cls_name, test_name, device_type, dtype): |
193 | | - for si in self.skips: |
194 | | - if not si.active_if: |
195 | | - continue |
196 | | - |
197 | | - cls_name_match = si.cls_name is None or cls_name == si.cls_name |
198 | | - name_match = si.test_name is None or test_name == si.test_name |
199 | | - device_type_match = si.device_type is None or device_type == si.device_type |
200 | | - dtype_match = si.dtypes is None or dtype in si.dtypes |
201 | | - if cls_name_match and name_match and device_type_match and dtype_match: |
202 | | - return True |
203 | | - |
204 | | - return False |
| 218 | + return any(si.is_active(cls_name, test_name, device_type, dtype) |
| 219 | + for si in self.skips) |
205 | 220 |
|
206 | 221 | def supported_dtypes(self, device_type): |
207 | 222 | if device_type == 'cpu': |
@@ -500,21 +515,18 @@ def __init__(self, |
500 | 515 | ref=None, # Reference implementation (probably in np.fft namespace) |
501 | 516 | dtypes=floating_and_complex_types(), |
502 | 517 | ndimensional: bool, # Whether dim argument can be a tuple |
503 | | - skips=None, |
504 | 518 | decorators=None, |
505 | 519 | **kwargs): |
506 | | - skips = skips if skips is not None else [] |
507 | | - |
508 | | - # gradgrad is quite slow |
509 | | - if not TEST_WITH_SLOW: |
510 | | - skips.append(SkipInfo('TestGradients', 'test_fn_gradgrad')) |
511 | | - |
512 | | - decorators = decorators if decorators is not None else [] |
513 | | - decorators += [skipCPUIfNoMkl, skipCUDAIfRocm] |
| 520 | + decorators = list(decorators) if decorators is not None else [] |
| 521 | + decorators += [ |
| 522 | + skipCPUIfNoMkl, |
| 523 | + skipCUDAIfRocm, |
| 524 | + # gradgrad is quite slow |
| 525 | + DecorateInfo(slowTest, 'TestGradients', 'test_fn_gradgrad'), |
| 526 | + ] |
514 | 527 |
|
515 | 528 | super().__init__(name=name, |
516 | 529 | dtypes=dtypes, |
517 | | - skips=skips, |
518 | 530 | decorators=decorators, |
519 | 531 | **kwargs) |
520 | 532 | self.ref = ref if ref is not None else _getattr_qual(np, name) |
@@ -1318,27 +1330,33 @@ def sample_inputs_fliplr_flipud(op_info, device, dtype, requires_grad): |
1318 | 1330 | test_inplace_grad=False, |
1319 | 1331 | supports_tensor_out=False, |
1320 | 1332 | sample_inputs_func=sample_inputs_svd, |
1321 | | - decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack], |
1322 | | - skips=( |
| 1333 | + decorators=[ |
| 1334 | + skipCUDAIfNoMagma, |
| 1335 | + skipCPUIfNoLapack, |
1323 | 1336 | # gradgrad checks are slow |
1324 | | - SkipInfo('TestGradients', 'test_fn_gradgrad', active_if=(not TEST_WITH_SLOW)), |
| 1337 | + DecorateInfo(slowTest, 'TestGradients', 'test_fn_gradgrad'), |
| 1338 | + ], |
| 1339 | + skips=( |
1325 | 1340 | # cuda gradchecks are very slow |
1326 | 1341 | # see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775 |
1327 | | - SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'))), |
| 1342 | + SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'),)), |
1328 | 1343 | OpInfo('linalg.svd', |
1329 | 1344 | op=torch.linalg.svd, |
1330 | 1345 | aten_name='linalg_svd', |
1331 | 1346 | dtypes=floating_and_complex_types(), |
1332 | 1347 | test_inplace_grad=False, |
1333 | 1348 | supports_tensor_out=False, |
1334 | 1349 | sample_inputs_func=sample_inputs_linalg_svd, |
1335 | | - decorators=[skipCUDAIfNoMagma, skipCPUIfNoLapack], |
1336 | | - skips=( |
| 1350 | + decorators=[ |
| 1351 | + skipCUDAIfNoMagma, |
| 1352 | + skipCPUIfNoLapack, |
1337 | 1353 | # gradgrad checks are slow |
1338 | | - SkipInfo('TestGradients', 'test_fn_gradgrad', active_if=(not TEST_WITH_SLOW)), |
| 1354 | + DecorateInfo(slowTest, 'TestGradients', 'test_fn_gradgrad'), |
| 1355 | + ], |
| 1356 | + skips=( |
1339 | 1357 | # cuda gradchecks are very slow |
1340 | 1358 | # see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775 |
1341 | | - SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'))), |
| 1359 | + SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'),)), |
1342 | 1360 | OpInfo('pinverse', |
1343 | 1361 | op=torch.pinverse, |
1344 | 1362 | dtypes=floating_and_complex_types(), |
|
0 commit comments