Skip to content

Commit c8c1910

Browse files
feat: LARS optimizer
1 parent 64077ce commit c8c1910

3 files changed

Lines changed: 225 additions & 0 deletions

File tree

test/test_optim.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1786,6 +1786,49 @@ def test_fused_optimizer_raises(self):
17861786
with self.assertRaisesRegex(RuntimeError, "`fused` does not support `differentiable`"):
17871787
optimizer_ctor([torch.empty((), device="cuda")], differentiable=True, fused=True)
17881788

1789+
def test_lars(self):
1790+
# ASK: What's the reason behind two identical calls? (See SGD tests)
1791+
self._test_basic_cases(
1792+
lambda weight, bias, maximize: optim.LARS([weight, bias], lr=1e-3, maximize=maximize),
1793+
constructor_accepts_maximize=True, constructor_accepts_foreach=False,
1794+
)
1795+
self._test_basic_cases(
1796+
lambda weight, bias, maximize: optim.LARS(
1797+
self._build_params_dict(weight, bias, lr=1e-2),
1798+
lr=1e-3, maximize=maximize),
1799+
constructor_accepts_maximize=True, constructor_accepts_foreach=False,
1800+
)
1801+
self._test_basic_cases(
1802+
lambda weight, bias, maximize: optim.LARS(
1803+
self._build_params_dict_single(weight, bias, lr=1e-2),
1804+
lr=1e-3, maximize=maximize),
1805+
constructor_accepts_maximize=True, constructor_accepts_foreach=False,
1806+
)
1807+
self._test_basic_cases(
1808+
lambda weight, bias, maximize: optim.LARS(
1809+
self._build_params_dict_single(weight, bias, lr=1e-2), maximize=maximize),
1810+
constructor_accepts_maximize=True, constructor_accepts_foreach=False,
1811+
)
1812+
self._test_basic_cases(
1813+
lambda weight, bias, maximize:
1814+
optim.LARS([weight, bias], lr=1e-3, momentum=0.5, weight_decay=1, dampening=0.0, nesterov=True, maximize=maximize),
1815+
constructor_accepts_maximize=True, constructor_accepts_foreach=False,
1816+
)
1817+
self._test_basic_cases(
1818+
lambda weight, bias, maximize:
1819+
optim.LARS([weight, bias], lr=1e-3, trust_coefficient=0.01, eps=1e-5, maximize=maximize),
1820+
constructor_accepts_maximize=True, constructor_accepts_foreach=False,
1821+
)
1822+
with self.assertRaisesRegex(ValueError, "Invalid learning rate: -0.1"):
1823+
optim.LARS(None, lr=-0.1)
1824+
with self.assertRaisesRegex(ValueError, "Invalid weight decay value: -0.5"):
1825+
optim.LARS(None, lr=1e-2, weight_decay=-0.5)
1826+
with self.assertRaisesRegex(ValueError, "Invalid momentum value: -0.5"):
1827+
optim.LARS(None, lr=1e-2, momentum=-0.5)
1828+
with self.assertRaisesRegex(ValueError, "Nesterov momentum requires a momentum and zero dampening"):
1829+
optim.LARS(None, lr=1e-2, nesterov=True, momentum=0.1, dampening=0.1)
1830+
with self.assertRaisesRegex(ValueError, "Nesterov momentum requires a momentum and zero dampening"):
1831+
optim.LARS(None, lr=1e-2, nesterov=True, momentum=0.0, dampening=0.0)
17891832

17901833
class SchedulerTestNet(torch.nn.Module):
17911834
def __init__(self):
@@ -4542,6 +4585,14 @@ def test_radam(self):
45424585
),
45434586
)
45444587

4588+
def test_lars(self):
4589+
p = torch.rand(10, requires_grad=True, dtype=torch.float64)
4590+
grad = torch.rand(10, requires_grad=True, dtype=torch.float64)
4591+
mbuff = torch.rand(10, requires_grad=True, dtype=torch.float64)
4592+
state = {'momentum_buffer': mbuff}
4593+
gradcheck(_diff_fn, (p, grad, state, torch.optim.LARS, {'lr': 0.9, 'differentiable': True}, *state.values()))
4594+
4595+
45454596

45464597
@unittest.skipIf(not TEST_CUDA, "test requires CUDA")
45474598
def test_defaults_changed_to_foreach(self):

torch/optim/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from .lbfgs import LBFGS
2222
from . import lr_scheduler
2323
from . import swa_utils
24+
from .lars import LARS
2425

2526
del adadelta
2627
del adagrad
@@ -36,3 +37,4 @@
3637
del optimizer
3738
del nadam
3839
del lbfgs
40+
del lars

