xref: /aosp_15_r20/external/pytorch/torch/optim/rmsprop.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-decorators
2*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
3*da0073e9SAndroid Build Coastguard Workerr"""Implementation for the RMSprop algorithm."""
4*da0073e9SAndroid Build Coastguard Workerfrom typing import cast, List, Optional, Union
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerimport torch
7*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerfrom .optimizer import (
10*da0073e9SAndroid Build Coastguard Worker    _capturable_doc,
11*da0073e9SAndroid Build Coastguard Worker    _default_to_fused_or_foreach,
12*da0073e9SAndroid Build Coastguard Worker    _differentiable_doc,
13*da0073e9SAndroid Build Coastguard Worker    _disable_dynamo_if_unsupported,
14*da0073e9SAndroid Build Coastguard Worker    _foreach_doc,
15*da0073e9SAndroid Build Coastguard Worker    _get_capturable_supported_devices,
16*da0073e9SAndroid Build Coastguard Worker    _get_scalar_dtype,
17*da0073e9SAndroid Build Coastguard Worker    _maximize_doc,
18*da0073e9SAndroid Build Coastguard Worker    _use_grad_for_differentiable,
19*da0073e9SAndroid Build Coastguard Worker    _view_as_real,
20*da0073e9SAndroid Build Coastguard Worker    Optimizer,
21*da0073e9SAndroid Build Coastguard Worker    ParamsT,
22*da0073e9SAndroid Build Coastguard Worker)
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker__all__ = ["RMSprop", "rmsprop"]
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Workerclass RMSprop(Optimizer):  # noqa: D101
29*da0073e9SAndroid Build Coastguard Worker    def __init__(
30*da0073e9SAndroid Build Coastguard Worker        self,
31*da0073e9SAndroid Build Coastguard Worker        params: ParamsT,
32*da0073e9SAndroid Build Coastguard Worker        lr: Union[float, Tensor] = 1e-2,
33*da0073e9SAndroid Build Coastguard Worker        alpha: float = 0.99,
34*da0073e9SAndroid Build Coastguard Worker        eps: float = 1e-8,
35*da0073e9SAndroid Build Coastguard Worker        weight_decay: float = 0,
36*da0073e9SAndroid Build Coastguard Worker        momentum: float = 0,
37*da0073e9SAndroid Build Coastguard Worker        centered=False,
38*da0073e9SAndroid Build Coastguard Worker        capturable=False,
39*da0073e9SAndroid Build Coastguard Worker        foreach: Optional[bool] = None,
40*da0073e9SAndroid Build Coastguard Worker        maximize: bool = False,
41*da0073e9SAndroid Build Coastguard Worker        differentiable: bool = False,
42*da0073e9SAndroid Build Coastguard Worker    ):  # noqa: D107
43*da0073e9SAndroid Build Coastguard Worker        if isinstance(lr, Tensor) and lr.numel() != 1:
44*da0073e9SAndroid Build Coastguard Worker            raise ValueError("Tensor lr must be 1-element")
45*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= lr:
46*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid learning rate: {lr}")
47*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= eps:
48*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid epsilon value: {eps}")
49*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= momentum:
50*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid momentum value: {momentum}")
51*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= weight_decay:
52*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
53*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= alpha:
54*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid alpha value: {alpha}")
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Worker        defaults = dict(
57*da0073e9SAndroid Build Coastguard Worker            lr=lr,
58*da0073e9SAndroid Build Coastguard Worker            momentum=momentum,
59*da0073e9SAndroid Build Coastguard Worker            alpha=alpha,
60*da0073e9SAndroid Build Coastguard Worker            eps=eps,
61*da0073e9SAndroid Build Coastguard Worker            centered=centered,
62*da0073e9SAndroid Build Coastguard Worker            weight_decay=weight_decay,
63*da0073e9SAndroid Build Coastguard Worker            capturable=capturable,
64*da0073e9SAndroid Build Coastguard Worker            foreach=foreach,
65*da0073e9SAndroid Build Coastguard Worker            maximize=maximize,
66*da0073e9SAndroid Build Coastguard Worker            differentiable=differentiable,
67*da0073e9SAndroid Build Coastguard Worker        )
68*da0073e9SAndroid Build Coastguard Worker        super().__init__(params, defaults)
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker    def __setstate__(self, state):  # noqa: D105
71*da0073e9SAndroid Build Coastguard Worker        super().__setstate__(state)
72*da0073e9SAndroid Build Coastguard Worker        for group in self.param_groups:
73*da0073e9SAndroid Build Coastguard Worker            group.setdefault("momentum", 0)
74*da0073e9SAndroid Build Coastguard Worker            group.setdefault("centered", False)
75*da0073e9SAndroid Build Coastguard Worker            group.setdefault("foreach", None)
76*da0073e9SAndroid Build Coastguard Worker            group.setdefault("maximize", False)
77*da0073e9SAndroid Build Coastguard Worker            group.setdefault("differentiable", False)
78*da0073e9SAndroid Build Coastguard Worker            group.setdefault("capturable", False)
79*da0073e9SAndroid Build Coastguard Worker            for p in group["params"]:
80*da0073e9SAndroid Build Coastguard Worker                p_state = self.state.get(p, [])
81*da0073e9SAndroid Build Coastguard Worker                if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
82*da0073e9SAndroid Build Coastguard Worker                    step_val = float(p_state["step"])
83*da0073e9SAndroid Build Coastguard Worker                    p_state["step"] = (
84*da0073e9SAndroid Build Coastguard Worker                        torch.tensor(
85*da0073e9SAndroid Build Coastguard Worker                            step_val, dtype=_get_scalar_dtype(), device=p.device
86*da0073e9SAndroid Build Coastguard Worker                        )
87*da0073e9SAndroid Build Coastguard Worker                        if group["capturable"]
88*da0073e9SAndroid Build Coastguard Worker                        else torch.tensor(step_val, dtype=_get_scalar_dtype())
89*da0073e9SAndroid Build Coastguard Worker                    )
90*da0073e9SAndroid Build Coastguard Worker
91*da0073e9SAndroid Build Coastguard Worker    def _init_group(
92*da0073e9SAndroid Build Coastguard Worker        self,
93*da0073e9SAndroid Build Coastguard Worker        group,
94*da0073e9SAndroid Build Coastguard Worker        params_with_grad,
95*da0073e9SAndroid Build Coastguard Worker        grads,
96*da0073e9SAndroid Build Coastguard Worker        square_avgs,
97*da0073e9SAndroid Build Coastguard Worker        momentum_buffer_list,
98*da0073e9SAndroid Build Coastguard Worker        grad_avgs,
99*da0073e9SAndroid Build Coastguard Worker        state_steps,
100*da0073e9SAndroid Build Coastguard Worker    ):
101*da0073e9SAndroid Build Coastguard Worker        has_complex = False
102*da0073e9SAndroid Build Coastguard Worker        for p in group["params"]:
103*da0073e9SAndroid Build Coastguard Worker            if p.grad is None:
104*da0073e9SAndroid Build Coastguard Worker                continue
105*da0073e9SAndroid Build Coastguard Worker            has_complex |= torch.is_complex(p)
106*da0073e9SAndroid Build Coastguard Worker            params_with_grad.append(p)
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker            if p.grad.is_sparse:
109*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError("RMSprop does not support sparse gradients")
110*da0073e9SAndroid Build Coastguard Worker            grads.append(p.grad)
111*da0073e9SAndroid Build Coastguard Worker
112*da0073e9SAndroid Build Coastguard Worker            state = self.state[p]
113*da0073e9SAndroid Build Coastguard Worker
114*da0073e9SAndroid Build Coastguard Worker            # State initialization
115*da0073e9SAndroid Build Coastguard Worker            if len(state) == 0:
116*da0073e9SAndroid Build Coastguard Worker                state["step"] = (
117*da0073e9SAndroid Build Coastguard Worker                    torch.zeros((), dtype=_get_scalar_dtype(), device=p.device)
118*da0073e9SAndroid Build Coastguard Worker                    if group["capturable"]
119*da0073e9SAndroid Build Coastguard Worker                    else torch.zeros((), dtype=_get_scalar_dtype())
120*da0073e9SAndroid Build Coastguard Worker                )
121*da0073e9SAndroid Build Coastguard Worker                state["square_avg"] = torch.zeros_like(
122*da0073e9SAndroid Build Coastguard Worker                    p, memory_format=torch.preserve_format
123*da0073e9SAndroid Build Coastguard Worker                )
124*da0073e9SAndroid Build Coastguard Worker                if group["momentum"] > 0:
125*da0073e9SAndroid Build Coastguard Worker                    state["momentum_buffer"] = torch.zeros_like(
126*da0073e9SAndroid Build Coastguard Worker                        p, memory_format=torch.preserve_format
127*da0073e9SAndroid Build Coastguard Worker                    )
128*da0073e9SAndroid Build Coastguard Worker                if group["centered"]:
129*da0073e9SAndroid Build Coastguard Worker                    state["grad_avg"] = torch.zeros_like(
130*da0073e9SAndroid Build Coastguard Worker                        p, memory_format=torch.preserve_format
131*da0073e9SAndroid Build Coastguard Worker                    )
132*da0073e9SAndroid Build Coastguard Worker            square_avgs.append(state["square_avg"])
133*da0073e9SAndroid Build Coastguard Worker            state_steps.append(state["step"])
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Worker            if group["momentum"] > 0:
136*da0073e9SAndroid Build Coastguard Worker                momentum_buffer_list.append(state["momentum_buffer"])
137*da0073e9SAndroid Build Coastguard Worker            if group["centered"]:
138*da0073e9SAndroid Build Coastguard Worker                grad_avgs.append(state["grad_avg"])
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker        return has_complex
141*da0073e9SAndroid Build Coastguard Worker
142*da0073e9SAndroid Build Coastguard Worker    @_use_grad_for_differentiable
143*da0073e9SAndroid Build Coastguard Worker    def step(self, closure=None):
144*da0073e9SAndroid Build Coastguard Worker        """Perform a single optimization step.
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker        Args:
147*da0073e9SAndroid Build Coastguard Worker            closure (Callable, optional): A closure that reevaluates the model
148*da0073e9SAndroid Build Coastguard Worker                and returns the loss.
149*da0073e9SAndroid Build Coastguard Worker        """
150*da0073e9SAndroid Build Coastguard Worker        self._cuda_graph_capture_health_check()
151*da0073e9SAndroid Build Coastguard Worker
152*da0073e9SAndroid Build Coastguard Worker        loss = None
153*da0073e9SAndroid Build Coastguard Worker        if closure is not None:
154*da0073e9SAndroid Build Coastguard Worker            with torch.enable_grad():
155*da0073e9SAndroid Build Coastguard Worker                loss = closure()
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker        for group in self.param_groups:
158*da0073e9SAndroid Build Coastguard Worker            params_with_grad: List[Tensor] = []
159*da0073e9SAndroid Build Coastguard Worker            grads: List[Tensor] = []
160*da0073e9SAndroid Build Coastguard Worker            square_avgs: List[Tensor] = []
161*da0073e9SAndroid Build Coastguard Worker            grad_avgs: List[Tensor] = []
162*da0073e9SAndroid Build Coastguard Worker            momentum_buffer_list: List[Tensor] = []
163*da0073e9SAndroid Build Coastguard Worker            state_steps: List[Tensor] = []
164*da0073e9SAndroid Build Coastguard Worker
165*da0073e9SAndroid Build Coastguard Worker            has_complex = self._init_group(
166*da0073e9SAndroid Build Coastguard Worker                group,
167*da0073e9SAndroid Build Coastguard Worker                params_with_grad,
168*da0073e9SAndroid Build Coastguard Worker                grads,
169*da0073e9SAndroid Build Coastguard Worker                square_avgs,
170*da0073e9SAndroid Build Coastguard Worker                momentum_buffer_list,
171*da0073e9SAndroid Build Coastguard Worker                grad_avgs,
172*da0073e9SAndroid Build Coastguard Worker                state_steps,
173*da0073e9SAndroid Build Coastguard Worker            )
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Worker            rmsprop(
176*da0073e9SAndroid Build Coastguard Worker                params_with_grad,
177*da0073e9SAndroid Build Coastguard Worker                grads,
178*da0073e9SAndroid Build Coastguard Worker                square_avgs,
179*da0073e9SAndroid Build Coastguard Worker                grad_avgs,
180*da0073e9SAndroid Build Coastguard Worker                momentum_buffer_list,
181*da0073e9SAndroid Build Coastguard Worker                state_steps,
182*da0073e9SAndroid Build Coastguard Worker                lr=group["lr"],
183*da0073e9SAndroid Build Coastguard Worker                alpha=group["alpha"],
184*da0073e9SAndroid Build Coastguard Worker                eps=group["eps"],
185*da0073e9SAndroid Build Coastguard Worker                weight_decay=group["weight_decay"],
186*da0073e9SAndroid Build Coastguard Worker                momentum=group["momentum"],
187*da0073e9SAndroid Build Coastguard Worker                centered=group["centered"],
188*da0073e9SAndroid Build Coastguard Worker                foreach=group["foreach"],
189*da0073e9SAndroid Build Coastguard Worker                maximize=group["maximize"],
190*da0073e9SAndroid Build Coastguard Worker                differentiable=group["differentiable"],
191*da0073e9SAndroid Build Coastguard Worker                capturable=group["capturable"],
192*da0073e9SAndroid Build Coastguard Worker                has_complex=has_complex,
193*da0073e9SAndroid Build Coastguard Worker            )
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Worker        return loss
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker
198*da0073e9SAndroid Build Coastguard WorkerRMSprop.__doc__ = (
199*da0073e9SAndroid Build Coastguard Worker    r"""Implements RMSprop algorithm.
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Worker    .. math::
202*da0073e9SAndroid Build Coastguard Worker       \begin{aligned}
203*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                                 \\
204*da0073e9SAndroid Build Coastguard Worker            &\textbf{input}      : \alpha \text{ (alpha)},\: \gamma \text{ (lr)},
205*da0073e9SAndroid Build Coastguard Worker                \: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)}                   \\
206*da0073e9SAndroid Build Coastguard Worker            &\hspace{13mm}   \lambda \text{ (weight decay)},\: \mu \text{ (momentum)},\: centered\\
207*da0073e9SAndroid Build Coastguard Worker            &\textbf{initialize} : v_0 \leftarrow 0 \text{ (square average)}, \:
208*da0073e9SAndroid Build Coastguard Worker                \textbf{b}_0 \leftarrow 0 \text{ (buffer)}, \: g^{ave}_0 \leftarrow 0     \\[-1.ex]
209*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                                 \\
210*da0073e9SAndroid Build Coastguard Worker            &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do}                         \\
211*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}g_t           \leftarrow   \nabla_{\theta} f_t (\theta_{t-1})           \\
212*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}if \: \lambda \neq 0                                                    \\
213*da0073e9SAndroid Build Coastguard Worker            &\hspace{10mm} g_t \leftarrow g_t + \lambda  \theta_{t-1}                            \\
214*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}v_t           \leftarrow   \alpha v_{t-1} + (1 - \alpha) g^2_t
215*da0073e9SAndroid Build Coastguard Worker                \hspace{8mm}                                                                     \\
216*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm} \tilde{v_t} \leftarrow v_t                                             \\
217*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}if \: centered                                                          \\
218*da0073e9SAndroid Build Coastguard Worker            &\hspace{10mm} g^{ave}_t \leftarrow g^{ave}_{t-1} \alpha + (1-\alpha) g_t            \\
219*da0073e9SAndroid Build Coastguard Worker            &\hspace{10mm} \tilde{v_t} \leftarrow \tilde{v_t} -  \big(g^{ave}_{t} \big)^2        \\
220*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm}if \: \mu > 0                                                           \\
221*da0073e9SAndroid Build Coastguard Worker            &\hspace{10mm} \textbf{b}_t\leftarrow \mu \textbf{b}_{t-1} +
222*da0073e9SAndroid Build Coastguard Worker                g_t/ \big(\sqrt{\tilde{v_t}} +  \epsilon \big)                                   \\
223*da0073e9SAndroid Build Coastguard Worker            &\hspace{10mm} \theta_t \leftarrow \theta_{t-1} - \gamma \textbf{b}_t                \\
224*da0073e9SAndroid Build Coastguard Worker            &\hspace{5mm} else                                                                   \\
225*da0073e9SAndroid Build Coastguard Worker            &\hspace{10mm}\theta_t      \leftarrow   \theta_{t-1} -
226*da0073e9SAndroid Build Coastguard Worker                \gamma  g_t/ \big(\sqrt{\tilde{v_t}} + \epsilon \big)  \hspace{3mm}              \\
227*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
228*da0073e9SAndroid Build Coastguard Worker            &\bf{return} \:  \theta_t                                                     \\[-1.ex]
229*da0073e9SAndroid Build Coastguard Worker            &\rule{110mm}{0.4pt}                                                          \\[-1.ex]
230*da0073e9SAndroid Build Coastguard Worker       \end{aligned}
231*da0073e9SAndroid Build Coastguard Worker
232*da0073e9SAndroid Build Coastguard Worker    For further details regarding the algorithm we refer to
233*da0073e9SAndroid Build Coastguard Worker    `lecture notes <https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_ by G. Hinton.
234*da0073e9SAndroid Build Coastguard Worker    and centered version `Generating Sequences
235*da0073e9SAndroid Build Coastguard Worker    With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_.
236*da0073e9SAndroid Build Coastguard Worker    The implementation here takes the square root of the gradient average before
237*da0073e9SAndroid Build Coastguard Worker    adding epsilon (note that TensorFlow interchanges these two operations). The effective
238*da0073e9SAndroid Build Coastguard Worker    learning rate is thus :math:`\gamma/(\sqrt{v} + \epsilon)` where :math:`\gamma`
239*da0073e9SAndroid Build Coastguard Worker    is the scheduled learning rate and :math:`v` is the weighted moving average
240*da0073e9SAndroid Build Coastguard Worker    of the squared gradient.
241*da0073e9SAndroid Build Coastguard Worker    """
242*da0073e9SAndroid Build Coastguard Worker    + rf"""
243*da0073e9SAndroid Build Coastguard Worker    Args:
244*da0073e9SAndroid Build Coastguard Worker        params (iterable): iterable of parameters to optimize or dicts defining
245*da0073e9SAndroid Build Coastguard Worker            parameter groups
246*da0073e9SAndroid Build Coastguard Worker        lr (float, Tensor, optional): learning rate (default: 1e-2)
247*da0073e9SAndroid Build Coastguard Worker        momentum (float, optional): momentum factor (default: 0)
248*da0073e9SAndroid Build Coastguard Worker        alpha (float, optional): smoothing constant (default: 0.99)
249*da0073e9SAndroid Build Coastguard Worker        eps (float, optional): term added to the denominator to improve
250*da0073e9SAndroid Build Coastguard Worker            numerical stability (default: 1e-8)
251*da0073e9SAndroid Build Coastguard Worker        centered (bool, optional) : if ``True``, compute the centered RMSProp,
252*da0073e9SAndroid Build Coastguard Worker            the gradient is normalized by an estimation of its variance
253*da0073e9SAndroid Build Coastguard Worker        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
254*da0073e9SAndroid Build Coastguard Worker        {_foreach_doc}
255*da0073e9SAndroid Build Coastguard Worker        {_maximize_doc}
256*da0073e9SAndroid Build Coastguard Worker        {_capturable_doc}
257*da0073e9SAndroid Build Coastguard Worker        {_differentiable_doc}
258*da0073e9SAndroid Build Coastguard Worker
259*da0073e9SAndroid Build Coastguard Worker    """
260*da0073e9SAndroid Build Coastguard Worker)
261*da0073e9SAndroid Build Coastguard Worker
262*da0073e9SAndroid Build Coastguard Worker
263*da0073e9SAndroid Build Coastguard Workerdef _single_tensor_rmsprop(
264*da0073e9SAndroid Build Coastguard Worker    params: List[Tensor],
265*da0073e9SAndroid Build Coastguard Worker    grads: List[Tensor],
266*da0073e9SAndroid Build Coastguard Worker    square_avgs: List[Tensor],
267*da0073e9SAndroid Build Coastguard Worker    grad_avgs: List[Tensor],
268*da0073e9SAndroid Build Coastguard Worker    momentum_buffer_list: List[Tensor],
269*da0073e9SAndroid Build Coastguard Worker    state_steps: List[Tensor],
270*da0073e9SAndroid Build Coastguard Worker    *,
271*da0073e9SAndroid Build Coastguard Worker    lr: float,
272*da0073e9SAndroid Build Coastguard Worker    alpha: float,
273*da0073e9SAndroid Build Coastguard Worker    eps: float,
274*da0073e9SAndroid Build Coastguard Worker    weight_decay: float,
275*da0073e9SAndroid Build Coastguard Worker    momentum: float,
276*da0073e9SAndroid Build Coastguard Worker    centered: bool,
277*da0073e9SAndroid Build Coastguard Worker    maximize: bool,
278*da0073e9SAndroid Build Coastguard Worker    differentiable: bool,
279*da0073e9SAndroid Build Coastguard Worker    capturable: bool,
280*da0073e9SAndroid Build Coastguard Worker    has_complex: bool,
281*da0073e9SAndroid Build Coastguard Worker):
282*da0073e9SAndroid Build Coastguard Worker    for i, param in enumerate(params):
283*da0073e9SAndroid Build Coastguard Worker        step = state_steps[i]
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker        # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
286*da0073e9SAndroid Build Coastguard Worker        if not torch._utils.is_compiling() and capturable:
287*da0073e9SAndroid Build Coastguard Worker            capturable_supported_devices = _get_capturable_supported_devices()
288*da0073e9SAndroid Build Coastguard Worker            assert (
289*da0073e9SAndroid Build Coastguard Worker                param.device.type == step.device.type
290*da0073e9SAndroid Build Coastguard Worker                and param.device.type in capturable_supported_devices
291*da0073e9SAndroid Build Coastguard Worker            ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
292*da0073e9SAndroid Build Coastguard Worker
293*da0073e9SAndroid Build Coastguard Worker        grad = grads[i]
294*da0073e9SAndroid Build Coastguard Worker        grad = grad if not maximize else -grad
295*da0073e9SAndroid Build Coastguard Worker        square_avg = square_avgs[i]
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker        step += 1
298*da0073e9SAndroid Build Coastguard Worker
299*da0073e9SAndroid Build Coastguard Worker        if weight_decay != 0:
300*da0073e9SAndroid Build Coastguard Worker            grad = grad.add(param, alpha=weight_decay)
301*da0073e9SAndroid Build Coastguard Worker
302*da0073e9SAndroid Build Coastguard Worker        is_complex_param = torch.is_complex(param)
303*da0073e9SAndroid Build Coastguard Worker        if is_complex_param:
304*da0073e9SAndroid Build Coastguard Worker            param = torch.view_as_real(param)
305*da0073e9SAndroid Build Coastguard Worker            grad = torch.view_as_real(grad)
306*da0073e9SAndroid Build Coastguard Worker            square_avg = torch.view_as_real(square_avg)
307*da0073e9SAndroid Build Coastguard Worker
308*da0073e9SAndroid Build Coastguard Worker        square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha)
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Worker        if centered:
311*da0073e9SAndroid Build Coastguard Worker            grad_avg = grad_avgs[i]
312*da0073e9SAndroid Build Coastguard Worker            if is_complex_param:
313*da0073e9SAndroid Build Coastguard Worker                grad_avg = torch.view_as_real(grad_avg)
314*da0073e9SAndroid Build Coastguard Worker            grad_avg.lerp_(grad, 1 - alpha)
315*da0073e9SAndroid Build Coastguard Worker            avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_()
316*da0073e9SAndroid Build Coastguard Worker        else:
317*da0073e9SAndroid Build Coastguard Worker            avg = square_avg.sqrt()
318*da0073e9SAndroid Build Coastguard Worker
319*da0073e9SAndroid Build Coastguard Worker        if differentiable:
320*da0073e9SAndroid Build Coastguard Worker            avg = avg.add(eps)
321*da0073e9SAndroid Build Coastguard Worker        else:
322*da0073e9SAndroid Build Coastguard Worker            avg = avg.add_(eps)
323*da0073e9SAndroid Build Coastguard Worker
324*da0073e9SAndroid Build Coastguard Worker        if momentum > 0:
325*da0073e9SAndroid Build Coastguard Worker            buf = momentum_buffer_list[i]
326*da0073e9SAndroid Build Coastguard Worker            if is_complex_param:
327*da0073e9SAndroid Build Coastguard Worker                buf = torch.view_as_real(buf)
328*da0073e9SAndroid Build Coastguard Worker            buf.mul_(momentum).addcdiv_(grad, avg)
329*da0073e9SAndroid Build Coastguard Worker            param.add_(buf, alpha=-lr)
330*da0073e9SAndroid Build Coastguard Worker        else:
331*da0073e9SAndroid Build Coastguard Worker            param.addcdiv_(grad, avg, value=-lr)
332*da0073e9SAndroid Build Coastguard Worker
333*da0073e9SAndroid Build Coastguard Worker
334*da0073e9SAndroid Build Coastguard Workerdef _multi_tensor_rmsprop(
335*da0073e9SAndroid Build Coastguard Worker    params: List[Tensor],
336*da0073e9SAndroid Build Coastguard Worker    grads: List[Tensor],
337*da0073e9SAndroid Build Coastguard Worker    square_avgs: List[Tensor],
338*da0073e9SAndroid Build Coastguard Worker    grad_avgs: List[Tensor],
339*da0073e9SAndroid Build Coastguard Worker    momentum_buffer_list: List[Tensor],
340*da0073e9SAndroid Build Coastguard Worker    state_steps: List[Tensor],
341*da0073e9SAndroid Build Coastguard Worker    *,
342*da0073e9SAndroid Build Coastguard Worker    lr: float,
343*da0073e9SAndroid Build Coastguard Worker    alpha: float,
344*da0073e9SAndroid Build Coastguard Worker    eps: float,
345*da0073e9SAndroid Build Coastguard Worker    weight_decay: float,
346*da0073e9SAndroid Build Coastguard Worker    momentum: float,
347*da0073e9SAndroid Build Coastguard Worker    centered: bool,
348*da0073e9SAndroid Build Coastguard Worker    maximize: bool,
349*da0073e9SAndroid Build Coastguard Worker    differentiable: bool,
350*da0073e9SAndroid Build Coastguard Worker    capturable: bool,
351*da0073e9SAndroid Build Coastguard Worker    has_complex: bool,
352*da0073e9SAndroid Build Coastguard Worker):
353*da0073e9SAndroid Build Coastguard Worker    if len(params) == 0:
354*da0073e9SAndroid Build Coastguard Worker        return
355*da0073e9SAndroid Build Coastguard Worker
356*da0073e9SAndroid Build Coastguard Worker    assert not differentiable, "_foreach ops don't support autograd"
357*da0073e9SAndroid Build Coastguard Worker
358*da0073e9SAndroid Build Coastguard Worker    # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
359*da0073e9SAndroid Build Coastguard Worker    if not torch._utils.is_compiling() and capturable:
360*da0073e9SAndroid Build Coastguard Worker        capturable_supported_devices = _get_capturable_supported_devices()
361*da0073e9SAndroid Build Coastguard Worker        assert all(
362*da0073e9SAndroid Build Coastguard Worker            p.device.type == step.device.type
363*da0073e9SAndroid Build Coastguard Worker            and p.device.type in capturable_supported_devices
364*da0073e9SAndroid Build Coastguard Worker            for p, step in zip(params, state_steps)
365*da0073e9SAndroid Build Coastguard Worker        ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
366*da0073e9SAndroid Build Coastguard Worker
367*da0073e9SAndroid Build Coastguard Worker    grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
368*da0073e9SAndroid Build Coastguard Worker        [params, grads, square_avgs, grad_avgs, momentum_buffer_list, state_steps]  # type: ignore[list-item]
369*da0073e9SAndroid Build Coastguard Worker    )
370*da0073e9SAndroid Build Coastguard Worker    for (
371*da0073e9SAndroid Build Coastguard Worker        (
372*da0073e9SAndroid Build Coastguard Worker            grouped_params_,
373*da0073e9SAndroid Build Coastguard Worker            grouped_grads_,
374*da0073e9SAndroid Build Coastguard Worker            grouped_square_avgs_,
375*da0073e9SAndroid Build Coastguard Worker            grouped_grad_avgs_,
376*da0073e9SAndroid Build Coastguard Worker            grouped_momentum_buffer_list_,
377*da0073e9SAndroid Build Coastguard Worker            grouped_state_steps_,
378*da0073e9SAndroid Build Coastguard Worker        )
379*da0073e9SAndroid Build Coastguard Worker    ), _ in grouped_tensors.values():
380*da0073e9SAndroid Build Coastguard Worker        grouped_params = cast(List[Tensor], grouped_params_)
381*da0073e9SAndroid Build Coastguard Worker        grouped_grads = cast(List[Tensor], grouped_grads_)
382*da0073e9SAndroid Build Coastguard Worker        grouped_square_avgs = cast(List[Tensor], grouped_square_avgs_)
383*da0073e9SAndroid Build Coastguard Worker        grouped_state_steps = cast(List[Tensor], grouped_state_steps_)
384*da0073e9SAndroid Build Coastguard Worker
385*da0073e9SAndroid Build Coastguard Worker        if has_complex:
386*da0073e9SAndroid Build Coastguard Worker            state_and_grads = [grouped_grads, grouped_square_avgs]
387*da0073e9SAndroid Build Coastguard Worker            if momentum > 0:
388*da0073e9SAndroid Build Coastguard Worker                grouped_momentum_buffer_list = cast(
389*da0073e9SAndroid Build Coastguard Worker                    List[Tensor], grouped_momentum_buffer_list_
390*da0073e9SAndroid Build Coastguard Worker                )
391*da0073e9SAndroid Build Coastguard Worker                state_and_grads.append(grouped_momentum_buffer_list)
392*da0073e9SAndroid Build Coastguard Worker            if centered:
393*da0073e9SAndroid Build Coastguard Worker                grouped_grad_avgs = cast(List[Tensor], grouped_grad_avgs_)
394*da0073e9SAndroid Build Coastguard Worker                state_and_grads.append(grouped_grad_avgs)
395*da0073e9SAndroid Build Coastguard Worker            _view_as_real(grouped_params, *state_and_grads)
396*da0073e9SAndroid Build Coastguard Worker
397*da0073e9SAndroid Build Coastguard Worker        if maximize:
398*da0073e9SAndroid Build Coastguard Worker            grouped_grads = torch._foreach_neg(grouped_grads)  # type: ignore[assignment]
399*da0073e9SAndroid Build Coastguard Worker
400*da0073e9SAndroid Build Coastguard Worker        # Update steps
401*da0073e9SAndroid Build Coastguard Worker        # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
402*da0073e9SAndroid Build Coastguard Worker        # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
403*da0073e9SAndroid Build Coastguard Worker        # wrapped it once now. The alpha is required to assure we go to the right overload.
404*da0073e9SAndroid Build Coastguard Worker        if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
405*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(
406*da0073e9SAndroid Build Coastguard Worker                grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
407*da0073e9SAndroid Build Coastguard Worker            )
408*da0073e9SAndroid Build Coastguard Worker        else:
409*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(grouped_state_steps, 1)
410*da0073e9SAndroid Build Coastguard Worker
411*da0073e9SAndroid Build Coastguard Worker        if weight_decay != 0:
412*da0073e9SAndroid Build Coastguard Worker            # Re-use the intermediate memory (grouped_grads) already allocated for maximize
413*da0073e9SAndroid Build Coastguard Worker            if maximize:
414*da0073e9SAndroid Build Coastguard Worker                torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay)
415*da0073e9SAndroid Build Coastguard Worker            else:
416*da0073e9SAndroid Build Coastguard Worker                grouped_grads = torch._foreach_add(  # type: ignore[assignment]
417*da0073e9SAndroid Build Coastguard Worker                    grouped_grads, grouped_params, alpha=weight_decay
418*da0073e9SAndroid Build Coastguard Worker                )
419*da0073e9SAndroid Build Coastguard Worker
420*da0073e9SAndroid Build Coastguard Worker        torch._foreach_mul_(grouped_square_avgs, alpha)
421*da0073e9SAndroid Build Coastguard Worker        torch._foreach_addcmul_(
422*da0073e9SAndroid Build Coastguard Worker            grouped_square_avgs, grouped_grads, grouped_grads, value=1 - alpha
423*da0073e9SAndroid Build Coastguard Worker        )
424*da0073e9SAndroid Build Coastguard Worker
425*da0073e9SAndroid Build Coastguard Worker        if centered:
426*da0073e9SAndroid Build Coastguard Worker            grouped_grad_avgs = cast(List[Tensor], grouped_grad_avgs_)
427*da0073e9SAndroid Build Coastguard Worker            torch._foreach_lerp_(grouped_grad_avgs, grouped_grads, 1 - alpha)
428*da0073e9SAndroid Build Coastguard Worker            avg = torch._foreach_addcmul(
429*da0073e9SAndroid Build Coastguard Worker                grouped_square_avgs, grouped_grad_avgs, grouped_grad_avgs, value=-1
430*da0073e9SAndroid Build Coastguard Worker            )
431*da0073e9SAndroid Build Coastguard Worker            torch._foreach_sqrt_(avg)
432*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(avg, eps)
433*da0073e9SAndroid Build Coastguard Worker        else:
434*da0073e9SAndroid Build Coastguard Worker            avg = torch._foreach_sqrt(grouped_square_avgs)
435*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(avg, eps)
436*da0073e9SAndroid Build Coastguard Worker
437*da0073e9SAndroid Build Coastguard Worker        if momentum > 0:
438*da0073e9SAndroid Build Coastguard Worker            grouped_momentum_buffer_list = cast(
439*da0073e9SAndroid Build Coastguard Worker                List[Tensor], grouped_momentum_buffer_list_
440*da0073e9SAndroid Build Coastguard Worker            )
441*da0073e9SAndroid Build Coastguard Worker            torch._foreach_mul_(grouped_momentum_buffer_list, momentum)
442*da0073e9SAndroid Build Coastguard Worker            torch._foreach_addcdiv_(grouped_momentum_buffer_list, grouped_grads, avg)
443*da0073e9SAndroid Build Coastguard Worker            # If LR is a tensor, the else branch will internally call item()
444*da0073e9SAndroid Build Coastguard Worker            # which will cause silent incorrectness if we are capturing
445*da0073e9SAndroid Build Coastguard Worker            if capturable and isinstance(lr, torch.Tensor):
446*da0073e9SAndroid Build Coastguard Worker                momentum_lr = torch._foreach_mul(grouped_momentum_buffer_list, -lr)
447*da0073e9SAndroid Build Coastguard Worker                torch._foreach_add_(grouped_params, momentum_lr)
448*da0073e9SAndroid Build Coastguard Worker            else:
449*da0073e9SAndroid Build Coastguard Worker                torch._foreach_add_(
450*da0073e9SAndroid Build Coastguard Worker                    grouped_params, grouped_momentum_buffer_list, alpha=-lr
451*da0073e9SAndroid Build Coastguard Worker                )
452*da0073e9SAndroid Build Coastguard Worker        else:
453*da0073e9SAndroid Build Coastguard Worker            # If LR is a tensor, the else branch will internally call item()
454*da0073e9SAndroid Build Coastguard Worker            # which will cause silent incorrectness if we are capturing
455*da0073e9SAndroid Build Coastguard Worker            if capturable and isinstance(lr, torch.Tensor):
456*da0073e9SAndroid Build Coastguard Worker                torch._foreach_div_(avg, -lr)
457*da0073e9SAndroid Build Coastguard Worker                torch._foreach_addcdiv_(grouped_params, grouped_grads, avg)
458*da0073e9SAndroid Build Coastguard Worker            else:
459*da0073e9SAndroid Build Coastguard Worker                torch._foreach_addcdiv_(grouped_params, grouped_grads, avg, value=-lr)
460*da0073e9SAndroid Build Coastguard Worker
461*da0073e9SAndroid Build Coastguard Worker
462*da0073e9SAndroid Build Coastguard Worker@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_rmsprop)
463*da0073e9SAndroid Build Coastguard Workerdef rmsprop(
464*da0073e9SAndroid Build Coastguard Worker    params: List[Tensor],
465*da0073e9SAndroid Build Coastguard Worker    grads: List[Tensor],
466*da0073e9SAndroid Build Coastguard Worker    square_avgs: List[Tensor],
467*da0073e9SAndroid Build Coastguard Worker    grad_avgs: List[Tensor],
468*da0073e9SAndroid Build Coastguard Worker    momentum_buffer_list: List[Tensor],
469*da0073e9SAndroid Build Coastguard Worker    state_steps: List[Tensor],
470*da0073e9SAndroid Build Coastguard Worker    # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
471*da0073e9SAndroid Build Coastguard Worker    # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
472*da0073e9SAndroid Build Coastguard Worker    foreach: Optional[bool] = None,
473*da0073e9SAndroid Build Coastguard Worker    maximize: bool = False,
474*da0073e9SAndroid Build Coastguard Worker    differentiable: bool = False,
475*da0073e9SAndroid Build Coastguard Worker    capturable: bool = False,
476*da0073e9SAndroid Build Coastguard Worker    has_complex: bool = False,
477*da0073e9SAndroid Build Coastguard Worker    *,
478*da0073e9SAndroid Build Coastguard Worker    lr: float,
479*da0073e9SAndroid Build Coastguard Worker    alpha: float,
480*da0073e9SAndroid Build Coastguard Worker    eps: float,
481*da0073e9SAndroid Build Coastguard Worker    weight_decay: float,
482*da0073e9SAndroid Build Coastguard Worker    momentum: float,
483*da0073e9SAndroid Build Coastguard Worker    centered: bool,
484*da0073e9SAndroid Build Coastguard Worker):
485*da0073e9SAndroid Build Coastguard Worker    r"""Functional API that performs rmsprop algorithm computation.
486*da0073e9SAndroid Build Coastguard Worker
487*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.optim.RMSProp` for details.
488*da0073e9SAndroid Build Coastguard Worker    """
489*da0073e9SAndroid Build Coastguard Worker    # this check is slow during compilation, so we skip it
490*da0073e9SAndroid Build Coastguard Worker    # if it's strictly needed we can add this check back in dynamo
491*da0073e9SAndroid Build Coastguard Worker    if not torch._utils.is_compiling() and not all(
492*da0073e9SAndroid Build Coastguard Worker        isinstance(t, torch.Tensor) for t in state_steps
493*da0073e9SAndroid Build Coastguard Worker    ):
494*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
495*da0073e9SAndroid Build Coastguard Worker            "API has changed, `state_steps` argument must contain a list of singleton tensors"
496*da0073e9SAndroid Build Coastguard Worker        )
497*da0073e9SAndroid Build Coastguard Worker
498*da0073e9SAndroid Build Coastguard Worker    if foreach is None:
499*da0073e9SAndroid Build Coastguard Worker        _, foreach = _default_to_fused_or_foreach(
500*da0073e9SAndroid Build Coastguard Worker            params, differentiable, use_fused=False
501*da0073e9SAndroid Build Coastguard Worker        )
502*da0073e9SAndroid Build Coastguard Worker
503*da0073e9SAndroid Build Coastguard Worker    if foreach and torch.jit.is_scripting():
504*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError("torch.jit.script not supported with foreach optimizers")
505*da0073e9SAndroid Build Coastguard Worker
506*da0073e9SAndroid Build Coastguard Worker    if foreach and not torch.jit.is_scripting():
507*da0073e9SAndroid Build Coastguard Worker        func = _multi_tensor_rmsprop
508*da0073e9SAndroid Build Coastguard Worker    else:
509*da0073e9SAndroid Build Coastguard Worker        func = _single_tensor_rmsprop
510*da0073e9SAndroid Build Coastguard Worker
511*da0073e9SAndroid Build Coastguard Worker    func(
512*da0073e9SAndroid Build Coastguard Worker        params,
513*da0073e9SAndroid Build Coastguard Worker        grads,
514*da0073e9SAndroid Build Coastguard Worker        square_avgs,
515*da0073e9SAndroid Build Coastguard Worker        grad_avgs,
516*da0073e9SAndroid Build Coastguard Worker        momentum_buffer_list,
517*da0073e9SAndroid Build Coastguard Worker        state_steps,
518*da0073e9SAndroid Build Coastguard Worker        lr=lr,
519*da0073e9SAndroid Build Coastguard Worker        alpha=alpha,
520*da0073e9SAndroid Build Coastguard Worker        eps=eps,
521*da0073e9SAndroid Build Coastguard Worker        weight_decay=weight_decay,
522*da0073e9SAndroid Build Coastguard Worker        momentum=momentum,
523*da0073e9SAndroid Build Coastguard Worker        centered=centered,
524*da0073e9SAndroid Build Coastguard Worker        maximize=maximize,
525*da0073e9SAndroid Build Coastguard Worker        capturable=capturable,
526*da0073e9SAndroid Build Coastguard Worker        differentiable=differentiable,
527*da0073e9SAndroid Build Coastguard Worker        has_complex=has_complex,
528*da0073e9SAndroid Build Coastguard Worker    )
529