xref: /aosp_15_r20/external/pytorch/torch/optim/adagrad.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerfrom typing import cast, List, Optional, Union
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerimport torch
5*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerfrom .optimizer import (
8*da0073e9SAndroid Build Coastguard Worker    _default_to_fused_or_foreach,
9*da0073e9SAndroid Build Coastguard Worker    _device_dtype_check_for_fused,
10*da0073e9SAndroid Build Coastguard Worker    _differentiable_doc,
11*da0073e9SAndroid Build Coastguard Worker    _foreach_doc,
12*da0073e9SAndroid Build Coastguard Worker    _get_scalar_dtype,
13*da0073e9SAndroid Build Coastguard Worker    _get_value,
14*da0073e9SAndroid Build Coastguard Worker    _maximize_doc,
15*da0073e9SAndroid Build Coastguard Worker    _use_grad_for_differentiable,
16*da0073e9SAndroid Build Coastguard Worker    _view_as_real,
17*da0073e9SAndroid Build Coastguard Worker    Optimizer,
18*da0073e9SAndroid Build Coastguard Worker    ParamsT,
19*da0073e9SAndroid Build Coastguard Worker)
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker__all__ = ["Adagrad", "adagrad"]
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Workerclass Adagrad(Optimizer):
26*da0073e9SAndroid Build Coastguard Worker    def __init__(
27*da0073e9SAndroid Build Coastguard Worker        self,
28*da0073e9SAndroid Build Coastguard Worker        params: ParamsT,
29*da0073e9SAndroid Build Coastguard Worker        lr: Union[float, Tensor] = 1e-2,
30*da0073e9SAndroid Build Coastguard Worker        lr_decay: float = 0,
31*da0073e9SAndroid Build Coastguard Worker        weight_decay: float = 0,
32*da0073e9SAndroid Build Coastguard Worker        initial_accumulator_value: float = 0,
33*da0073e9SAndroid Build Coastguard Worker        eps: float = 1e-10,
34*da0073e9SAndroid Build Coastguard Worker        foreach: Optional[bool] = None,
35*da0073e9SAndroid Build Coastguard Worker        *,
36*da0073e9SAndroid Build Coastguard Worker        maximize: bool = False,
37*da0073e9SAndroid Build Coastguard Worker        differentiable: bool = False,
38*da0073e9SAndroid Build Coastguard Worker        fused: Optional[bool] = None,
39*da0073e9SAndroid Build Coastguard Worker    ):
40*da0073e9SAndroid Build Coastguard Worker        if isinstance(lr, Tensor) and lr.numel() != 1:
41*da0073e9SAndroid Build Coastguard Worker            raise ValueError("Tensor lr must be 1-element")
42*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= lr:
43*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid learning rate: {lr}")
44*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= lr_decay:
45*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid lr_decay value: {lr_decay}")
46*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= weight_decay:
47*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
48*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= initial_accumulator_value:
49*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
50*da0073e9SAndroid Build Coastguard Worker                f"Invalid initial_accumulator_value value: {initial_accumulator_value}"
51*da0073e9SAndroid Build Coastguard Worker            )
52*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= eps:
53*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid epsilon value: {eps}")
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker        defaults = dict(
56*da0073e9SAndroid Build Coastguard Worker            lr=lr,
57*da0073e9SAndroid Build Coastguard Worker            lr_decay=lr_decay,
58*da0073e9SAndroid Build Coastguard Worker            eps=eps,
59*da0073e9SAndroid Build Coastguard Worker            weight_decay=weight_decay,
60*da0073e9SAndroid Build Coastguard Worker            initial_accumulator_value=initial_accumulator_value,
61*da0073e9SAndroid Build Coastguard Worker            foreach=foreach,
62*da0073e9SAndroid Build Coastguard Worker            maximize=maximize,
63*da0073e9SAndroid Build Coastguard Worker            differentiable=differentiable,
64*da0073e9SAndroid Build Coastguard Worker            fused=fused,
65*da0073e9SAndroid Build Coastguard Worker        )
66*da0073e9SAndroid Build Coastguard Worker        super().__init__(params, defaults)
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Worker        if fused:
69*da0073e9SAndroid Build Coastguard Worker            if differentiable:
70*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError("`fused` does not support `differentiable`")
71*da0073e9SAndroid Build Coastguard Worker            if foreach:
72*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
73*da0073e9SAndroid Build Coastguard Worker            self._need_device_dtype_check_for_fused = True
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker        for group in self.param_groups:
76*da0073e9SAndroid Build Coastguard Worker            for p in group["params"]:
77*da0073e9SAndroid Build Coastguard Worker                state = self.state[p]
78*da0073e9SAndroid Build Coastguard Worker                state["step"] = (
79*da0073e9SAndroid Build Coastguard Worker                    torch.zeros(
80*da0073e9SAndroid Build Coastguard Worker                        (),
81*da0073e9SAndroid Build Coastguard Worker                        dtype=_get_scalar_dtype(is_fused=group["fused"]),
82*da0073e9SAndroid Build Coastguard Worker                        device=p.device,
83*da0073e9SAndroid Build Coastguard Worker                    )
84*da0073e9SAndroid Build Coastguard Worker                    if group["fused"]
85*da0073e9SAndroid Build Coastguard Worker                    else torch.tensor(0.0, dtype=_get_scalar_dtype())
86*da0073e9SAndroid Build Coastguard Worker                )
87*da0073e9SAndroid Build Coastguard Worker                init_value = (
88*da0073e9SAndroid Build Coastguard Worker                    complex(initial_accumulator_value, initial_accumulator_value)
89*da0073e9SAndroid Build Coastguard Worker                    if torch.is_complex(p)
90*da0073e9SAndroid Build Coastguard Worker                    else initial_accumulator_value
91*da0073e9SAndroid Build Coastguard Worker                )
92*da0073e9SAndroid Build Coastguard Worker                state["sum"] = torch.full_like(
93*da0073e9SAndroid Build Coastguard Worker                    p, init_value, memory_format=torch.preserve_format
94*da0073e9SAndroid Build Coastguard Worker                )
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker    def __setstate__(self, state):
97*da0073e9SAndroid Build Coastguard Worker        super().__setstate__(state)
98*da0073e9SAndroid Build Coastguard Worker        #  define "fused" for
99*da0073e9SAndroid Build Coastguard Worker        #  MYPY error: Name "fused" may be undefined
100*da0073e9SAndroid Build Coastguard Worker        fused = None
101*da0073e9SAndroid Build Coastguard Worker        for group in self.param_groups:
102*da0073e9SAndroid Build Coastguard Worker            group.setdefault("foreach", None)
103*da0073e9SAndroid Build Coastguard Worker            group.setdefault("maximize", False)
104*da0073e9SAndroid Build Coastguard Worker            group.setdefault("differentiable", False)
105*da0073e9SAndroid Build Coastguard Worker            fused = group.setdefault("fused", None)
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker        state_values = list(self.state.values())
108*da0073e9SAndroid Build Coastguard Worker        step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
109*da0073e9SAndroid Build Coastguard Worker            state_values[0]["step"]
110*da0073e9SAndroid Build Coastguard Worker        )
111*da0073e9SAndroid Build Coastguard Worker        if not step_is_tensor:
112*da0073e9SAndroid Build Coastguard Worker            for s in state_values:
113*da0073e9SAndroid Build Coastguard Worker                s["step"] = torch.tensor(
114*da0073e9SAndroid Build Coastguard Worker                    float(s["step"]), dtype=_get_scalar_dtype(is_fused=fused)
115*da0073e9SAndroid Build Coastguard Worker                )
116*da0073e9SAndroid Build Coastguard Worker
117*da0073e9SAndroid Build Coastguard Worker    def share_memory(self):
118*da0073e9SAndroid Build Coastguard Worker        for group in self.param_groups:
119*da0073e9SAndroid Build Coastguard Worker            for p in group["params"]:
120*da0073e9SAndroid Build Coastguard Worker                state = self.state[p]
121*da0073e9SAndroid Build Coastguard Worker                state["sum"].share_memory_()
122*da0073e9SAndroid Build Coastguard Worker
123*da0073e9SAndroid Build Coastguard Worker    def _init_group(self, group, params_with_grad, grads, state_sums, state_steps):
124*da0073e9SAndroid Build Coastguard Worker        has_sparse_grad, has_complex = False, False
125*da0073e9SAndroid Build Coastguard Worker        for p in group["params"]:
126*da0073e9SAndroid Build Coastguard Worker            if p.grad is not None:
127*da0073e9SAndroid Build Coastguard Worker                if group["fused"] and getattr(
128*da0073e9SAndroid Build Coastguard Worker                    self,
129*da0073e9SAndroid Build Coastguard Worker                    "_need_device_dtype_check_for_fused",
130*da0073e9SAndroid Build Coastguard Worker                    True,
131*da0073e9SAndroid Build Coastguard Worker                ):
132*da0073e9SAndroid Build Coastguard Worker                    _device_dtype_check_for_fused(p, cuda_unsupported=True)
133*da0073e9SAndroid Build Coastguard Worker                    self._need_device_dtype_check_for_fused = False
134*da0073e9SAndroid Build Coastguard Worker                has_sparse_grad |= p.grad.is_sparse
135*da0073e9SAndroid Build Coastguard Worker                has_complex |= torch.is_complex(p)
136*da0073e9SAndroid Build Coastguard Worker                params_with_grad.append(p)
137*da0073e9SAndroid Build Coastguard Worker                grads.append(p.grad)
138*da0073e9SAndroid Build Coastguard Worker                state = self.state[p]
139*da0073e9SAndroid Build Coastguard Worker                state_sums.append(state["sum"])
140*da0073e9SAndroid Build Coastguard Worker                state_steps.append(state["step"])
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker        return has_sparse_grad, has_complex
143*da0073e9SAndroid Build Coastguard Worker
144*da0073e9SAndroid Build Coastguard Worker    @_use_grad_for_differentiable
145*da0073e9SAndroid Build Coastguard Worker    def step(self, closure=None):
146*da0073e9SAndroid Build Coastguard Worker        """Perform a single optimization step.
147*da0073e9SAndroid Build Coastguard Worker
148*da0073e9SAndroid Build Coastguard Worker        Args:
149*da0073e9SAndroid Build Coastguard Worker            closure (Callable, optional): A closure that reevaluates the model
150*da0073e9SAndroid Build Coastguard Worker                and returns the loss.
151*da0073e9SAndroid Build Coastguard Worker        """
152*da0073e9SAndroid Build Coastguard Worker        loss = None
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker        if closure is not None:
155*da0073e9SAndroid Build Coastguard Worker            with torch.enable_grad():
156*da0073e9SAndroid Build Coastguard Worker                loss = closure()
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker        for group in self.param_groups:
159*da0073e9SAndroid Build Coastguard Worker            params_with_grad: List[Tensor] = []
160*da0073e9SAndroid Build Coastguard Worker            grads: List[Tensor] = []
161*da0073e9SAndroid Build Coastguard Worker            state_sums: List[Tensor] = []
162*da0073e9SAndroid Build Coastguard Worker            state_steps: List[Tensor] = []
163*da0073e9SAndroid Build Coastguard Worker
164*da0073e9SAndroid Build Coastguard Worker            has_sparse_grad, has_complex = self._init_group(
165*da0073e9SAndroid Build Coastguard Worker                group, params_with_grad, grads, state_sums, state_steps
166*da0073e9SAndroid Build Coastguard Worker            )
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker            adagrad(
169*da0073e9SAndroid Build Coastguard Worker                params_with_grad,
170*da0073e9SAndroid Build Coastguard Worker                grads,
171*da0073e9SAndroid Build Coastguard Worker                state_sums,
172*da0073e9SAndroid Build Coastguard Worker                state_steps,
173*da0073e9SAndroid Build Coastguard Worker                lr=group["lr"],
174*da0073e9SAndroid Build Coastguard Worker                weight_decay=group["weight_decay"],
175*da0073e9SAndroid Build Coastguard Worker                lr_decay=group["lr_decay"],
176*da0073e9SAndroid Build Coastguard Worker                eps=group["eps"],
177*da0073e9SAndroid Build Coastguard Worker                has_sparse_grad=has_sparse_grad,
178*da0073e9SAndroid Build Coastguard Worker                foreach=group["foreach"],
179*da0073e9SAndroid Build Coastguard Worker                maximize=group["maximize"],
180*da0073e9SAndroid Build Coastguard Worker                differentiable=group["differentiable"],
181*da0073e9SAndroid Build Coastguard Worker                has_complex=has_complex,
182*da0073e9SAndroid Build Coastguard Worker                fused=group["fused"],
183*da0073e9SAndroid Build Coastguard Worker                grad_scale=getattr(self, "grad_scale", None),
184*da0073e9SAndroid Build Coastguard Worker                found_inf=getattr(self, "found_inf", None),
185*da0073e9SAndroid Build Coastguard Worker            )
186*da0073e9SAndroid Build Coastguard Worker
187*da0073e9SAndroid Build Coastguard Worker        return loss
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard WorkerAdagrad.__doc__ = (
191*da0073e9SAndroid Build Coastguard Worker    r"""Implements Adagrad algorithm.
192*da0073e9SAndroid Build Coastguard Worker
193*da0073e9SAndroid Build Coastguard Worker    .. math::
194*da0073e9SAndroid Build Coastguard Worker       \begin{aligned}
195*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                                 \\
196*da0073e9SAndroid Build Coastguard Worker            &\textbf{input}      : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta)
197*da0073e9SAndroid Build Coastguard Worker                \text{ (objective)}, \: \lambda \text{ (weight decay)},                          \\
198*da0073e9SAndroid Build Coastguard Worker            &\hspace{12mm}    \tau \text{ (initial accumulator value)}, \: \eta\text{ (lr decay)}\\
199*da0073e9SAndroid Build Coastguard Worker            &\textbf{initialize} :  state\_sum_0 \leftarrow \tau                          \\[-1.ex]
200*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                                 \\
201*da0073e9SAndroid Build Coastguard Worker            &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do}                         \\
202*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}g_t           \leftarrow   \nabla_{\theta} f_t (\theta_{t-1})           \\
203*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm} \tilde{\gamma}    \leftarrow \gamma / (1 +(t-1) \eta)                  \\
204*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm} \textbf{if} \: \lambda \neq 0                                          \\
205*da0073e9SAndroid Build Coastguard Worker            &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1}                             \\
206*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}state\_sum_t  \leftarrow  state\_sum_{t-1} + g^2_t                      \\
207*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}\theta_t \leftarrow
208*da0073e9SAndroid Build Coastguard Worker                \theta_{t-1}- \tilde{\gamma} \frac{g_t}{\sqrt{state\_sum_t}+\epsilon}            \\
209*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
210*da0073e9SAndroid Build Coastguard Worker            &\bf{return} \:  \theta_t                                                     \\[-1.ex]
211*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
212*da0073e9SAndroid Build Coastguard Worker       \end{aligned}
213*da0073e9SAndroid Build Coastguard Worker
214*da0073e9SAndroid Build Coastguard Worker    For further details regarding the algorithm we refer to `Adaptive Subgradient Methods for Online Learning
215*da0073e9SAndroid Build Coastguard Worker    and Stochastic Optimization`_.
216*da0073e9SAndroid Build Coastguard Worker    """
217*da0073e9SAndroid Build Coastguard Worker    + rf"""
218*da0073e9SAndroid Build Coastguard Worker    Args:
219*da0073e9SAndroid Build Coastguard Worker        params (iterable): iterable of parameters to optimize or dicts defining
220*da0073e9SAndroid Build Coastguard Worker            parameter groups
221*da0073e9SAndroid Build Coastguard Worker        lr (float, Tensor, optional): learning rate (default: 1e-2)
222*da0073e9SAndroid Build Coastguard Worker        lr_decay (float, optional): learning rate decay (default: 0)
223*da0073e9SAndroid Build Coastguard Worker        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
224*da0073e9SAndroid Build Coastguard Worker        initial_accumulator_value (float, optional): initial value of the
225*da0073e9SAndroid Build Coastguard Worker            sum of squares of gradients (default: 0)
226*da0073e9SAndroid Build Coastguard Worker        eps (float, optional): term added to the denominator to improve
227*da0073e9SAndroid Build Coastguard Worker            numerical stability (default: 1e-10)
228*da0073e9SAndroid Build Coastguard Worker        {_foreach_doc}
229*da0073e9SAndroid Build Coastguard Worker        {_maximize_doc}
230*da0073e9SAndroid Build Coastguard Worker        {_differentiable_doc}
231*da0073e9SAndroid Build Coastguard Worker        fused (bool, optional): whether the fused implementation (CPU only) is used.
232*da0073e9SAndroid Build Coastguard Worker            Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16`
233*da0073e9SAndroid Build Coastguard Worker            are supported. (default: None). Please note that the fused implementations does not
234*da0073e9SAndroid Build Coastguard Worker            support sparse or complex gradients.
235*da0073e9SAndroid Build Coastguard Worker    .. _Adaptive Subgradient Methods for Online Learning and Stochastic
236*da0073e9SAndroid Build Coastguard Worker        Optimization: http://jmlr.org/papers/v12/duchi11a.html
237*da0073e9SAndroid Build Coastguard Worker
238*da0073e9SAndroid Build Coastguard Worker    """
239*da0073e9SAndroid Build Coastguard Worker)
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Workerdef adagrad(
243*da0073e9SAndroid Build Coastguard Worker    params: List[Tensor],
244*da0073e9SAndroid Build Coastguard Worker    grads: List[Tensor],
245*da0073e9SAndroid Build Coastguard Worker    state_sums: List[Tensor],
246*da0073e9SAndroid Build Coastguard Worker    state_steps: List[Tensor],
247*da0073e9SAndroid Build Coastguard Worker    fused: Optional[bool] = None,
248*da0073e9SAndroid Build Coastguard Worker    grad_scale: Optional[Tensor] = None,
249*da0073e9SAndroid Build Coastguard Worker    found_inf: Optional[Tensor] = None,
250*da0073e9SAndroid Build Coastguard Worker    # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
251*da0073e9SAndroid Build Coastguard Worker    # setting these as kwargs for now as functional API is compiled by torch/distributed/optim
252*da0073e9SAndroid Build Coastguard Worker    has_sparse_grad: bool = False,
253*da0073e9SAndroid Build Coastguard Worker    foreach: Optional[bool] = None,
254*da0073e9SAndroid Build Coastguard Worker    differentiable: bool = False,
255*da0073e9SAndroid Build Coastguard Worker    has_complex: bool = False,
256*da0073e9SAndroid Build Coastguard Worker    *,
257*da0073e9SAndroid Build Coastguard Worker    lr: float,
258*da0073e9SAndroid Build Coastguard Worker    weight_decay: float,
259*da0073e9SAndroid Build Coastguard Worker    lr_decay: float,
260*da0073e9SAndroid Build Coastguard Worker    eps: float,
261*da0073e9SAndroid Build Coastguard Worker    maximize: bool,
262*da0073e9SAndroid Build Coastguard Worker):
263*da0073e9SAndroid Build Coastguard Worker    r"""Functional API that performs Adagrad algorithm computation.
264*da0073e9SAndroid Build Coastguard Worker
265*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.optim.Adagrad` for details.
266*da0073e9SAndroid Build Coastguard Worker    """
267*da0073e9SAndroid Build Coastguard Worker    if not all(isinstance(t, torch.Tensor) for t in state_steps):
268*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
269*da0073e9SAndroid Build Coastguard Worker            "API has changed, `state_steps` argument must contain a list of singleton tensors"
270*da0073e9SAndroid Build Coastguard Worker        )
271*da0073e9SAndroid Build Coastguard Worker
272*da0073e9SAndroid Build Coastguard Worker    # Respect when the user inputs False/True for foreach or fused. We only want to change
273*da0073e9SAndroid Build Coastguard Worker    # the default when neither have been user-specified. Note that we default to foreach
274*da0073e9SAndroid Build Coastguard Worker    # and pass False to use_fused. This is not a mistake--we want to give the fused impl
275*da0073e9SAndroid Build Coastguard Worker    # bake-in time before making it the default, even if it is typically faster.
276*da0073e9SAndroid Build Coastguard Worker    if fused is None and foreach is None:
277*da0073e9SAndroid Build Coastguard Worker        _, foreach = _default_to_fused_or_foreach(
278*da0073e9SAndroid Build Coastguard Worker            params, differentiable, use_fused=False
279*da0073e9SAndroid Build Coastguard Worker        )
280*da0073e9SAndroid Build Coastguard Worker
281*da0073e9SAndroid Build Coastguard Worker    if fused is None:
282*da0073e9SAndroid Build Coastguard Worker        fused = False
283*da0073e9SAndroid Build Coastguard Worker    if foreach is None:
284*da0073e9SAndroid Build Coastguard Worker        foreach = False
285*da0073e9SAndroid Build Coastguard Worker
286*da0073e9SAndroid Build Coastguard Worker    if foreach and torch.jit.is_scripting():
287*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError("torch.jit.script not supported with foreach optimizers")
288*da0073e9SAndroid Build Coastguard Worker    if fused and torch.jit.is_scripting():
289*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError("torch.jit.script not supported with fused optimizers")
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Worker    if fused and not torch.jit.is_scripting():
292*da0073e9SAndroid Build Coastguard Worker        func = _fused_adagrad
293*da0073e9SAndroid Build Coastguard Worker    elif foreach and not torch.jit.is_scripting():
294*da0073e9SAndroid Build Coastguard Worker        func = _multi_tensor_adagrad
295*da0073e9SAndroid Build Coastguard Worker    else:
296*da0073e9SAndroid Build Coastguard Worker        func = _single_tensor_adagrad
297*da0073e9SAndroid Build Coastguard Worker
298*da0073e9SAndroid Build Coastguard Worker    func(
299*da0073e9SAndroid Build Coastguard Worker        params,
300*da0073e9SAndroid Build Coastguard Worker        grads,
301*da0073e9SAndroid Build Coastguard Worker        state_sums,
302*da0073e9SAndroid Build Coastguard Worker        state_steps,
303*da0073e9SAndroid Build Coastguard Worker        lr=lr,
304*da0073e9SAndroid Build Coastguard Worker        weight_decay=weight_decay,
305*da0073e9SAndroid Build Coastguard Worker        lr_decay=lr_decay,
306*da0073e9SAndroid Build Coastguard Worker        eps=eps,
307*da0073e9SAndroid Build Coastguard Worker        has_sparse_grad=has_sparse_grad,
308*da0073e9SAndroid Build Coastguard Worker        maximize=maximize,
309*da0073e9SAndroid Build Coastguard Worker        differentiable=differentiable,
310*da0073e9SAndroid Build Coastguard Worker        has_complex=has_complex,
311*da0073e9SAndroid Build Coastguard Worker        grad_scale=grad_scale,
312*da0073e9SAndroid Build Coastguard Worker        found_inf=found_inf,
313*da0073e9SAndroid Build Coastguard Worker    )
314*da0073e9SAndroid Build Coastguard Worker
315*da0073e9SAndroid Build Coastguard Worker
316*da0073e9SAndroid Build Coastguard Workerdef _make_sparse(grad, grad_indices, values):
317*da0073e9SAndroid Build Coastguard Worker    size = grad.size()
318*da0073e9SAndroid Build Coastguard Worker    return torch.sparse_coo_tensor(grad_indices, values, size)
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Worker
321*da0073e9SAndroid Build Coastguard Workerdef _single_tensor_adagrad(
322*da0073e9SAndroid Build Coastguard Worker    params: List[Tensor],
323*da0073e9SAndroid Build Coastguard Worker    grads: List[Tensor],
324*da0073e9SAndroid Build Coastguard Worker    state_sums: List[Tensor],
325*da0073e9SAndroid Build Coastguard Worker    state_steps: List[Tensor],
326*da0073e9SAndroid Build Coastguard Worker    grad_scale: Optional[Tensor],
327*da0073e9SAndroid Build Coastguard Worker    found_inf: Optional[Tensor],
328*da0073e9SAndroid Build Coastguard Worker    *,
329*da0073e9SAndroid Build Coastguard Worker    lr: float,
330*da0073e9SAndroid Build Coastguard Worker    weight_decay: float,
331*da0073e9SAndroid Build Coastguard Worker    lr_decay: float,
332*da0073e9SAndroid Build Coastguard Worker    eps: float,
333*da0073e9SAndroid Build Coastguard Worker    has_sparse_grad: bool,
334*da0073e9SAndroid Build Coastguard Worker    maximize: bool,
335*da0073e9SAndroid Build Coastguard Worker    differentiable: bool,
336*da0073e9SAndroid Build Coastguard Worker    has_complex: bool,
337*da0073e9SAndroid Build Coastguard Worker):
338*da0073e9SAndroid Build Coastguard Worker    assert grad_scale is None and found_inf is None
339*da0073e9SAndroid Build Coastguard Worker    for param, grad, state_sum, step_t in zip(params, grads, state_sums, state_steps):
340*da0073e9SAndroid Build Coastguard Worker        # update step
341*da0073e9SAndroid Build Coastguard Worker        step_t += 1
342*da0073e9SAndroid Build Coastguard Worker        step = _get_value(step_t)
343*da0073e9SAndroid Build Coastguard Worker        grad = grad if not maximize else -grad
344*da0073e9SAndroid Build Coastguard Worker
345*da0073e9SAndroid Build Coastguard Worker        if weight_decay != 0:
346*da0073e9SAndroid Build Coastguard Worker            if grad.is_sparse:
347*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
348*da0073e9SAndroid Build Coastguard Worker                    "weight_decay option is not compatible with sparse gradients"
349*da0073e9SAndroid Build Coastguard Worker                )
350*da0073e9SAndroid Build Coastguard Worker            grad = grad.add(param, alpha=weight_decay)
351*da0073e9SAndroid Build Coastguard Worker
352*da0073e9SAndroid Build Coastguard Worker        clr = lr / (1 + (step - 1) * lr_decay)
353*da0073e9SAndroid Build Coastguard Worker
354*da0073e9SAndroid Build Coastguard Worker        if grad.is_sparse:
355*da0073e9SAndroid Build Coastguard Worker            grad = grad.coalesce()  # the update is non-linear so indices must be unique
356*da0073e9SAndroid Build Coastguard Worker            grad_indices = grad._indices()
357*da0073e9SAndroid Build Coastguard Worker            grad_values = grad._values()
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Worker            state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2)))
360*da0073e9SAndroid Build Coastguard Worker            std = state_sum.sparse_mask(grad)
361*da0073e9SAndroid Build Coastguard Worker            std_values = std._values().sqrt_().add_(eps)
362*da0073e9SAndroid Build Coastguard Worker            param.add_(
363*da0073e9SAndroid Build Coastguard Worker                _make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr
364*da0073e9SAndroid Build Coastguard Worker            )
365*da0073e9SAndroid Build Coastguard Worker        else:
366*da0073e9SAndroid Build Coastguard Worker            is_complex = torch.is_complex(param)
367*da0073e9SAndroid Build Coastguard Worker            if is_complex:
368*da0073e9SAndroid Build Coastguard Worker                grad = torch.view_as_real(grad)
369*da0073e9SAndroid Build Coastguard Worker                state_sum = torch.view_as_real(state_sum)
370*da0073e9SAndroid Build Coastguard Worker                param = torch.view_as_real(param)
371*da0073e9SAndroid Build Coastguard Worker            state_sum.addcmul_(grad, grad, value=1)
372*da0073e9SAndroid Build Coastguard Worker            if differentiable:
373*da0073e9SAndroid Build Coastguard Worker                std = state_sum.sqrt() + eps
374*da0073e9SAndroid Build Coastguard Worker            else:
375*da0073e9SAndroid Build Coastguard Worker                std = state_sum.sqrt().add_(eps)
376*da0073e9SAndroid Build Coastguard Worker            param.addcdiv_(grad, std, value=-clr)
377*da0073e9SAndroid Build Coastguard Worker            if is_complex:
378*da0073e9SAndroid Build Coastguard Worker                param = torch.view_as_complex(param)
379*da0073e9SAndroid Build Coastguard Worker                state_sum = torch.view_as_complex(state_sum)
380*da0073e9SAndroid Build Coastguard Worker
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Workerdef _multi_tensor_adagrad(
383*da0073e9SAndroid Build Coastguard Worker    params: List[Tensor],
384*da0073e9SAndroid Build Coastguard Worker    grads: List[Tensor],
385*da0073e9SAndroid Build Coastguard Worker    state_sums: List[Tensor],
386*da0073e9SAndroid Build Coastguard Worker    state_steps: List[Tensor],
387*da0073e9SAndroid Build Coastguard Worker    grad_scale: Optional[Tensor],
388*da0073e9SAndroid Build Coastguard Worker    found_inf: Optional[Tensor],
389*da0073e9SAndroid Build Coastguard Worker    *,
390*da0073e9SAndroid Build Coastguard Worker    lr: float,
391*da0073e9SAndroid Build Coastguard Worker    weight_decay: float,
392*da0073e9SAndroid Build Coastguard Worker    lr_decay: float,
393*da0073e9SAndroid Build Coastguard Worker    eps: float,
394*da0073e9SAndroid Build Coastguard Worker    has_sparse_grad: bool,
395*da0073e9SAndroid Build Coastguard Worker    maximize: bool,
396*da0073e9SAndroid Build Coastguard Worker    differentiable: bool,
397*da0073e9SAndroid Build Coastguard Worker    has_complex: bool,
398*da0073e9SAndroid Build Coastguard Worker):
399*da0073e9SAndroid Build Coastguard Worker    assert not differentiable, "_foreach ops don't support autograd"
400*da0073e9SAndroid Build Coastguard Worker    assert grad_scale is None and found_inf is None
401*da0073e9SAndroid Build Coastguard Worker
402*da0073e9SAndroid Build Coastguard Worker    # Foreach functions will throw errors if given empty lists
403*da0073e9SAndroid Build Coastguard Worker    if len(params) == 0:
404*da0073e9SAndroid Build Coastguard Worker        return
405*da0073e9SAndroid Build Coastguard Worker
406*da0073e9SAndroid Build Coastguard Worker    grouped_tensorlists = Optimizer._group_tensors_by_device_and_dtype(
407*da0073e9SAndroid Build Coastguard Worker        [params, grads, state_sums, state_steps]  # type: ignore[list-item]
408*da0073e9SAndroid Build Coastguard Worker    )
409*da0073e9SAndroid Build Coastguard Worker    for (
410*da0073e9SAndroid Build Coastguard Worker        device_params_,
411*da0073e9SAndroid Build Coastguard Worker        device_grads_,
412*da0073e9SAndroid Build Coastguard Worker        device_state_sums_,
413*da0073e9SAndroid Build Coastguard Worker        device_state_steps_,
414*da0073e9SAndroid Build Coastguard Worker    ), _ in grouped_tensorlists.values():
415*da0073e9SAndroid Build Coastguard Worker        device_params = cast(List[Tensor], device_params_)
416*da0073e9SAndroid Build Coastguard Worker        device_grads = cast(List[Tensor], device_grads_)
417*da0073e9SAndroid Build Coastguard Worker        device_state_sums = cast(List[Tensor], device_state_sums_)
418*da0073e9SAndroid Build Coastguard Worker        device_state_steps = cast(List[Tensor], device_state_steps_)
419*da0073e9SAndroid Build Coastguard Worker
420*da0073e9SAndroid Build Coastguard Worker        device_has_sparse_grad = has_sparse_grad and any(
421*da0073e9SAndroid Build Coastguard Worker            grad.is_sparse for grad in device_grads
422*da0073e9SAndroid Build Coastguard Worker        )
423*da0073e9SAndroid Build Coastguard Worker
424*da0073e9SAndroid Build Coastguard Worker        if device_has_sparse_grad:
425*da0073e9SAndroid Build Coastguard Worker            _single_tensor_adagrad(
426*da0073e9SAndroid Build Coastguard Worker                device_params,
427*da0073e9SAndroid Build Coastguard Worker                device_grads,
428*da0073e9SAndroid Build Coastguard Worker                device_state_sums,
429*da0073e9SAndroid Build Coastguard Worker                device_state_steps,
430*da0073e9SAndroid Build Coastguard Worker                lr=lr,
431*da0073e9SAndroid Build Coastguard Worker                weight_decay=weight_decay,
432*da0073e9SAndroid Build Coastguard Worker                lr_decay=lr_decay,
433*da0073e9SAndroid Build Coastguard Worker                eps=eps,
434*da0073e9SAndroid Build Coastguard Worker                has_sparse_grad=True,
435*da0073e9SAndroid Build Coastguard Worker                maximize=maximize,
436*da0073e9SAndroid Build Coastguard Worker                differentiable=differentiable,
437*da0073e9SAndroid Build Coastguard Worker                has_complex=has_complex,
438*da0073e9SAndroid Build Coastguard Worker                grad_scale=grad_scale,
439*da0073e9SAndroid Build Coastguard Worker                found_inf=found_inf,
440*da0073e9SAndroid Build Coastguard Worker            )
441*da0073e9SAndroid Build Coastguard Worker            continue
442*da0073e9SAndroid Build Coastguard Worker
443*da0073e9SAndroid Build Coastguard Worker        # Handle complex parameters
444*da0073e9SAndroid Build Coastguard Worker        if has_complex:
445*da0073e9SAndroid Build Coastguard Worker            _view_as_real(device_params, device_grads, device_state_sums)
446*da0073e9SAndroid Build Coastguard Worker
447*da0073e9SAndroid Build Coastguard Worker        if maximize:
448*da0073e9SAndroid Build Coastguard Worker            device_grads = torch._foreach_neg(device_grads)  # type: ignore[assignment]
449*da0073e9SAndroid Build Coastguard Worker
450*da0073e9SAndroid Build Coastguard Worker        # Update steps
451*da0073e9SAndroid Build Coastguard Worker        # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
452*da0073e9SAndroid Build Coastguard Worker        # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
453*da0073e9SAndroid Build Coastguard Worker        # wrapped it once now. The alpha is required to assure we go to the right overload.
454*da0073e9SAndroid Build Coastguard Worker        if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
455*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(
456*da0073e9SAndroid Build Coastguard Worker                device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
457*da0073e9SAndroid Build Coastguard Worker            )
458*da0073e9SAndroid Build Coastguard Worker        else:
459*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(device_state_steps, 1)
460*da0073e9SAndroid Build Coastguard Worker
461*da0073e9SAndroid Build Coastguard Worker        if weight_decay != 0:
462*da0073e9SAndroid Build Coastguard Worker            # Re-use the intermediate memory (device_grads) already allocated for maximize
463*da0073e9SAndroid Build Coastguard Worker            if maximize:
464*da0073e9SAndroid Build Coastguard Worker                torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
465*da0073e9SAndroid Build Coastguard Worker            else:
466*da0073e9SAndroid Build Coastguard Worker                device_grads = torch._foreach_add(  # type: ignore[assignment]
467*da0073e9SAndroid Build Coastguard Worker                    device_grads, device_params, alpha=weight_decay
468*da0073e9SAndroid Build Coastguard Worker                )
469*da0073e9SAndroid Build Coastguard Worker
470*da0073e9SAndroid Build Coastguard Worker        minus_clr = [
471*da0073e9SAndroid Build Coastguard Worker            -lr / (1 + (_get_value(step) - 1) * lr_decay) for step in device_state_steps
472*da0073e9SAndroid Build Coastguard Worker        ]
473*da0073e9SAndroid Build Coastguard Worker
474*da0073e9SAndroid Build Coastguard Worker        torch._foreach_addcmul_(device_state_sums, device_grads, device_grads, value=1)
475*da0073e9SAndroid Build Coastguard Worker
476*da0073e9SAndroid Build Coastguard Worker        std = torch._foreach_sqrt(device_state_sums)
477*da0073e9SAndroid Build Coastguard Worker        torch._foreach_add_(std, eps)
478*da0073e9SAndroid Build Coastguard Worker
479*da0073e9SAndroid Build Coastguard Worker        if weight_decay != 0 or maximize:
480*da0073e9SAndroid Build Coastguard Worker            # Again, re-use the intermediate memory (device_grads) already allocated
481*da0073e9SAndroid Build Coastguard Worker            torch._foreach_mul_(device_grads, minus_clr)
482*da0073e9SAndroid Build Coastguard Worker            numerator = device_grads
483*da0073e9SAndroid Build Coastguard Worker        else:
484*da0073e9SAndroid Build Coastguard Worker            numerator = torch._foreach_mul(device_grads, minus_clr)  # type: ignore[assignment]
485*da0073e9SAndroid Build Coastguard Worker
486*da0073e9SAndroid Build Coastguard Worker        torch._foreach_addcdiv_(device_params, numerator, std)
487*da0073e9SAndroid Build Coastguard Worker
488*da0073e9SAndroid Build Coastguard Worker
489*da0073e9SAndroid Build Coastguard Workerdef _fused_adagrad(
490*da0073e9SAndroid Build Coastguard Worker    params: List[Tensor],
491*da0073e9SAndroid Build Coastguard Worker    grads: List[Tensor],
492*da0073e9SAndroid Build Coastguard Worker    state_sums: List[Tensor],
493*da0073e9SAndroid Build Coastguard Worker    state_steps: List[Tensor],
494*da0073e9SAndroid Build Coastguard Worker    grad_scale: Optional[Tensor],
495*da0073e9SAndroid Build Coastguard Worker    found_inf: Optional[Tensor],
496*da0073e9SAndroid Build Coastguard Worker    *,
497*da0073e9SAndroid Build Coastguard Worker    lr: float,
498*da0073e9SAndroid Build Coastguard Worker    weight_decay: float,
499*da0073e9SAndroid Build Coastguard Worker    lr_decay: float,
500*da0073e9SAndroid Build Coastguard Worker    eps: float,
501*da0073e9SAndroid Build Coastguard Worker    has_sparse_grad: bool,
502*da0073e9SAndroid Build Coastguard Worker    maximize: bool,
503*da0073e9SAndroid Build Coastguard Worker    differentiable: bool,
504*da0073e9SAndroid Build Coastguard Worker    has_complex: bool,
505*da0073e9SAndroid Build Coastguard Worker) -> None:
506*da0073e9SAndroid Build Coastguard Worker    if not params:
507*da0073e9SAndroid Build Coastguard Worker        return
508*da0073e9SAndroid Build Coastguard Worker    if has_sparse_grad or has_complex:
509*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError("`fused` does not support sparse grad or complex param")
510*da0073e9SAndroid Build Coastguard Worker
511*da0073e9SAndroid Build Coastguard Worker    if differentiable:
512*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
513*da0073e9SAndroid Build Coastguard Worker            "adagrad with fused=True does not support differentiable=True"
514*da0073e9SAndroid Build Coastguard Worker        )
515*da0073e9SAndroid Build Coastguard Worker
516*da0073e9SAndroid Build Coastguard Worker    grad_scale_dict = (
517*da0073e9SAndroid Build Coastguard Worker        {grad_scale.device: grad_scale} if grad_scale is not None else None
518*da0073e9SAndroid Build Coastguard Worker    )
519*da0073e9SAndroid Build Coastguard Worker    found_inf_dict = {found_inf.device: found_inf} if found_inf is not None else None
520*da0073e9SAndroid Build Coastguard Worker
521*da0073e9SAndroid Build Coastguard Worker    grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
522*da0073e9SAndroid Build Coastguard Worker        [params, grads, state_sums, state_steps]  # type: ignore[list-item]
523*da0073e9SAndroid Build Coastguard Worker    )
524*da0073e9SAndroid Build Coastguard Worker    for (device, _), (
525*da0073e9SAndroid Build Coastguard Worker        (
526*da0073e9SAndroid Build Coastguard Worker            device_params_,
527*da0073e9SAndroid Build Coastguard Worker            device_grads_,
528*da0073e9SAndroid Build Coastguard Worker            device_state_sums_,
529*da0073e9SAndroid Build Coastguard Worker            device_state_steps_,
530*da0073e9SAndroid Build Coastguard Worker        ),
531*da0073e9SAndroid Build Coastguard Worker        _,
532*da0073e9SAndroid Build Coastguard Worker    ) in grouped_tensors.items():
533*da0073e9SAndroid Build Coastguard Worker        device_params = cast(List[Tensor], device_params_)
534*da0073e9SAndroid Build Coastguard Worker        device_grads = cast(List[Tensor], device_grads_)
535*da0073e9SAndroid Build Coastguard Worker        device_state_sums = cast(List[Tensor], device_state_sums_)
536*da0073e9SAndroid Build Coastguard Worker        device_state_steps = cast(List[Tensor], device_state_steps_)
537*da0073e9SAndroid Build Coastguard Worker
538*da0073e9SAndroid Build Coastguard Worker        device_grad_scale, device_found_inf = None, None
539*da0073e9SAndroid Build Coastguard Worker        if grad_scale is not None and grad_scale_dict is not None:
540*da0073e9SAndroid Build Coastguard Worker            if device not in grad_scale_dict:
541*da0073e9SAndroid Build Coastguard Worker                grad_scale_dict[device] = grad_scale.to(device, non_blocking=True)  # type: ignore[index]
542*da0073e9SAndroid Build Coastguard Worker            device_grad_scale = grad_scale_dict[device]  # type: ignore[index]
543*da0073e9SAndroid Build Coastguard Worker        if found_inf is not None and found_inf_dict is not None:
544*da0073e9SAndroid Build Coastguard Worker            if found_inf not in found_inf_dict:
545*da0073e9SAndroid Build Coastguard Worker                found_inf_dict[device] = found_inf.to(device, non_blocking=True)  # type: ignore[index]
546*da0073e9SAndroid Build Coastguard Worker            device_found_inf = found_inf_dict[device]  # type: ignore[index]
547*da0073e9SAndroid Build Coastguard Worker        torch._foreach_add_(device_state_steps, 1)
548*da0073e9SAndroid Build Coastguard Worker        torch._fused_adagrad_(
549*da0073e9SAndroid Build Coastguard Worker            device_params,
550*da0073e9SAndroid Build Coastguard Worker            device_grads,
551*da0073e9SAndroid Build Coastguard Worker            device_state_sums,
552*da0073e9SAndroid Build Coastguard Worker            device_state_steps,
553*da0073e9SAndroid Build Coastguard Worker            lr=lr,
554*da0073e9SAndroid Build Coastguard Worker            lr_decay=lr_decay,
555*da0073e9SAndroid Build Coastguard Worker            weight_decay=weight_decay,
556*da0073e9SAndroid Build Coastguard Worker            eps=eps,
557*da0073e9SAndroid Build Coastguard Worker            maximize=maximize,
558*da0073e9SAndroid Build Coastguard Worker            grad_scale=device_grad_scale,
559*da0073e9SAndroid Build Coastguard Worker            found_inf=device_found_inf,
560*da0073e9SAndroid Build Coastguard Worker        )
561*da0073e9SAndroid Build Coastguard Worker        if device_found_inf is not None:
562*da0073e9SAndroid Build Coastguard Worker            torch._foreach_sub_(
563*da0073e9SAndroid Build Coastguard Worker                device_state_steps, [device_found_inf] * len(device_state_steps)
564*da0073e9SAndroid Build Coastguard Worker            )
565