1- from .optimizer import Optimizer , required , _use_grad_for_differentiable # type: ignore[attr-defined]
1+ from .optimizer import Optimizer , required , _use_grad_for_differentiable , _differentiable_doc , _maximize_doc # type: ignore[attr-defined]
22import torch
33from typing import List , Optional
44from torch import Tensor
77
88
99class LARS (Optimizer ):
10- """Implements LARS algorithm."""
11-
1210 def __init__ (
1311 self ,
1412 params ,
@@ -55,6 +53,12 @@ def __setstate__(self, state):
5553
5654 @_use_grad_for_differentiable
5755 def step (self , closure = None ):
56+ """Performs a single optimization step.
57+
58+ Args:
59+ closure (Callable, optional): A closure that reevaluates the model
60+ and returns the loss.
61+ """
5862 loss = None
5963 if closure is not None :
6064 with torch .enable_grad ():
@@ -94,6 +98,35 @@ def step(self, closure=None):
9498
9599 return loss
96100
101+ LARS .__doc__ = r"""Implements LARS algorithm.
102+
103+ For further details regarding the algorithm we refer to `Large Batch Training of Convolutional Networks`_.
104+ """ + r"""
105+ Args:
106+ params (iterable): iterable of parameters to optimize or dicts defining
107+ parameter groups
108+ lr (float): learning rate
109+ momentum (float, optional): momentum factor (default: 0)
110+ weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
111+ nesterov (bool, optional): enables Nesterov momentum (default: False)
112+ trust_coefficient (float, optional): coefficient for computing LR (default: 0.001)
113+ eps (float, optional): term added to the denominator to improve
114+ numerical stability (default: 1e-8)
115+ {maximize}
116+ {differentiable}
117+
118+ .. _Large Batch Training of Convolutional Networks:
119+ https://arxiv.org/abs/1708.03888
120+
121+ """ .format (maximize = _maximize_doc , differentiable = _differentiable_doc ) + r"""
122+
123+ Example:
124+ >>> # xdoctest: +SKIP
125+ >>> optimizer = torch.optim.LARS(model.parameters(), lr=0.1, momentum=0.9)
126+ >>> optimizer.zero_grad()
127+ >>> loss_fn(model(input), target).backward()
128+ >>> optimizer.step()
129+ """
97130
98131def lars (
99132 params : List [Tensor ],
@@ -109,6 +142,9 @@ def lars(
109142 eps : float ,
110143 maximize : bool ,
111144):
145+ r"""Functional API that performs LARS algorithm computation.
146+ See :class:`~torch.optim.LARS` for details.
147+ """
112148 if torch .jit .is_scripting ():
113149 raise RuntimeError ('torch.jit.script not supported with foreach optimizers' )
114150
0 commit comments