Added check and test for betas parameter for Adam Optimizer#751
Added check and test for betas parameter for Adam Optimizer#751lazypanda1 wants to merge 2 commits intopyro-ppl:devfrom
Conversation
|
hello. thanks for the contribution! since we just wrap the adam in pytorch i think a more appropriate place for this kind of check would be in pytorch? |
fritzo
left a comment
There was a problem hiding this comment.
Thanks, this looks helpful. Looks like there's a lint error on travis.
| return PyroOptim(torch.optim.Adam, optim_args, _adam_checker) | ||
|
|
||
| def _adam_checker(optim_args): | ||
| assert (optim_args['betas'][0] >= 0.0 and optim_args['betas'][0] < 1.0 \ |
There was a problem hiding this comment.
nit: Could you convert these to if ...: raise ValueError rather than asserts? We try to reserve assertions for developer-facing errors and ValueErrors for user-facing errors.
| such dictionaries | ||
| """ | ||
| def __init__(self, optim_constructor, optim_args): | ||
| def __init__(self, optim_constructor, optim_args, checker=None): |
There was a problem hiding this comment.
nit: maybe call this arg_checker?
|
👍 I agree the PyTorch folks would appreciate better error checking, then a larger community could benefit from this fix. |
|
@lazypanda1 After discussing with the rest of the Pyro team, we feel fix really belongs upstream in PyTorch so more people can benefit. The PyTorch folks are very friendly about accepting PRs (we contribute regularly). I think you can simply put your |
This PR adds a check to prevent division by zero errors and give users friendlier error messages when using the Adam optimizer.
Currently, if one specifies the beta value of the Adam optimizer as
1.0for the first parameter, pyro fails with the error message,ZeroDivisionError: float division by zero. According to the definition of Adam, beta values should be in the range [0, 1).