xref: /aosp_15_r20/external/pytorch/torch/optim/adadelta.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-decorators
2*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
3*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, cast, Dict, List, Optional, Union
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerimport torch
6*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerfrom .optimizer import (
9*da0073e9SAndroid Build Coastguard Worker    _capturable_doc,
10*da0073e9SAndroid Build Coastguard Worker    _default_to_fused_or_foreach,
11*da0073e9SAndroid Build Coastguard Worker    _differentiable_doc,
12*da0073e9SAndroid Build Coastguard Worker    _disable_dynamo_if_unsupported,
13*da0073e9SAndroid Build Coastguard Worker    _foreach_doc,
14*da0073e9SAndroid Build Coastguard Worker    _get_capturable_supported_devices,
15*da0073e9SAndroid Build Coastguard Worker    _get_scalar_dtype,
16*da0073e9SAndroid Build Coastguard Worker    _maximize_doc,
17*da0073e9SAndroid Build Coastguard Worker    _use_grad_for_differentiable,
18*da0073e9SAndroid Build Coastguard Worker    _view_as_real,
19*da0073e9SAndroid Build Coastguard Worker    Optimizer,
20*da0073e9SAndroid Build Coastguard Worker    ParamsT,
21*da0073e9SAndroid Build Coastguard Worker)
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker__all__ = ["Adadelta", "adadelta"]
25*da0073e9SAndroid Build Coastguard Worker
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Workerclass Adadelta(Optimizer):
28*da0073e9SAndroid Build Coastguard Worker    def __init__(
29*da0073e9SAndroid Build Coastguard Worker        self,
30*da0073e9SAndroid Build Coastguard Worker        params: ParamsT,
31*da0073e9SAndroid Build Coastguard Worker        lr: Union[float, Tensor] = 1.0,
32*da0073e9SAndroid Build Coastguard Worker        rho: float = 0.9,
33*da0073e9SAndroid Build Coastguard Worker        eps: float = 1e-6,
34*da0073e9SAndroid Build Coastguard Worker        weight_decay: float = 0,
35*da0073e9SAndroid Build Coastguard Worker        foreach: Optional[bool] = None,
36*da0073e9SAndroid Build Coastguard Worker        *,
37*da0073e9SAndroid Build Coastguard Worker        capturable: bool = False,
38*da0073e9SAndroid Build Coastguard Worker        maximize: bool = False,
39*da0073e9SAndroid Build Coastguard Worker        differentiable: bool = False,
40*da0073e9SAndroid Build Coastguard Worker    ):
41*da0073e9SAndroid Build Coastguard Worker        if isinstance(lr, Tensor) and lr.numel() != 1:
42*da0073e9SAndroid Build Coastguard Worker            raise ValueError("Tensor lr must be 1-element")
43*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= lr:
44*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid learning rate: {lr}")
45*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= rho <= 1.0:
46*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid rho value: {rho}")
47*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= eps:
48*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid epsilon value: {eps}")
49*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= weight_decay:
50*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker        defaults = dict(
53*da0073e9SAndroid Build Coastguard Worker            lr=lr,
54*da0073e9SAndroid Build Coastguard Worker            rho=rho,
55*da0073e9SAndroid Build Coastguard Worker            eps=eps,
56*da0073e9SAndroid Build Coastguard Worker            weight_decay=weight_decay,
57*da0073e9SAndroid Build Coastguard Worker            maximize=maximize,
58*da0073e9SAndroid Build Coastguard Worker            capturable=capturable,
59*da0073e9SAndroid Build Coastguard Worker            foreach=foreach,
60*da0073e9SAndroid Build Coastguard Worker            differentiable=differentiable,
61*da0073e9SAndroid Build Coastguard Worker        )
62*da0073e9SAndroid Build Coastguard Worker        super().__init__(params, defaults)
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker    def __setstate__(self, state):
65*da0073e9SAndroid Build Coastguard Worker        super().__setstate__(state)
66*da0073e9SAndroid Build Coastguard Worker        for group in self.param_groups:
67*da0073e9SAndroid Build Coastguard Worker            group.setdefault("foreach", None)
68*da0073e9SAndroid Build Coastguard Worker            group.setdefault("maximize", False)
69*da0073e9SAndroid Build Coastguard Worker            group.setdefault("differentiable", False)
70*da0073e9SAndroid Build Coastguard Worker            group.setdefault("capturable", False)
71*da0073e9SAndroid Build Coastguard Worker            for p in group["params"]:
72*da0073e9SAndroid Build Coastguard Worker                p_state = self.state.get(p, [])
73*da0073e9SAndroid Build Coastguard Worker                if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
74*da0073e9SAndroid Build Coastguard Worker                    step_val = float(p_state["step"])
75*da0073e9SAndroid Build Coastguard Worker                    p_state["step"] = (
76*da0073e9SAndroid Build Coastguard Worker                        torch.tensor(
77*da0073e9SAndroid Build Coastguard Worker                            step_val, dtype=_get_scalar_dtype(), device=p.device
78*da0073e9SAndroid Build Coastguard Worker                        )
79*da0073e9SAndroid Build Coastguard Worker                        if group["capturable"]
80*da0073e9SAndroid Build Coastguard Worker                        else torch.tensor(step_val, dtype=_get_scalar_dtype())
81*da0073e9SAndroid Build Coastguard Worker                    )
82*da0073e9SAndroid Build Coastguard Worker
83*da0073e9SAndroid Build Coastguard Worker    def _init_group(
84*da0073e9SAndroid Build Coastguard Worker        self,
85*da0073e9SAndroid Build Coastguard Worker        group: Dict[str, Any],
86*da0073e9SAndroid Build Coastguard Worker        params_with_grad: List[Tensor],
87*da0073e9SAndroid Build Coastguard Worker        grads: List[Tensor],
88*da0073e9SAndroid Build Coastguard Worker        square_avgs: List[Tensor],
89*da0073e9SAndroid Build Coastguard Worker        acc_deltas: List[Tensor],
90*da0073e9SAndroid Build Coastguard Worker        state_steps: List[Tensor],
91*da0073e9SAndroid Build Coastguard Worker    ):
92*da0073e9SAndroid Build Coastguard Worker        has_complex = False
93*da0073e9SAndroid Build Coastguard Worker        p: Tensor
94*da0073e9SAndroid Build Coastguard Worker        for p in group["params"]:
95*da0073e9SAndroid Build Coastguard Worker            if p.grad is None:
96*da0073e9SAndroid Build Coastguard Worker                continue
97*da0073e9SAndroid Build Coastguard Worker            has_complex |= torch.is_complex(p)
98*da0073e9SAndroid Build Coastguard Worker            params_with_grad.append(p)
99*da0073e9SAndroid Build Coastguard Worker            if p.grad.is_sparse:
100*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError("Adadelta does not support sparse gradients")
101*da0073e9SAndroid Build Coastguard Worker            grads.append(p.grad)
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker            state = self.state[p]
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker            # Lazy state initialization
106*da0073e9SAndroid Build Coastguard Worker            if len(state) == 0:
107*da0073e9SAndroid Build Coastguard Worker                state["step"] = (
108*da0073e9SAndroid Build Coastguard Worker                    torch.zeros((), dtype=_get_scalar_dtype(), device=p.device)
109*da0073e9SAndroid Build Coastguard Worker                    if group["capturable"]
110*da0073e9SAndroid Build Coastguard Worker                    else torch.zeros((), dtype=_get_scalar_dtype())
111*da0073e9SAndroid Build Coastguard Worker                )
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker                state["square_avg"] = torch.zeros_like(
114*da0073e9SAndroid Build Coastguard Worker                    p, memory_format=torch.preserve_format
115*da0073e9SAndroid Build Coastguard Worker                )
116*da0073e9SAndroid Build Coastguard Worker                state["acc_delta"] = torch.zeros_like(
117*da0073e9SAndroid Build Coastguard Worker                    p, memory_format=torch.preserve_format
118*da0073e9SAndroid Build Coastguard Worker                )
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker            square_avgs.append(state["square_avg"])
121*da0073e9SAndroid Build Coastguard Worker            acc_deltas.append(state["acc_delta"])
122*da0073e9SAndroid Build Coastguard Worker            state_steps.append(state["step"])
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker        return has_complex
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker    @_use_grad_for_differentiable
127*da0073e9SAndroid Build Coastguard Worker    def step(self, closure=None):
128*da0073e9SAndroid Build Coastguard Worker        """Perform a single optimization step.
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker        Args:
131*da0073e9SAndroid Build Coastguard Worker            closure (Callable, optional): A closure that reevaluates the model
132*da0073e9SAndroid Build Coastguard Worker                and returns the loss.
133*da0073e9SAndroid Build Coastguard Worker        """
134*da0073e9SAndroid Build Coastguard Worker        self._cuda_graph_capture_health_check()
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Worker        loss = None
137*da0073e9SAndroid Build Coastguard Worker        if closure is not None:
138*da0073e9SAndroid Build Coastguard Worker            with torch.enable_grad():
139*da0073e9SAndroid Build Coastguard Worker                loss = closure()
140*da0073e9SAndroid Build Coastguard Worker
141*da0073e9SAndroid Build Coastguard Worker        for group in self.param_groups:
142*da0073e9SAndroid Build Coastguard Worker            params_with_grad: List[Tensor] = []
143*da0073e9SAndroid Build Coastguard Worker            grads: List[Tensor] = []
144*da0073e9SAndroid Build Coastguard Worker            square_avgs: List[Tensor] = []
145*da0073e9SAndroid Build Coastguard Worker            acc_deltas: List[Tensor] = []
146*da0073e9SAndroid Build Coastguard Worker            state_steps: List[Tensor] = []
147*da0073e9SAndroid Build Coastguard Worker            (
148*da0073e9SAndroid Build Coastguard Worker                lr,
149*da0073e9SAndroid Build Coastguard Worker                rho,
150*da0073e9SAndroid Build Coastguard Worker                eps,
151*da0073e9SAndroid Build Coastguard Worker                weight_decay,
152*da0073e9SAndroid Build Coastguard Worker                foreach,
153*da0073e9SAndroid Build Coastguard Worker                maximize,
154*da0073e9SAndroid Build Coastguard Worker                differentiable,
155*da0073e9SAndroid Build Coastguard Worker                capturable,
156*da0073e9SAndroid Build Coastguard Worker            ) = (
157*da0073e9SAndroid Build Coastguard Worker                group["lr"],
158*da0073e9SAndroid Build Coastguard Worker                group["rho"],
159*da0073e9SAndroid Build Coastguard Worker                group["eps"],
160*da0073e9SAndroid Build Coastguard Worker                group["weight_decay"],
161*da0073e9SAndroid Build Coastguard Worker                group["foreach"],
162*da0073e9SAndroid Build Coastguard Worker                group["maximize"],
163*da0073e9SAndroid Build Coastguard Worker                group["differentiable"],
164*da0073e9SAndroid Build Coastguard Worker                group["capturable"],
165*da0073e9SAndroid Build Coastguard Worker            )
166*da0073e9SAndroid Build Coastguard Worker
167*da0073e9SAndroid Build Coastguard Worker            has_complex = self._init_group(
168*da0073e9SAndroid Build Coastguard Worker                group, params_with_grad, grads, square_avgs, acc_deltas, state_steps
169*da0073e9SAndroid Build Coastguard Worker            )
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard Worker            adadelta(
172*da0073e9SAndroid Build Coastguard Worker                params_with_grad,
173*da0073e9SAndroid Build Coastguard Worker                grads,
174*da0073e9SAndroid Build Coastguard Worker                square_avgs,
175*da0073e9SAndroid Build Coastguard Worker                acc_deltas,
176*da0073e9SAndroid Build Coastguard Worker                state_steps,
177*da0073e9SAndroid Build Coastguard Worker                lr=lr,
178*da0073e9SAndroid Build Coastguard Worker                rho=rho,
179*da0073e9SAndroid Build Coastguard Worker                eps=eps,
180*da0073e9SAndroid Build Coastguard Worker                weight_decay=weight_decay,
181*da0073e9SAndroid Build Coastguard Worker                foreach=foreach,
182*da0073e9SAndroid Build Coastguard Worker                maximize=maximize,
183*da0073e9SAndroid Build Coastguard Worker                differentiable=differentiable,
184*da0073e9SAndroid Build Coastguard Worker                capturable=capturable,
185*da0073e9SAndroid Build Coastguard Worker                has_complex=has_complex,
186*da0073e9SAndroid Build Coastguard Worker            )
187*da0073e9SAndroid Build Coastguard Worker
188*da0073e9SAndroid Build Coastguard Worker        return loss
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker
191*da0073e9SAndroid Build Coastguard WorkerAdadelta.__doc__ = (
192*da0073e9SAndroid Build Coastguard Worker    r"""Implements Adadelta algorithm.
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker    .. math::
195*da0073e9SAndroid Build Coastguard Worker       \begin{aligned}
196*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                                 \\
197*da0073e9SAndroid Build Coastguard Worker            &\textbf{input}      : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)},
198*da0073e9SAndroid Build Coastguard Worker                \: f(\theta) \text{ (objective)}, \: \rho \text{ (decay)},
199*da0073e9SAndroid Build Coastguard Worker                \: \lambda \text{ (weight decay)}                                                \\
200*da0073e9SAndroid Build Coastguard Worker            &\textbf{initialize} :  v_0  \leftarrow 0 \: \text{ (square avg)},
201*da0073e9SAndroid Build Coastguard Worker                \: u_0 \leftarrow 0 \: \text{ (accumulate variables)}                     \\[-1.ex]
202*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                                 \\
203*da0073e9SAndroid Build Coastguard Worker            &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do}                         \\
204*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}g_t           \leftarrow   \nabla_{\theta} f_t (\theta_{t-1})           \\
205*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}if \: \lambda \neq 0                                                    \\
206*da0073e9SAndroid Build Coastguard Worker            &\hspace{10mm} g_t \leftarrow g_t + \lambda  \theta_{t-1}                            \\
207*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm} v_t      \leftarrow v_{t-1} \rho + g^2_t (1 - \rho)                    \\
208*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}\Delta x_t    \leftarrow   \frac{\sqrt{u_{t-1} +
209*da0073e9SAndroid Build Coastguard Worker                \epsilon }}{ \sqrt{v_t + \epsilon}  }g_t \hspace{21mm}                           \\
210*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm} u_t  \leftarrow   u_{t-1}  \rho +
211*da0073e9SAndroid Build Coastguard Worker                 \Delta x^2_t  (1 - \rho)                                                        \\
212*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}\theta_t      \leftarrow   \theta_{t-1} - \gamma  \Delta x_t            \\
213*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
214*da0073e9SAndroid Build Coastguard Worker            &\bf{return} \:  \theta_t                                                     \\[-1.ex]
215*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
216*da0073e9SAndroid Build Coastguard Worker       \end{aligned}
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker    For further details regarding the algorithm we refer to `ADADELTA: An Adaptive Learning Rate Method`_.
219*da0073e9SAndroid Build Coastguard Worker    """
220*da0073e9SAndroid Build Coastguard Worker    + rf"""
221*da0073e9SAndroid Build Coastguard Worker    Args:
222*da0073e9SAndroid Build Coastguard Worker        params (iterable): iterable of parameters to optimize or dicts defining
223*da0073e9SAndroid Build Coastguard Worker            parameter groups
224*da0073e9SAndroid Build Coastguard Worker        rho (float, optional): coefficient used for computing a running average
225*da0073e9SAndroid Build Coastguard Worker            of squared gradients (default: 0.9). A higher value of `rho` will
226*da0073e9SAndroid Build Coastguard Worker            result in a slower average, which can be helpful for preventing
227*da0073e9SAndroid Build Coastguard Worker            oscillations in the learning process.
228*da0073e9SAndroid Build Coastguard Worker        eps (float, optional): term added to the denominator to improve
229*da0073e9SAndroid Build Coastguard Worker            numerical stability (default: 1e-6).
230*da0073e9SAndroid Build Coastguard Worker        lr (float, Tensor, optional): coefficient that scale delta before it is applied
231*da0073e9SAndroid Build Coastguard Worker            to the parameters (default: 1.0)
232*da0073e9SAndroid Build Coastguard Worker        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
233*da0073e9SAndroid Build Coastguard Worker        {_foreach_doc}
234*da0073e9SAndroid Build Coastguard Worker        {_capturable_doc}
235*da0073e9SAndroid Build Coastguard Worker        {_maximize_doc}
236*da0073e9SAndroid Build Coastguard Worker        {_differentiable_doc}
237*da0073e9SAndroid Build Coastguard Worker
238*da0073e9SAndroid Build Coastguard Worker    .. _ADADELTA\: An Adaptive Learning Rate Method:
239*da0073e9SAndroid Build Coastguard Worker        https://arxiv.org/abs/1212.5701
240*da0073e9SAndroid Build Coastguard Worker
241*da0073e9SAndroid Build Coastguard Worker    """
242*da0073e9SAndroid Build Coastguard Worker)
243*da0073e9SAndroid Build Coastguard Worker
244*da0073e9SAndroid Build Coastguard Worker
245*da0073e9SAndroid Build Coastguard Workerdef _single_tensor_adadelta(
246*da0073e9SAndroid Build Coastguard Worker    params: List[Tensor],
247*da0073e9SAndroid Build Coastguard Worker    grads: List[Tensor],
248*da0073e9SAndroid Build Coastguard Worker    square_avgs: List[Tensor],
249*da0073e9SAndroid Build Coastguard Worker    acc_deltas: List[Tensor],
250*da0073e9SAndroid Build Coastguard Worker    state_steps: List[Tensor],
251*da0073e9SAndroid Build Coastguard Worker    *,
252*da0073e9SAndroid Build Coastguard Worker    lr: float,
253*da0073e9SAndroid Build Coastguard Worker    rho: float,
254*da0073e9SAndroid Build Coastguard Worker    eps: float,
255*da0073e9SAndroid Build Coastguard Worker    weight_decay: float,
256*da0073e9SAndroid Build Coastguard Worker    maximize: bool,
257*da0073e9SAndroid Build Coastguard Worker    differentiable: bool,
258*da0073e9SAndroid Build Coastguard Worker    capturable: bool,
259*da0073e9SAndroid Build Coastguard Worker    has_complex: bool,
260*da0073e9SAndroid Build Coastguard Worker):
261*da0073e9SAndroid Build Coastguard Worker    # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
262*da0073e9SAndroid Build Coastguard Worker    if not torch._utils.is_compiling() and capturable:
263*da0073e9SAndroid Build Coastguard Worker        capturable_supported_devices = _get_capturable_supported_devices(
264*da0073e9SAndroid Build Coastguard Worker            supports_xla=False
265*da0073e9SAndroid Build Coastguard Worker        )
266*da0073e9SAndroid Build Coastguard Worker        assert all(
267*da0073e9SAndroid Build Coastguard Worker            p.device.type == step.device.type
268*da0073e9SAndroid Build Coastguard Worker            and p.device.type in capturable_supported_devices
269*da0073e9SAndroid Build Coastguard Worker            for p, step in zip(params, state_steps)
270*da0073e9SAndroid Build Coastguard Worker        ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
271*da0073e9SAndroid Build Coastguard Worker
272*da0073e9SAndroid Build Coastguard Worker    for param, grad, square_avg, acc_delta, step in zip(
273*da0073e9SAndroid Build Coastguard Worker        params, grads, square_avgs, acc_deltas, state_steps
274*da0073e9SAndroid Build Coastguard Worker    ):
275*da0073e9SAndroid Build Coastguard Worker        step += 1
276*da0073e9SAndroid Build Coastguard Worker        grad = grad if not maximize else -grad
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker        if weight_decay != 0:
279*da0073e9SAndroid Build Coastguard Worker            grad = grad.add(param, alpha=weight_decay)
280*da0073e9SAndroid Build Coastguard Worker
281*da0073e9SAndroid Build Coastguard Worker        if torch.is_complex(param):
282*da0073e9SAndroid Build Coastguard Worker            square_avg = torch.view_as_real(square_avg)
283*da0073e9SAndroid Build Coastguard Worker            acc_delta = torch.view_as_real(acc_delta)
284*da0073e9SAndroid Build Coastguard Worker            grad = torch.view_as_real(grad)
285*da0073e9SAndroid Build Coastguard Worker
286*da0073e9SAndroid Build Coastguard Worker        square_avg.mul_(rho).addcmul_(grad, grad, value=1 - rho)
287*da0073e9SAndroid Build Coastguard Worker        std = square_avg.add(eps).sqrt_()
288*da0073e9SAndroid Build Coastguard Worker        delta = acc_delta.add(eps).sqrt_()
289*da0073e9SAndroid Build Coastguard Worker        if differentiable:
290*da0073e9SAndroid Build Coastguard Worker            delta = delta.clone()
291*da0073e9SAndroid Build Coastguard Worker        delta.div_(std).mul_(grad)
292*da0073e9SAndroid Build Coastguard Worker        acc_delta.mul_(rho).addcmul_(delta, delta, value=1 - rho)
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker        if torch.is_complex(param):
295*da0073e9SAndroid Build Coastguard Worker            delta = torch.view_as_complex(delta)
296*da0073e9SAndroid Build Coastguard Worker        param.add_(delta, alpha=-lr)
297*da0073e9SAndroid Build Coastguard Worker
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Workerdef _multi_tensor_adadelta(
300*da0073e9SAndroid Build Coastguard Worker    params: List[Tensor],
301*da0073e9SAndroid Build Coastguard Worker    grads: List[Tensor],
302*da0073e9SAndroid Build Coastguard Worker    square_avgs: List[Tensor],
303*da0073e9SAndroid Build Coastguard Worker    acc_deltas: List[Tensor],
304*da0073e9SAndroid Build Coastguard Worker    state_steps: List[Tensor],
305*da0073e9SAndroid Build Coastguard Worker    *,
306*da0073e9SAndroid Build Coastguard Worker    lr: float,
307*da0073e9SAndroid Build Coastguard Worker    rho: float,
308*da0073e9SAndroid Build Coastguard Worker    eps: float,
309*da0073e9SAndroid Build Coastguard Worker    weight_decay: float,
310*da0073e9SAndroid Build Coastguard Worker    maximize: bool,
311*da0073e9SAndroid Build Coastguard Worker    differentiable: bool,
312*da0073e9SAndroid Build Coastguard Worker    capturable: bool,
313*da0073e9SAndroid Build Coastguard Worker    has_complex: bool,
314*da0073e9SAndroid Build Coastguard Worker):
315*da0073e9SAndroid Build Coastguard Worker    assert not differentiable, "_foreach ops don't support autograd"
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker    # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
318*da0073e9SAndroid Build Coastguard Worker    if not torch._utils.is_compiling() and capturable:
319*da0073e9SAndroid Build Coastguard Worker        capturable_supported_devices = _get_capturable_supported_devices(
320*da0073e9SAndroid Build Coastguard Worker            supports_xla=False
321*da0073e9SAndroid Build Coastguard Worker        )
322*da0073e9SAndroid Build Coastguard Worker        assert all(
323*da0073e9SAndroid Build Coastguard Worker            p.device.type == step.device.type
324*da0073e9SAndroid Build Coastguard Worker            and p.device.type in capturable_supported_devices
325*da0073e9SAndroid Build Coastguard Worker            for p, step in zip(params, state_steps)
326*da0073e9SAndroid Build Coastguard Worker        ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
327*da0073e9SAndroid Build Coastguard Worker
328*da0073e9SAndroid Build Coastguard Worker    if len(params) == 0:
329*da0073e9SAndroid Build Coastguard Worker        return
330*da0073e9SAndroid Build Coastguard Worker
331*da0073e9SAndroid Build Coastguard Worker    grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
332*da0073e9SAndroid Build Coastguard Worker        [params, grads, square_avgs, acc_deltas, state_steps]  # type: ignore[list-item]
333*da0073e9SAndroid Build Coastguard Worker    )
334*da0073e9SAndroid Build Coastguard Worker    for (
335*da0073e9SAndroid Build Coastguard Worker        device_params_,
336*da0073e9SAndroid Build Coastguard Worker        device_grads_,
337*da0073e9SAndroid Build Coastguard Worker        device_square_avgs_,
338*da0073e9SAndroid Build Coastguard Worker        device_acc_deltas_,
339*da0073e9SAndroid Build Coastguard Worker        device_state_steps_,
340*da0073e9SAndroid Build Coastguard Worker    ), _ in grouped_tensors.values():
341*da0073e9SAndroid Build Coastguard Worker        device_params = cast(List[Tensor], device_params_)
342*da0073e9SAndroid Build Coastguard Worker        device_grads = cast(List[Tensor], device_grads_)
343*da0073e9SAndroid Build Coastguard Worker        device_square_avgs = cast(List[Tensor], device_square_avgs_)
344*da0073e9SAndroid Build Coastguard Worker        device_acc_deltas = cast(List[Tensor], device_acc_deltas_)
345*da0073e9SAndroid Build Coastguard Worker        device_state_steps = cast(List[Tensor], device_state_steps_)
346*da0073e9SAndroid Build Coastguard Worker        if has_complex:
347*da0073e9SAndroid Build Coastguard Worker            _view_as_real(
348*da0073e9SAndroid Build Coastguard Worker                device_params, device_grads, device_square_avgs, device_acc_deltas
349*da0073e9SAndroid Build Coastguard Worker            )
350*da0073e9SAndroid Build Coastguard Worker
351*da0073e9SAndroid Build Coastguard Worker        # Update steps
352*da0073e9SAndroid Build Coastguard Worker        # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
353*da0073e9SAndroid Build Coastguard Worker        # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
354*da0073e9SAndroid Build Coastguard Worker        # wrapped it once now. The alpha is required to assure we go to the right overload.
355*da0073e9SAndroid Build Coastguard Worker        if not torch._utils.is_compiling() and device_state_steps[0].is_cpu:
356*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(
357*da0073e9SAndroid Build Coastguard Worker                device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
358*da0073e9SAndroid Build Coastguard Worker            )
359*da0073e9SAndroid Build Coastguard Worker        else:
360*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(device_state_steps, 1)
361*da0073e9SAndroid Build Coastguard Worker
362*da0073e9SAndroid Build Coastguard Worker        if maximize:
363*da0073e9SAndroid Build Coastguard Worker            device_grads = torch._foreach_neg(device_grads)  # type: ignore[assignment]
364*da0073e9SAndroid Build Coastguard Worker
365*da0073e9SAndroid Build Coastguard Worker        if weight_decay != 0:
366*da0073e9SAndroid Build Coastguard Worker            # Re-use the intermediate memory (device_grads) already allocated for maximize
367*da0073e9SAndroid Build Coastguard Worker            if maximize:
368*da0073e9SAndroid Build Coastguard Worker                torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
369*da0073e9SAndroid Build Coastguard Worker            else:
370*da0073e9SAndroid Build Coastguard Worker                device_grads = torch._foreach_add(  # type: ignore[assignment]
371*da0073e9SAndroid Build Coastguard Worker                    device_grads, device_params, alpha=weight_decay
372*da0073e9SAndroid Build Coastguard Worker                )
373*da0073e9SAndroid Build Coastguard Worker
374*da0073e9SAndroid Build Coastguard Worker        torch._foreach_mul_(device_square_avgs, rho)
375*da0073e9SAndroid Build Coastguard Worker        torch._foreach_addcmul_(
376*da0073e9SAndroid Build Coastguard Worker            device_square_avgs, device_grads, device_grads, value=1 - rho
377*da0073e9SAndroid Build Coastguard Worker        )
378*da0073e9SAndroid Build Coastguard Worker
379*da0073e9SAndroid Build Coastguard Worker        std = torch._foreach_add(device_square_avgs, eps)
380*da0073e9SAndroid Build Coastguard Worker        torch._foreach_sqrt_(std)
381*da0073e9SAndroid Build Coastguard Worker
382*da0073e9SAndroid Build Coastguard Worker        deltas = torch._foreach_add(device_acc_deltas, eps)
383*da0073e9SAndroid Build Coastguard Worker        torch._foreach_sqrt_(deltas)
384*da0073e9SAndroid Build Coastguard Worker        torch._foreach_div_(deltas, std)
385*da0073e9SAndroid Build Coastguard Worker        torch._foreach_mul_(deltas, device_grads)
386*da0073e9SAndroid Build Coastguard Worker
387*da0073e9SAndroid Build Coastguard Worker        torch._foreach_mul_(device_acc_deltas, rho)
388*da0073e9SAndroid Build Coastguard Worker        torch._foreach_addcmul_(device_acc_deltas, deltas, deltas, value=1 - rho)
389*da0073e9SAndroid Build Coastguard Worker
390*da0073e9SAndroid Build Coastguard Worker        # If LR is a tensor, the else branch will internally call item()
391*da0073e9SAndroid Build Coastguard Worker        # which will cause silent incorrectness if we are capturing
392*da0073e9SAndroid Build Coastguard Worker        if capturable and isinstance(lr, torch.Tensor):
393*da0073e9SAndroid Build Coastguard Worker            torch._foreach_mul_(deltas, -lr)
394*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(device_params, deltas)
395*da0073e9SAndroid Build Coastguard Worker        else:
396*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(device_params, deltas, alpha=-lr)
397*da0073e9SAndroid Build Coastguard Worker
398*da0073e9SAndroid Build Coastguard Worker
399*da0073e9SAndroid Build Coastguard Worker@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adadelta)
400*da0073e9SAndroid Build Coastguard Workerdef adadelta(
401*da0073e9SAndroid Build Coastguard Worker    params: List[Tensor],
402*da0073e9SAndroid Build Coastguard Worker    grads: List[Tensor],
403*da0073e9SAndroid Build Coastguard Worker    square_avgs: List[Tensor],
404*da0073e9SAndroid Build Coastguard Worker    acc_deltas: List[Tensor],
405*da0073e9SAndroid Build Coastguard Worker    state_steps: List[Tensor],
406*da0073e9SAndroid Build Coastguard Worker    # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
407*da0073e9SAndroid Build Coastguard Worker    # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
408*da0073e9SAndroid Build Coastguard Worker    capturable: bool = False,
409*da0073e9SAndroid Build Coastguard Worker    foreach: Optional[bool] = None,
410*da0073e9SAndroid Build Coastguard Worker    differentiable: bool = False,
411*da0073e9SAndroid Build Coastguard Worker    has_complex: bool = False,
412*da0073e9SAndroid Build Coastguard Worker    *,
413*da0073e9SAndroid Build Coastguard Worker    lr: float,
414*da0073e9SAndroid Build Coastguard Worker    rho: float,
415*da0073e9SAndroid Build Coastguard Worker    eps: float,
416*da0073e9SAndroid Build Coastguard Worker    weight_decay: float,
417*da0073e9SAndroid Build Coastguard Worker    maximize: bool,
418*da0073e9SAndroid Build Coastguard Worker):
419*da0073e9SAndroid Build Coastguard Worker    r"""Functional API that performs Adadelta algorithm computation.
420*da0073e9SAndroid Build Coastguard Worker
421*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.optim.Adadelta` for details.
422*da0073e9SAndroid Build Coastguard Worker    """
423*da0073e9SAndroid Build Coastguard Worker
424*da0073e9SAndroid Build Coastguard Worker    # this check is slow during compilation, so we skip it
425*da0073e9SAndroid Build Coastguard Worker    # if it's strictly needed we can add this check back in dynamo
426*da0073e9SAndroid Build Coastguard Worker    if not torch._utils.is_compiling() and not all(
427*da0073e9SAndroid Build Coastguard Worker        isinstance(t, torch.Tensor) for t in state_steps
428*da0073e9SAndroid Build Coastguard Worker    ):
429*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
430*da0073e9SAndroid Build Coastguard Worker            "API has changed, `state_steps` argument must contain a list of singleton tensors"
431*da0073e9SAndroid Build Coastguard Worker        )
432*da0073e9SAndroid Build Coastguard Worker
433*da0073e9SAndroid Build Coastguard Worker    # We still respect when the user inputs False for foreach.
434*da0073e9SAndroid Build Coastguard Worker    if foreach is None:
435*da0073e9SAndroid Build Coastguard Worker        _, foreach = _default_to_fused_or_foreach(
436*da0073e9SAndroid Build Coastguard Worker            params, differentiable, use_fused=False
437*da0073e9SAndroid Build Coastguard Worker        )
438*da0073e9SAndroid Build Coastguard Worker
439*da0073e9SAndroid Build Coastguard Worker    if foreach and torch.jit.is_scripting():
440*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError("torch.jit.script not supported with foreach optimizers")
441*da0073e9SAndroid Build Coastguard Worker
442*da0073e9SAndroid Build Coastguard Worker    if foreach and not torch.jit.is_scripting():
443*da0073e9SAndroid Build Coastguard Worker        func = _multi_tensor_adadelta
444*da0073e9SAndroid Build Coastguard Worker    else:
445*da0073e9SAndroid Build Coastguard Worker        func = _single_tensor_adadelta
446*da0073e9SAndroid Build Coastguard Worker
447*da0073e9SAndroid Build Coastguard Worker    func(
448*da0073e9SAndroid Build Coastguard Worker        params,
449*da0073e9SAndroid Build Coastguard Worker        grads,
450*da0073e9SAndroid Build Coastguard Worker        square_avgs,
451*da0073e9SAndroid Build Coastguard Worker        acc_deltas,
452*da0073e9SAndroid Build Coastguard Worker        state_steps,
453*da0073e9SAndroid Build Coastguard Worker        lr=lr,
454*da0073e9SAndroid Build Coastguard Worker        rho=rho,
455*da0073e9SAndroid Build Coastguard Worker        eps=eps,
456*da0073e9SAndroid Build Coastguard Worker        weight_decay=weight_decay,
457*da0073e9SAndroid Build Coastguard Worker        maximize=maximize,
458*da0073e9SAndroid Build Coastguard Worker        differentiable=differentiable,
459*da0073e9SAndroid Build Coastguard Worker        capturable=capturable,
460*da0073e9SAndroid Build Coastguard Worker        has_complex=has_complex,
461*da0073e9SAndroid Build Coastguard Worker    )
462