Skip to content

Commit 109c83b

Browse files
tmp: docs + linter
1 parent c8c1910 commit 109c83b

1 file changed

Lines changed: 39 additions & 3 deletions

File tree

torch/optim/lars.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .optimizer import Optimizer, required, _use_grad_for_differentiable # type: ignore[attr-defined]
1+
from .optimizer import Optimizer, required, _use_grad_for_differentiable, _differentiable_doc, _maximize_doc # type: ignore[attr-defined]
22
import torch
33
from typing import List, Optional
44
from torch import Tensor
@@ -7,8 +7,6 @@
77

88

99
class LARS(Optimizer):
10-
"""Implements LARS algorithm."""
11-
1210
def __init__(
1311
self,
1412
params,
@@ -55,6 +53,12 @@ def __setstate__(self, state):
5553

5654
@_use_grad_for_differentiable
5755
def step(self, closure=None):
56+
"""Performs a single optimization step.
57+
58+
Args:
59+
closure (Callable, optional): A closure that reevaluates the model
60+
and returns the loss.
61+
"""
5862
loss = None
5963
if closure is not None:
6064
with torch.enable_grad():
@@ -94,6 +98,35 @@ def step(self, closure=None):
9498

9599
return loss
96100

101+
LARS.__doc__ = r"""Implements LARS algorithm.
102+
103+
For further details regarding the algorithm we refer to `Large Batch Training of Convolutional Networks`_.
104+
""" + r"""
105+
Args:
106+
params (iterable): iterable of parameters to optimize or dicts defining
107+
parameter groups
108+
lr (float): learning rate
109+
momentum (float, optional): momentum factor (default: 0)
110+
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
111+
nesterov (bool, optional): enables Nesterov momentum (default: False)
112+
trust_coefficient (float, optional): coefficient for computing LR (default: 0.001)
113+
eps (float, optional): term added to the denominator to improve
114+
numerical stability (default: 1e-8)
115+
{maximize}
116+
{differentiable}
117+
118+
.. _Large Batch Training of Convolutional Networks:
119+
https://arxiv.org/abs/1708.03888
120+
121+
""".format(maximize=_maximize_doc, differentiable=_differentiable_doc) + r"""
122+
123+
Example:
124+
>>> # xdoctest: +SKIP
125+
>>> optimizer = torch.optim.LARS(model.parameters(), lr=0.1, momentum=0.9)
126+
>>> optimizer.zero_grad()
127+
>>> loss_fn(model(input), target).backward()
128+
>>> optimizer.step()
129+
"""
97130

98131
def lars(
99132
params: List[Tensor],
@@ -109,6 +142,9 @@ def lars(
109142
eps: float,
110143
maximize: bool,
111144
):
145+
r"""Functional API that performs LARS algorithm computation.
146+
See :class:`~torch.optim.LARS` for details.
147+
"""
112148
if torch.jit.is_scripting():
113149
raise RuntimeError('torch.jit.script not supported with foreach optimizers')
114150

0 commit comments

Comments
 (0)