-
Notifications
You must be signed in to change notification settings - Fork 430
Add ability to set tf.function args for scipy optimizer #2064
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
uri-granta
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Just one comment.
gpflow/optimizers/scipy.py
Outdated
| step_callback: Optional[StepCallback] = None, | ||
| compile: bool = True, | ||
| allow_unused_variables: bool = False, | ||
| tffun_args: Dict[str, Any] = {}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- More general to allow any mapping, not just dicts.
- Generally best to avoid mutable default arguments, since if you mutate them in (or outside) the function they end up being changed for all subsequent calls. For dicts it's conventional to use None (since there isn't a built in immutable dict structure):
| tffun_args: Dict[str, Any] = {}, | |
| tffun_args: Optional[Mapping[str, Any]] = None, | |
| .... | |
| if ttfun_args is None: | |
| ttfun_args = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reasoned that tffun_args wasn't ever going to be mutated; but I agree it is better to use None. I have made both those suggested improvements.
gpflow/optimizers/scipy.py
Outdated
| variables: Sequence[tf.Variable], | ||
| compile: bool = True, | ||
| allow_unused_variables: bool = False, | ||
| tffun_args: Dict[str, Any] = {}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(ditto)
| "tffun_args", | ||
| [{}, dict(jit_compile=True), dict(jit_compile=False, other_arg="dummy")], | ||
| ) | ||
| def test_scipy__tffun_args( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice test
Yes. Wouldn't adding a dictionary with arguments to tensorflow |
I am not. Suggestions welcome :-)
The current version of However if we want to keep it simple, we could just add direct support for XLA compilation by, e.g., modifying the
It is a bit messy/confusing to have two different dictionaries. I am not sure how to add a generic mechanism in a backwards compatible way. We could combine the two dictionaries into one (but if we rename the existing |
Personally, I think the current approach is the best one (though the docstring for
I therefore don't think minimize(
m3.training_loss,
variables=m3.trainable_variables,
options=dict(maxiter=50),
compile=True,
tffun_args=dict(jit_compile=True),
) |
|
I have added an example to the docstring. |
Codecov ReportPatch coverage:
Additional details and impacted files@@ Coverage Diff @@
## develop #2064 +/- ##
===========================================
+ Coverage 98.16% 98.18% +0.01%
===========================================
Files 97 97
Lines 5458 5462 +4
===========================================
+ Hits 5358 5363 +5
+ Misses 100 99 -1
☔ View full report in Codecov by Sentry. |
PR type: enhancement
Related issue(s)/PRs: None
Summary
Proposed changes
Scipyoptimizer. This can be achieved by passingjit_compile=Trueargument totf.function.tf.functioncall inScipy.minimize.Scipy.minimizeis a normal dictionary and not keyword-args. There is already a kwargs argument for scipy optimizer itself.What alternatives have you considered?
Considered changing the existing
compile: boolargument toScipy.minimize()to an enum; with options for "no-compilation", "default-compilation" and "XLA-compilation" -- or something similar. However, decided that it is better to add a more general purpose mechanism allowing setting of anytf.functionargument. There are other existing arguments (and there might be others in the future) that some users may want to set.Minimal working example
Release notes
Fully backwards compatible: yes
PR checklist
make format)make check-all)