Skip to content

Commit df7c0a0

Browse files
kshitij12345facebook-github-bot
authored andcommitted
[testing] assert no duplicate in method_tests for an OpInfo entry (#53492)
Summary: Assert no duplicate in method_tests for an OpInfo entry Pull Request resolved: #53492 Reviewed By: izdeby Differential Revision: D26882441 Pulled By: mruberry fbshipit-source-id: f0631ea2b46b74285c76365c679bd45abc917d63
1 parent 547f435 commit df7c0a0

2 files changed

Lines changed: 24 additions & 2 deletions

File tree

test/test_ops.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,9 @@
88
from torch.testing._internal.common_utils import \
99
(TestCase, run_tests, IS_SANDCASTLE, clone_input_helper, make_tensor)
1010
from torch.testing._internal.common_methods_invocations import \
11-
(op_db)
11+
(op_db, method_tests)
1212
from torch.testing._internal.common_device_type import \
13-
(instantiate_device_type_tests, ops, onlyOnCPUAndCUDA, skipCUDAIfRocm, OpDTypes)
13+
(instantiate_device_type_tests, ops, onlyCPU, onlyOnCPUAndCUDA, skipCUDAIfRocm, OpDTypes)
1414
from torch.testing._internal.common_jit import JitCommonTestCase, check_against_reference
1515
from torch.autograd.gradcheck import gradcheck, gradgradcheck
1616

@@ -19,6 +19,9 @@
1919
from torch.testing._internal.jit_utils import disable_autodiff_subgraph_inlining
2020

2121

22+
# Get names of all the operators which have entry in `method_tests` (legacy testing infra)
23+
method_tested_operators = set(map(lambda test_details: test_details[0], method_tests()))
24+
2225
# Tests that apply to all operators
2326

2427
class TestOpInfo(TestCase):
@@ -56,6 +59,12 @@ def test_supported_dtypes(self, device, dtype, op):
5659
sample = samples[0]
5760
op(*sample.input, *sample.args, **sample.kwargs)
5861

62+
# Verifies that ops do not have an entry in
63+
# `method_tests` (legacy testing infra).
64+
@onlyCPU
65+
@ops(op_db, allowed_dtypes=[torch.float32])
66+
def test_duplicate_method_tests(self, device, dtype, op):
67+
self.assertFalse(op.name in method_tested_operators)
5968

6069
# gradcheck requires double precision
6170
_gradcheck_ops = partial(ops, dtypes=OpDTypes.supported,

torch/testing/_internal/common_methods_invocations.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1574,6 +1574,7 @@ def _make_tensor_helper(shape, low=None, high=None):
15741574
skips=(
15751575
SkipInfo('TestCommon', 'test_variant_consistency_jit',
15761576
dtypes=[torch.bfloat16, torch.float16, torch.cfloat, torch.cdouble]),
1577+
SkipInfo('TestOpInfo', 'test_duplicate_method_tests'),
15771578
# addmm does not correctly warn when resizing out= inputs
15781579
SkipInfo('TestCommon', 'test_out')),
15791580
sample_inputs_func=sample_inputs_addmm),
@@ -1857,6 +1858,7 @@ def _make_tensor_helper(shape, low=None, high=None):
18571858
dtypes=[torch.bool]),
18581859
# cumsum does not correctly warn when resizing out= inputs
18591860
SkipInfo('TestCommon', 'test_out'),
1861+
SkipInfo('TestOpInfo', 'test_duplicate_method_tests'),
18601862
),
18611863
sample_inputs_func=sample_inputs_cumsum),
18621864
UnaryUfuncInfo('deg2rad',
@@ -1881,21 +1883,25 @@ def _make_tensor_helper(shape, low=None, high=None):
18811883
variant_test_name='no_rounding_mode',
18821884
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
18831885
sample_inputs_func=sample_inputs_div,
1886+
skips=(SkipInfo('TestOpInfo', 'test_duplicate_method_tests'),),
18841887
assert_autodiffed=True),
18851888
OpInfo('div',
18861889
variant_test_name='true_rounding',
18871890
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
18881891
sample_inputs_func=partial(sample_inputs_div, rounding_mode='true'),
1892+
skips=(SkipInfo('TestOpInfo', 'test_duplicate_method_tests'),),
18891893
assert_autodiffed=True),
18901894
OpInfo('div',
18911895
variant_test_name='trunc_rounding',
18921896
dtypes=all_types_and(torch.half, torch.bfloat16),
18931897
sample_inputs_func=partial(sample_inputs_div, rounding_mode='trunc'),
1898+
skips=(SkipInfo('TestOpInfo', 'test_duplicate_method_tests'),),
18941899
assert_autodiffed=True),
18951900
OpInfo('div',
18961901
variant_test_name='floor_rounding',
18971902
dtypes=all_types_and(torch.half, torch.bfloat16),
18981903
sample_inputs_func=partial(sample_inputs_div, rounding_mode='floor'),
1904+
skips=(SkipInfo('TestOpInfo', 'test_duplicate_method_tests'),),
18991905
assert_autodiffed=True),
19001906
UnaryUfuncInfo('exp',
19011907
ref=np_unary_ufunc_integer_promotion_wrapper(np.exp),
@@ -2050,6 +2056,10 @@ def _make_tensor_helper(shape, low=None, high=None):
20502056
dtypes=all_types_and(torch.half, torch.bfloat16),
20512057
sample_inputs_func=sample_inputs_floor_divide,
20522058
decorators=[_wrap_warn_once("floor_divide is deprecated, and will be removed")],
2059+
skips=(
2060+
# `test_duplicate_method_tests` doesn't raise any warning, as it doesn't actually
2061+
# call the operator.
2062+
SkipInfo('TestOpInfo', 'test_duplicate_method_tests'),),
20532063
supports_autograd=False,
20542064
),
20552065
OpInfo('inverse',
@@ -2241,6 +2251,7 @@ def _make_tensor_helper(shape, low=None, high=None):
22412251
device_type='cuda', dtypes=[torch.complex128]),
22422252
SkipInfo('TestCommon', 'test_variant_consistency_jit',
22432253
dtypes=[torch.cfloat, torch.cdouble]),
2254+
SkipInfo('TestOpInfo', 'test_duplicate_method_tests'),
22442255
),
22452256
supports_out=False),
22462257
OpInfo('masked_select',
@@ -2494,6 +2505,7 @@ def _make_tensor_helper(shape, low=None, high=None):
24942505
dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
24952506
supports_out=False,
24962507
test_inplace_grad=False,
2508+
skips=(SkipInfo('TestOpInfo', 'test_duplicate_method_tests'),),
24972509
sample_inputs_func=sample_inputs_tensor_split,),
24982510
OpInfo('triangular_solve',
24992511
op=torch.triangular_solve,
@@ -2723,6 +2735,7 @@ def _make_tensor_helper(shape, low=None, high=None):
27232735
OpInfo('index_fill',
27242736
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
27252737
test_inplace_grad=False,
2738+
skips=(SkipInfo('TestOpInfo', 'test_duplicate_method_tests'),),
27262739
supports_out=False,
27272740
sample_inputs_func=sample_inputs_index_fill),
27282741
OpInfo('index_select',

0 commit comments

Comments
 (0)