Skip to content

Commit 44b98c0

Browse files
janeyx99pytorchmergebot
authored andcommitted
[BE] migrate all assertRaises tests to OptimizerInfo test_errors (#116315)
Removes a part of the sparse adam test and the following three tests: `test_fused_optimizer_raises`, `test_duplicate_params_across_param_groups`, `test_duplicate_params_in_one_param_group` ``` (pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (d2d129d)]$ python test/test_optim.py -k test_fused_optimizer_raises -k test_duplicate_params_across_param_groups -k test_duplicate_params_in_one_param_group /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0 warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}" ... ---------------------------------------------------------------------- Ran 3 tests in 0.023s OK ``` Increases coverage by testing the duplicate param tests on ALL the optims instead of just one each. Also fixes SparseAdam bug which was accidentally calling torch.unbind through list instead of putting params in a list. This bug was caught by migrating the weird warning stuff to just one easy warning context manager, which checks that nothing else gets raised. The new test_errors does not run slower than before, overhead is still king: ``` (pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (d2d129d)]$ python test/test_optim.py -k test_errors /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0 warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}" .......................... ---------------------------------------------------------------------- Ran 26 tests in 10.337s OK ``` Compared to test_errors BEFORE my commit :p ``` (pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (b47aa69)]$ python test/test_optim.py -k test_errors /home/janeyx/.conda/envs/pytorch-3.10/lib/python3.10/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.0 warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}" .............sssssssssssss ---------------------------------------------------------------------- Ran 26 tests in 11.980s OK (skipped=13) (pytorch-3.10) [janeyx@devgpu023.odn1 ~/local/pytorch (b47aa69)]$ ``` Pull Request resolved: #116315 Approved by: https://github.com/mikaylagawarecki
1 parent 8abeacd commit 44b98c0

4 files changed

Lines changed: 286 additions & 303 deletions

File tree

test/optim/test_optim.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -906,14 +906,6 @@ def test_sparse_adam(self):
906906
sparse_only=True,
907907
maximize=True,
908908
)
909-
import warnings
910-
with warnings.catch_warnings(record=True) as ws:
911-
SparseAdam(torch.zeros(3))
912-
self.assertEqual(len(ws), 1)
913-
for warning in ws:
914-
self.assertEqual(len(warning.message.args), 1)
915-
self.assertRegex(warning.message.args[0],
916-
"Passing in a raw Tensor as ``params`` to SparseAdam ")
917909

918910
# ROCm precision is too low to pass this test
919911
def test_adadelta(self):
@@ -1438,20 +1430,6 @@ def closure():
14381430
self.assertEqual(type(res1), type(res2))
14391431

14401432

1441-
def test_duplicate_params_in_one_param_group(self):
1442-
param = Parameter(torch.randn(1))
1443-
with self.assertWarnsOnceRegex(UserWarning, '.*a parameter group with duplicate parameters.*'):
1444-
Adamax([param, param], lr=0.01)
1445-
1446-
def test_duplicate_params_across_param_groups(self):
1447-
param = Parameter(torch.randn(1))
1448-
self.assertRaisesRegex(
1449-
ValueError,
1450-
'some parameters appear in more than one parameter group',
1451-
lambda: Adadelta([{'params': param}, {'params': param}])
1452-
)
1453-
1454-
14551433
def test_fused_optimizer_does_not_step_if_foundinf(self):
14561434
if not torch.cuda.is_available():
14571435
self.skipTest("CUDA is required.")
@@ -1621,14 +1599,6 @@ def local_post_hook(opt: Optimizer, args: Tuple[Any], kwargs: Dict[Any, Any]):
16211599
opt2.step()
16221600
self.assertListEqual(data, [0, 1, 2, 5, 0, 1, 2, 5, 0, 1, 2, 5])
16231601

1624-
def test_fused_optimizer_raises(self):
1625-
if not torch.cuda.is_available():
1626-
self.skipTest("Requires CUDA devices")
1627-
for optimizer_ctor in (Adam, AdamW):
1628-
with self.assertRaisesRegex(RuntimeError, "`fused` and `foreach` cannot be `True` together."):
1629-
optimizer_ctor([torch.empty((), device="cuda")], foreach=True, fused=True)
1630-
with self.assertRaisesRegex(RuntimeError, "`fused` does not support `differentiable`"):
1631-
optimizer_ctor([torch.empty((), device="cuda")], differentiable=True, fused=True)
16321602

16331603
@staticmethod
16341604
def _state_dict_pre_hook(optimizer: Optimizer) -> None:

test/test_optim.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def test_optim_infos_do_not_specify_global_cliquey_kwargs(self, device, dtype, o
2626
self.assertFalse(any(f for f in global_cliquey_flags if f in optim_input.kwargs))
2727

2828

29-
@onlyCPU
3029
@optims([optim for optim in optim_db if optim.optim_error_inputs_func is not None])
3130
def test_errors(self, device, dtype, optim_info):
3231
optim_cls = optim_info.optim_cls
@@ -36,12 +35,20 @@ def test_errors(self, device, dtype, optim_info):
3635
optim_input = error_input.optimizer_error_input
3736
params, kwargs = optim_input.params, optim_input.kwargs
3837
if error_input.error_on == OptimizerErrorEnum.CONSTRUCTION_ERROR:
39-
with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
40-
optim_cls(params, **kwargs)
38+
if issubclass(error_input.error_type, Warning):
39+
with self.assertWarnsRegex(error_input.error_type, error_input.error_regex):
40+
optim_cls(params, **kwargs)
41+
else:
42+
with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
43+
optim_cls(params, **kwargs)
4144
elif error_input.error_on == OptimizerErrorEnum.STEP_ERROR:
4245
optim = optim_cls(params, **kwargs)
43-
with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
44-
optim.step()
46+
if issubclass(error_input.error_type, Warning):
47+
with self.assertWarnsRegex(error_input.error_type, error_input.error_regex):
48+
optim.step()
49+
else:
50+
with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
51+
optim.step()
4552
else:
4653
raise NotImplementedError(f"Unknown error type {error_input.error_on}")
4754

torch/optim/optimizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ def __init__(self, params: ParamsT, defaults: Dict[str, Any]) -> None:
260260
"is deprecated. In the future, this will raise an error. "
261261
"Please wrap your Tensor in an iterable instead."),
262262
FutureWarning)
263+
params = [params]
263264
else:
264265
raise TypeError("params argument given to the optimizer should be "
265266
"an iterable of Tensors or dicts, but got " +

0 commit comments

Comments
 (0)