Skip to content

Added check and test for betas parameter for Adam Optimizer#751

Closed
lazypanda1 wants to merge 2 commits intopyro-ppl:devfrom
lazypanda1:adam_optimizer_param_fix
Closed

Added check and test for betas parameter for Adam Optimizer#751
lazypanda1 wants to merge 2 commits intopyro-ppl:devfrom
lazypanda1:adam_optimizer_param_fix

Conversation

@lazypanda1
Copy link
Copy Markdown

This PR adds a check to prevent division by zero errors and give users friendlier error messages when using the Adam optimizer.

Currently, if one specifies the beta value of the Adam optimizer as 1.0 for the first parameter, pyro fails with the error message, ZeroDivisionError: float division by zero. According to the definition of Adam, beta values should be in the range [0, 1).

@CLAassistant
Copy link
Copy Markdown

CLAassistant commented Feb 8, 2018

CLA assistant check
All committers have signed the CLA.

@martinjankowiak
Copy link
Copy Markdown
Collaborator

hello. thanks for the contribution! since we just wrap the adam in pytorch i think a more appropriate place for this kind of check would be in pytorch?

Copy link
Copy Markdown
Member

@fritzo fritzo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this looks helpful. Looks like there's a lint error on travis.

Comment thread pyro/optim/__init__.py Outdated
return PyroOptim(torch.optim.Adam, optim_args, _adam_checker)

def _adam_checker(optim_args):
assert (optim_args['betas'][0] >= 0.0 and optim_args['betas'][0] < 1.0 \
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Could you convert these to if ...: raise ValueError rather than asserts? We try to reserve assertions for developer-facing errors and ValueErrors for user-facing errors.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Comment thread pyro/optim/optim.py Outdated
such dictionaries
"""
def __init__(self, optim_constructor, optim_args):
def __init__(self, optim_constructor, optim_args, checker=None):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe call this arg_checker?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@fritzo
Copy link
Copy Markdown
Member

fritzo commented Feb 8, 2018

👍 I agree the PyTorch folks would appreciate better error checking, then a larger community could benefit from this fix.

@fritzo
Copy link
Copy Markdown
Member

fritzo commented Feb 8, 2018

@lazypanda1 After discussing with the rest of the Pyro team, we feel fix really belongs upstream in PyTorch so more people can benefit. The PyTorch folks are very friendly about accepting PRs (we contribute regularly). I think you can simply put your ValueErrors now in Adam.__init__(): https://github.com/pytorch/pytorch/blob/master/torch/optim/adam.py#L29 If you send a PR this week it will likely make it into PyTorch 0.4 release and be available in the next major release of Pyro.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants