Commit 44b98c0
[BE] migrate all assertRaises tests to OptimizerInfo test_errors (#116315)
Removes a part of the sparse adam test and the following three tests: `test_fused_optimizer_raises`, `test_duplicate_params_across_param_groups`, `test_duplicate_params_in_one_param_group`
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (d2d129d)]$ python test/test_optim.py -k test_fused_optimizer_raises -k test_duplicate_params_across_param_groups -k test_duplicate_params_in_one_param_group
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
...
----------------------------------------------------------------------
Ran 3 tests in 0.023s
OK
```
Increases coverage by testing the duplicate param tests on ALL the optims instead of just one each. Also fixes SparseAdam bug which was accidentally calling torch.unbind through list instead of putting params in a list. This bug was caught by migrating the weird warning stuff to just one easy warning context manager, which checks that nothing else gets raised.
The new test_errors does not run slower than before, overhead is still king:
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (d2d129d)]$ python test/test_optim.py -k test_errors
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
..........................
----------------------------------------------------------------------
Ran 26 tests in 10.337s
OK
```
Compared to test_errors BEFORE my commit :p
```
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (b47aa69)]$ python test/test_optim.py -k test_errors
/home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0
warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
.............sssssssssssss
----------------------------------------------------------------------
Ran 26 tests in 11.980s
OK (skipped=13)
(pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (b47aa69)]$
```
Pull Request resolved: #116315
Approved by: https://github.com/mikaylagawarecki1 parent 8abeacd commit 44b98c0
4 files changed
Lines changed: 286 additions & 303 deletions
File tree
- test
- optim
- torch
- optim
- testing/_internal
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
906 | 906 | | |
907 | 907 | | |
908 | 908 | | |
909 | | - | |
910 | | - | |
911 | | - | |
912 | | - | |
913 | | - | |
914 | | - | |
915 | | - | |
916 | | - | |
917 | 909 | | |
918 | 910 | | |
919 | 911 | | |
| |||
1438 | 1430 | | |
1439 | 1431 | | |
1440 | 1432 | | |
1441 | | - | |
1442 | | - | |
1443 | | - | |
1444 | | - | |
1445 | | - | |
1446 | | - | |
1447 | | - | |
1448 | | - | |
1449 | | - | |
1450 | | - | |
1451 | | - | |
1452 | | - | |
1453 | | - | |
1454 | | - | |
1455 | 1433 | | |
1456 | 1434 | | |
1457 | 1435 | | |
| |||
1621 | 1599 | | |
1622 | 1600 | | |
1623 | 1601 | | |
1624 | | - | |
1625 | | - | |
1626 | | - | |
1627 | | - | |
1628 | | - | |
1629 | | - | |
1630 | | - | |
1631 | | - | |
1632 | 1602 | | |
1633 | 1603 | | |
1634 | 1604 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
26 | 26 | | |
27 | 27 | | |
28 | 28 | | |
29 | | - | |
30 | 29 | | |
31 | 30 | | |
32 | 31 | | |
| |||
36 | 35 | | |
37 | 36 | | |
38 | 37 | | |
39 | | - | |
40 | | - | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
41 | 44 | | |
42 | 45 | | |
43 | | - | |
44 | | - | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
45 | 52 | | |
46 | 53 | | |
47 | 54 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
260 | 260 | | |
261 | 261 | | |
262 | 262 | | |
| 263 | + | |
263 | 264 | | |
264 | 265 | | |
265 | 266 | | |
| |||
0 commit comments