Skip to content

Commit 23d590e

Browse files
jbschlosserpytorchmergebot
authored andcommitted
More flexible test parametrization with @reparametrize (#138369)
**Background:** The `@parametrize` decorator enjoys widespread usage as a convenient tool for ensuring extensive test coverage. One particular feature that makes this easy is the ability to stack such decorators, testing over the cross-product of inputs. Example: ```python class MyTestClass(TestCase): @parametrize("x", range(3)) @parametrize("y", [False, True]) def test_foo(self, x, y): # Invoked with: # x=0, y=False # x=1, y=False # x=2, y=False # x=0, y=True # x=1, y=True # x=2, y=True ... ``` Note that the `@ops` and `@modules` decorators employ the same underlying machinery for parametrizing over `OpInfo` / `ModuleInfo` entries. These decorators also parametrize over op-specific `device` / `dtype` info *according to what is supported for each op*. ```python class MyTestClass(TestCase): @ops(op_db) def test_foo(self, op, device, dtype): # Invoked each OpInfo in the db along with each device / dtype that corresponds # with this op according to the OpInfo entry. ... ``` Note that this in contrast to the naive cross product between ops and devices / dtypes, which would generate too many tests. Certain use cases benefit from a similar type of flexible parametrization that is more intelligent than simple cross-product composition. It is expensive to generate / run too many tests, even if the unneeded ones are skipped appropriately. This PR attempts to generalize such flexible parametrization and satisfy these use cases through the introduction of a `@reparametrize` decorator, which operates on an existing parametrizer and allows for customized on-the-fly parametrization through the use of an `adapter_fn`. Examples: ```python # adapter_fn that adds a new arg def include_is_even_arg(test_name, param_kwargs): x = param_kwargs["x"] is_even = x % 2 == 0 new_param_kwargs = dict(param_kwargs) new_param_kwargs["is_even"] = is_even is_even_suffix = "_even" if is_even else "_odd" new_test_name = f"{test_name}{is_even_suffix}" yield (new_test_name, new_param_kwargs) # adapter_fn that excludes certain values def exclude_odds(test_name, param_kwargs): x = param_kwargs["x"] is_even = x % 2 == 0 yield None if not is_even else (test_name, param_kwargs) class MyTestClass(TestCase): @reparametrize(parametrize("x", range(5)), include_is_even_arg) def test_foo(self, x, is_even): # Invoked with both the x value and the new is_even arg ... @reparametrize(parametrize("x", range(5)), exclude_odds) def test_bar(self, x): # Only invoked with even x values ... ``` For a more real-world use case, imagine you want to write a set of OpInfo tests that parametrize over additional op-specific things beyond `device` / `dtype` (in NJT's case, this includes contiguity type, whether to operate over the batch / ragged / other dims, etc.). The `@reparametrize` decorator allows you to customize the `@ops` parametrization to add in these additional args as they make sense on a per-op basis. Pull Request resolved: #138369 Approved by: https://github.com/janeyx99
1 parent ebaa774 commit 23d590e

2 files changed

Lines changed: 100 additions & 3 deletions

File tree

test/test_testing.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717
import torch
1818

1919
from torch.testing import make_tensor
20-
from torch.testing._internal.common_utils import \
21-
(IS_FBCODE, IS_JETSON, IS_MACOS, IS_SANDCASTLE, IS_WINDOWS, TestCase, run_tests, slowTest,
22-
parametrize, subtest, instantiate_parametrized_tests, dtype_name, TEST_WITH_ROCM, decorateIf, skipIfRocm)
20+
from torch.testing._internal.common_utils import (
21+
IS_FBCODE, IS_JETSON, IS_MACOS, IS_SANDCASTLE, IS_WINDOWS, TestCase, run_tests, slowTest,
22+
parametrize, reparametrize, subtest, instantiate_parametrized_tests, dtype_name,
23+
TEST_WITH_ROCM, decorateIf, skipIfRocm
24+
)
2325
from torch.testing._internal.common_device_type import \
2426
(PYTORCH_TESTING_DEVICE_EXCEPT_FOR_KEY, PYTORCH_TESTING_DEVICE_ONLY_FOR_KEY, dtypes,
2527
get_device_type_test_bases, instantiate_device_type_tests, onlyCPU, onlyCUDA, onlyNativeDeviceTypes,
@@ -1651,6 +1653,46 @@ def test_two_things_custom_names_alternate(self, x, y):
16511653
test_names = _get_test_names_for_test_class(TestParametrized)
16521654
self.assertEqual(expected_test_names, test_names)
16531655

1656+
def test_reparametrize(self):
1657+
1658+
def include_is_even_arg(test_name, param_kwargs):
1659+
x = param_kwargs["x"]
1660+
is_even = x % 2 == 0
1661+
new_param_kwargs = dict(param_kwargs)
1662+
new_param_kwargs["is_even"] = is_even
1663+
is_even_suffix = "_even" if is_even else "_odd"
1664+
new_test_name = f"{test_name}{is_even_suffix}"
1665+
yield (new_test_name, new_param_kwargs)
1666+
1667+
def exclude_odds(test_name, param_kwargs):
1668+
x = param_kwargs["x"]
1669+
is_even = x % 2 == 0
1670+
yield None if not is_even else (test_name, param_kwargs)
1671+
1672+
class TestParametrized(TestCase):
1673+
@reparametrize(parametrize("x", range(5)), include_is_even_arg)
1674+
def test_foo(self, x, is_even):
1675+
pass
1676+
1677+
@reparametrize(parametrize("x", range(5)), exclude_odds)
1678+
def test_bar(self, x):
1679+
pass
1680+
1681+
instantiate_parametrized_tests(TestParametrized)
1682+
1683+
expected_test_names = [
1684+
'TestParametrized.test_bar_x_0',
1685+
'TestParametrized.test_bar_x_2',
1686+
'TestParametrized.test_bar_x_4',
1687+
'TestParametrized.test_foo_x_0_even',
1688+
'TestParametrized.test_foo_x_1_odd',
1689+
'TestParametrized.test_foo_x_2_even',
1690+
'TestParametrized.test_foo_x_3_odd',
1691+
'TestParametrized.test_foo_x_4_even',
1692+
]
1693+
test_names = _get_test_names_for_test_class(TestParametrized)
1694+
self.assertEqual(expected_test_names, test_names)
1695+
16541696
def test_subtest_names(self):
16551697

16561698
class TestParametrized(TestCase):

torch/testing/_internal/common_utils.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,61 @@ def decorator_fn(_, decorators=decorators):
702702
'Note that this may result from reuse of a generator.')
703703

704704

705+
class reparametrize(_TestParametrizer):
706+
"""
707+
Decorator for adjusting the way an existing parametrizer operates. This class runs
708+
the given adapter_fn on each parametrization produced by the given parametrizer,
709+
allowing for on-the-fly parametrization more flexible than the default,
710+
product-based composition that occurs when stacking parametrization decorators.
711+
712+
If the adapter_fn returns None for a given test parametrization, that parametrization
713+
will be excluded. Otherwise, it's expected that the adapter_fn returns an iterable of
714+
modified parametrizations, with tweaked test names and parameter kwargs.
715+
716+
Examples::
717+
718+
def include_is_even_arg(test_name, param_kwargs):
719+
x = param_kwargs["x"]
720+
is_even = x % 2 == 0
721+
new_param_kwargs = dict(param_kwargs)
722+
new_param_kwargs["is_even"] = is_even
723+
is_even_suffix = "_even" if is_even else "_odd"
724+
new_test_name = f"{test_name}{is_even_suffix}"
725+
yield (new_test_name, new_param_kwargs)
726+
727+
...
728+
729+
@reparametrize(parametrize("x", range(5)), include_is_even_arg)
730+
def test_foo(self, x, is_even):
731+
...
732+
733+
def exclude_odds(test_name, param_kwargs):
734+
x = param_kwargs["x"]
735+
is_even = x % 2 == 0
736+
yield None if not is_even else (test_name, param_kwargs)
737+
738+
...
739+
740+
@reparametrize(parametrize("x", range(5)), exclude_odds)
741+
def test_bar(self, x):
742+
...
743+
744+
"""
745+
def __init__(self, parametrizer, adapter_fn):
746+
self.parametrizer = parametrizer
747+
self.adapter_fn = adapter_fn
748+
749+
def _parametrize_test(self, test, generic_cls, device_cls):
750+
for (gen_test, test_name, param_kwargs, decorator_fn) in \
751+
self.parametrizer._parametrize_test(test, generic_cls, device_cls):
752+
adapted = self.adapter_fn(test_name, param_kwargs)
753+
if adapted is not None:
754+
for adapted_item in adapted:
755+
if adapted_item is not None:
756+
new_test_name, new_param_kwargs = adapted_item
757+
yield (gen_test, new_test_name, new_param_kwargs, decorator_fn)
758+
759+
705760
class decorateIf(_TestParametrizer):
706761
"""
707762
Decorator for applying parameter-specific conditional decoration.

0 commit comments

Comments
 (0)