torch/optim/lars.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
from .optimizer import Optimizer, required, _use_grad_for_differentiable # type: ignore[attr-defined]
2+
import torch
3+
from typing import List, Optional
4+
from torch import Tensor
5+
6+
__all__ = ["LARS", "lars"]
7+
8+
9+
class LARS(Optimizer):
10+
"""Implements LARS algorithm."""
11+
12+
def __init__(
13+
self,
14+
params,
15+
lr=required,
16+
momentum: float = 0,
17+
dampening: float = 0,
18+
weight_decay: float = 0,
19+
nesterov: bool = False,
20+
*,
21+
trust_coefficient: float = 0.001,
22+
eps: float = 1e-8,
23+
maximize: bool = False,
24+
differentiable: bool = False,
25+
):
26+
if lr is not required and lr < 0.0:
27+
raise ValueError(f"Invalid learning rate: {lr}")
28+
if weight_decay < 0.0:
29+
raise ValueError(f"Invalid weight decay value: {weight_decay}")
30+
if momentum < 0.0:
31+
raise ValueError(f"Invalid momentum value: {momentum}")
32+
if nesterov and (momentum <= 0 or dampening != 0):
33+
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
34+
35+
defaults = dict(
36+
lr=lr,
37+
momentum=momentum,
38+
dampening=dampening,
39+
weight_decay=weight_decay,
40+
nesterov=nesterov,
41+
trust_coefficient=trust_coefficient,
42+
eps=eps,
43+
maximize=maximize,
44+
differentiable=differentiable,
45+
)
46+
47+
super().__init__(params, defaults)
48+
49+
def __setstate__(self, state):
50+
super().__setstate__(state)
51+
for group in self.param_groups:
52+
group.setdefault("nesterov", False)
53+
group.setdefault("maximize", False)
54+
group.setdefault("differentiable", False)
55+
56+
@_use_grad_for_differentiable
57+
def step(self, closure=None):
58+
loss = None
59+
if closure is not None:
60+
with torch.enable_grad():
61+
loss = closure()
62+
63+
for group in self.param_groups:
64+
params_with_grad = []
65+
grads = []
66+
momentum_buffer_list = []
67+
68+
for p in group["params"]:
69+
if p.grad is not None:
70+
params_with_grad.append(p)
71+
grads.append(p.grad)
72+
73+
state = self.state[p]
74+
75+
momentum_buffer_list.append(state.get("momentum_buffer"))
76+
77+
lars(
78+
params_with_grad,
79+
grads,
80+
momentum_buffer_list,
81+
lr=group["lr"],
82+
momentum=group["momentum"],
83+
dampening=group["dampening"],
84+
weight_decay=group["weight_decay"],
85+
nesterov=group["nesterov"],
86+
trust_coefficient=group["trust_coefficient"],
87+
eps=group["eps"],
88+
maximize=group["maximize"],
89+
)
90+
91+
for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):
92+
state = self.state[p]
93+
state["momentum_buffer"] = momentum_buffer
94+
95+
return loss
96+
97+
98+
def lars(
99+
params: List[Tensor],
100+
grads: List[Tensor],
101+
momentum_buffer_list: List[Optional[Tensor]],
102+
*,
103+
lr: float,
104+
momentum: float,
105+
dampening: float,
106+
weight_decay: float,
107+
nesterov: bool,
108+
trust_coefficient: float,
109+
eps: float,
110+
maximize: bool,
111+
):
112+
if torch.jit.is_scripting():
113+
raise RuntimeError('torch.jit.script not supported with foreach optimizers')
114+
115+
if not torch.jit.is_scripting():
116+
func = _single_tensor_lars
117+
118+
func(
119+
params,
120+
grads,
121+
momentum_buffer_list,
122+
lr=lr,
123+
momentum=momentum,
124+
dampening=dampening,
125+
weight_decay=weight_decay,
126+
nesterov=nesterov,
127+
trust_coefficient=trust_coefficient,
128+
eps=eps,
129+
maximize=maximize,
130+
)
131+
132+
133+
def _single_tensor_lars(
134+
params: List[Tensor],
135+
grads: List[Tensor],
136+
momentum_buffer_list: List[Optional[Tensor]],
137+
*,
138+
lr: float,
139+
momentum: float,
140+
dampening: float,
141+
weight_decay: float,
142+
nesterov: bool,
143+
trust_coefficient: float,
144+
eps: float,
145+
maximize: bool,
146+
):
147+
for i, param in enumerate(params):
148+
d_p = grads[i] if not maximize else -grads[i]
149+
150+
p_norm = torch.norm(param.data)
151+
g_norm = torch.norm(d_p.data)
152+
153+
if weight_decay != 0:
154+
# LARS scaling:
155+
if p_norm * g_norm > 0:
156+
lars_lr = trust_coefficient * p_norm / (g_norm + p_norm * weight_decay + eps)
157+
158+
d_p = d_p.add(param, alpha=weight_decay)
159+
d_p.mul_(lars_lr)
160+
161+
if momentum != 0:
162+
buf = momentum_buffer_list[i]
163+
164+
if buf is None:
165+
buf = torch.clone(d_p).detach()
166+
momentum_buffer_list[i] = buf
167+
else:
168+
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
169+
170+
d_p = d_p.add(buf, alpha=momentum) if nesterov else buf
171+
172+
param.add_(d_p, alpha=-lr)

0 commit comments

Comments
 (0)