Commit 23d590e
More flexible test parametrization with @reparametrize (#138369)
**Background:** The `@parametrize` decorator enjoys widespread usage as a convenient tool for ensuring extensive test coverage. One particular feature that makes this easy is the ability to stack such decorators, testing over the cross-product of inputs. Example:
```python
class MyTestClass(TestCase):
@parametrize("x", range(3))
@parametrize("y", [False, True])
def test_foo(self, x, y):
# Invoked with:
# x=0, y=False
# x=1, y=False
# x=2, y=False
# x=0, y=True
# x=1, y=True
# x=2, y=True
...
```
Note that the `@ops` and `@modules` decorators employ the same underlying machinery for parametrizing over `OpInfo` / `ModuleInfo` entries. These decorators also parametrize over op-specific `device` / `dtype` info *according to what is supported for each op*.
```python
class MyTestClass(TestCase):
@ops(op_db)
def test_foo(self, op, device, dtype):
# Invoked each OpInfo in the db along with each device / dtype that corresponds
# with this op according to the OpInfo entry.
...
```
Note that this in contrast to the naive cross product between ops and devices / dtypes, which would generate too many tests. Certain use cases benefit from a similar type of flexible parametrization that is more intelligent than simple cross-product composition. It is expensive to generate / run too many tests, even if the unneeded ones are skipped appropriately.
This PR attempts to generalize such flexible parametrization and satisfy these use cases through the introduction of a `@reparametrize` decorator, which operates on an existing parametrizer and allows for customized on-the-fly parametrization through the use of an `adapter_fn`. Examples:
```python
# adapter_fn that adds a new arg
def include_is_even_arg(test_name, param_kwargs):
x = param_kwargs["x"]
is_even = x % 2 == 0
new_param_kwargs = dict(param_kwargs)
new_param_kwargs["is_even"] = is_even
is_even_suffix = "_even" if is_even else "_odd"
new_test_name = f"{test_name}{is_even_suffix}"
yield (new_test_name, new_param_kwargs)
# adapter_fn that excludes certain values
def exclude_odds(test_name, param_kwargs):
x = param_kwargs["x"]
is_even = x % 2 == 0
yield None if not is_even else (test_name, param_kwargs)
class MyTestClass(TestCase):
@reparametrize(parametrize("x", range(5)), include_is_even_arg)
def test_foo(self, x, is_even):
# Invoked with both the x value and the new is_even arg
...
@reparametrize(parametrize("x", range(5)), exclude_odds)
def test_bar(self, x):
# Only invoked with even x values
...
```
For a more real-world use case, imagine you want to write a set of OpInfo tests that parametrize over additional op-specific things beyond `device` / `dtype` (in NJT's case, this includes contiguity type, whether to operate over the batch / ragged / other dims, etc.). The `@reparametrize` decorator allows you to customize the `@ops` parametrization to add in these additional args as they make sense on a per-op basis.
Pull Request resolved: #138369
Approved by: https://github.com/janeyx991 parent ebaa774 commit 23d590e
2 files changed
Lines changed: 100 additions & 3 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
17 | 17 | | |
18 | 18 | | |
19 | 19 | | |
20 | | - | |
21 | | - | |
22 | | - | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
23 | 25 | | |
24 | 26 | | |
25 | 27 | | |
| |||
1651 | 1653 | | |
1652 | 1654 | | |
1653 | 1655 | | |
| 1656 | + | |
| 1657 | + | |
| 1658 | + | |
| 1659 | + | |
| 1660 | + | |
| 1661 | + | |
| 1662 | + | |
| 1663 | + | |
| 1664 | + | |
| 1665 | + | |
| 1666 | + | |
| 1667 | + | |
| 1668 | + | |
| 1669 | + | |
| 1670 | + | |
| 1671 | + | |
| 1672 | + | |
| 1673 | + | |
| 1674 | + | |
| 1675 | + | |
| 1676 | + | |
| 1677 | + | |
| 1678 | + | |
| 1679 | + | |
| 1680 | + | |
| 1681 | + | |
| 1682 | + | |
| 1683 | + | |
| 1684 | + | |
| 1685 | + | |
| 1686 | + | |
| 1687 | + | |
| 1688 | + | |
| 1689 | + | |
| 1690 | + | |
| 1691 | + | |
| 1692 | + | |
| 1693 | + | |
| 1694 | + | |
| 1695 | + | |
1654 | 1696 | | |
1655 | 1697 | | |
1656 | 1698 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
702 | 702 | | |
703 | 703 | | |
704 | 704 | | |
| 705 | + | |
| 706 | + | |
| 707 | + | |
| 708 | + | |
| 709 | + | |
| 710 | + | |
| 711 | + | |
| 712 | + | |
| 713 | + | |
| 714 | + | |
| 715 | + | |
| 716 | + | |
| 717 | + | |
| 718 | + | |
| 719 | + | |
| 720 | + | |
| 721 | + | |
| 722 | + | |
| 723 | + | |
| 724 | + | |
| 725 | + | |
| 726 | + | |
| 727 | + | |
| 728 | + | |
| 729 | + | |
| 730 | + | |
| 731 | + | |
| 732 | + | |
| 733 | + | |
| 734 | + | |
| 735 | + | |
| 736 | + | |
| 737 | + | |
| 738 | + | |
| 739 | + | |
| 740 | + | |
| 741 | + | |
| 742 | + | |
| 743 | + | |
| 744 | + | |
| 745 | + | |
| 746 | + | |
| 747 | + | |
| 748 | + | |
| 749 | + | |
| 750 | + | |
| 751 | + | |
| 752 | + | |
| 753 | + | |
| 754 | + | |
| 755 | + | |
| 756 | + | |
| 757 | + | |
| 758 | + | |
| 759 | + | |
705 | 760 | | |
706 | 761 | | |
707 | 762 | | |
| |||
0 commit comments