[dist_optim] add distributed functional Adam optimizer#50624
[dist_optim] add distributed functional Adam optimizer#50624wanchaol wants to merge 5 commits intogh/wanchaol/156/basefrom
Conversation
Add TorchScript compatible Adam functional optimizer to distributed optimizer [ghstack-poisoned]
| # Define a TorchScript compatible Functional Adam Optimizer | ||
| # where we use these optimizer in a functional way. | ||
| # Instead of using the `param.grad` when updating parameters, | ||
| # we explicitly let the user pass gradients to the `step` function |
There was a problem hiding this comment.
Just to clarify, so according to the comments below "user" is the DistributedOptimizer API, not RPC application user right? The call to optimizer should remain the same for RPC user?
There was a problem hiding this comment.
Yes that's right, DistributedOptimizer API actually pass those grads to the step function, let me update the comment to clarify
| + f"Gradients length: {len(gradients)}" | ||
| ) | ||
|
|
||
| for param, gradient in zip(self.param_group['params'], gradients): |
There was a problem hiding this comment.
General question, It looks like the similar code in torch/optim/adam.py uses for p in group['params'], and then accesses the grad with p.grad. Although I'm assuming we can't do this since we need the grads explicitly, since dist autograd doesn't populate p.grad?
There was a problem hiding this comment.
yeah in distributed autograd context, we don't populate p.grad, instead we call dist_autograd.get_gradients(autograd_ctx_id) to get the list of gradients locally.
rohan-varma
left a comment
There was a problem hiding this comment.
LGTM overall, although I mostly compared the changes to torch/optim/adam.py and torch/distributed/optim/functional_adagrad.py and checked for parity. I don't have context on the changes in torch/optim/functional.py so please get someone to look at that.
@pritamdamania87 Would be great if you get a chance to take a look at these changes as well.
| # update the steps for each param group update | ||
| state['step'] += 1 | ||
| # record the step after step update | ||
| state_steps.append(state['step'].item()) |
There was a problem hiding this comment.
I'm guessing all the logic up until the point we call F.adam is aiming to emulate torch/optim/adam.py, although is there any automated way to guarantee this? Could we dedupe the similar parts into helper functions and call those helper functions here? Alternatively, are we guaranteed that the dist optimizer tests will raise an error if the implementations diverge at all?
There was a problem hiding this comment.
yes they are similar indeed, but different across different optimizers as there need to be different states for each optimizer (and this is different from the original adam.py as well bc of TorchScript limitations on some syntax), so it's hard to generalized across them. Though, I think we can guaranteed the implementation will not diverge from the functional part as we shared the computation part, and test_optim has a good coverage of it. On the state management side, do you think we should introduce some sort of flag to disable/enable the TorchScript support and compare the results in the test?
|
|
||
| @dist_init() | ||
| def test_dist_optim(self): | ||
| self._test_dist_optim_base(optim.SGD, lr=0.05) |
There was a problem hiding this comment.
can we just move this to test_dist_optim_functional?
Ah nvm, this is regular optimizer, not torchscripted
Add TorchScript compatible Adam functional optimizer to distributed optimizer Differential Revision: [D25932770](https://our.internmc.facebook.com/intern/diff/D25932770) [ghstack-poisoned]
Add TorchScript compatible Adam functional optimizer to distributed optimizer Differential Revision: [D25932770](https://our.internmc.facebook.com/intern/diff/D25932770) [ghstack-poisoned]
Add TorchScript compatible Adam functional optimizer to distributed optimizer Differential Revision: [D25932770](https://our.internmc.facebook.com/intern/diff/D25932770) [ghstack-poisoned]
Add TorchScript compatible Adam functional optimizer to distributed optimizer Differential Revision: [D25932770](https://our.internmc.facebook.com/intern/diff/D25932770) [ghstack-poisoned]
Summary: Pull Request resolved: pytorch#50624 Add TorchScript compatible Adam functional optimizer to distributed optimizer Test Plan: Imported from OSS Reviewed By: rohan-varma Differential Revision: D25932770 Pulled By: wanchaol fbshipit-source-id: cab3f1164c76186969c284a2c52481b79bbb7190
Stack from ghstack:
Add TorchScript compatible Adam functional optimizer to distributed optimizer
Differential Revision: D25932770