Skip to content

Commit a0cf556

Browse files
wanchaolfacebook-github-bot
authored andcommitted
[optimizer] refactor SGD to use functional API (#45597)
Summary: Pull Request resolved: #45597 Test Plan: Imported from OSS Reviewed By: izdeby Differential Revision: D25932773 Pulled By: wanchaol fbshipit-source-id: bc5f830d6812f847475b9bdcc67865d9968e3282
1 parent b96a651 commit a0cf556

2 files changed

Lines changed: 64 additions & 18 deletions

File tree

torch/optim/functional.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import math
33
import torch
44
from torch import Tensor
5-
from typing import List
5+
from typing import List, Optional
66

77
# TODO: use foreach API in optim.functional to do all the computation
88

@@ -96,3 +96,39 @@ def adam(params: List[Tensor],
9696
step_size = lr / bias_correction1
9797

9898
param.addcdiv_(exp_avg, denom, value=-step_size)
99+
100+
101+
def sgd(params: List[Tensor],
102+
d_p_list: List[Tensor],
103+
momentum_buffer_list: List[Optional[Tensor]],
104+
weight_decay: float,
105+
momentum: float,
106+
lr: float,
107+
dampening: float,
108+
nesterov: bool):
109+
r"""Functional API that performs SGD algorithm computation.
110+
111+
See :class:`~torch.optim.SGD` for details.
112+
"""
113+
114+
for i, param in enumerate(params):
115+
116+
d_p = d_p_list[i]
117+
if weight_decay != 0:
118+
d_p = d_p.add(param, alpha=weight_decay)
119+
120+
if momentum != 0:
121+
buf = momentum_buffer_list[i]
122+
123+
if buf is None:
124+
buf = torch.clone(d_p).detach()
125+
momentum_buffer_list[i] = buf
126+
else:
127+
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
128+
129+
if nesterov:
130+
d_p = d_p.add(buf, alpha=momentum)
131+
else:
132+
d_p = buf
133+
134+
param.add_(d_p, alpha=-lr)

torch/optim/sgd.py

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
from . import functional as F
23
from .optimizer import Optimizer, required
34

45

@@ -86,29 +87,38 @@ def step(self, closure=None):
8687
loss = closure()
8788

8889
for group in self.param_groups:
90+
params_with_grad = []
91+
d_p_list = []
92+
momentum_buffer_list = []
8993
weight_decay = group['weight_decay']
9094
momentum = group['momentum']
9195
dampening = group['dampening']
9296
nesterov = group['nesterov']
97+
lr = group['lr']
9398

9499
for p in group['params']:
95-
if p.grad is None:
96-
continue
97-
d_p = p.grad
98-
if weight_decay != 0:
99-
d_p = d_p.add(p, alpha=weight_decay)
100-
if momentum != 0:
101-
param_state = self.state[p]
102-
if 'momentum_buffer' not in param_state:
103-
buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()
104-
else:
105-
buf = param_state['momentum_buffer']
106-
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
107-
if nesterov:
108-
d_p = d_p.add(buf, alpha=momentum)
109-
else:
110-
d_p = buf
100+
if p.grad is not None:
101+
params_with_grad.append(p)
102+
d_p_list.append(p.grad)
111103

112-
p.add_(d_p, alpha=-group['lr'])
104+
state = self.state[p]
105+
if 'momentum_buffer' not in state:
106+
momentum_buffer_list.append(None)
107+
else:
108+
momentum_buffer_list.append(state['momentum_buffer'])
109+
110+
F.sgd(params_with_grad,
111+
d_p_list,
112+
momentum_buffer_list,
113+
weight_decay,
114+
momentum,
115+
lr,
116+
dampening,
117+
nesterov)
118+
119+
# update momentum_buffers in state
120+
for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):
121+
state = self.state[p]
122+
state['momentum_buffer'] = momentum_buffer
113123

114124
return loss

0 commit comments

Comments
 (0)