xref: /aosp_15_r20/external/pytorch/torch/optim/asgd.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-decorators
2*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
3*da0073e9SAndroid Build Coastguard Workerfrom typing import cast, List, Optional, Tuple, Union
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerimport torch
6*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Workerfrom .optimizer import (
9*da0073e9SAndroid Build Coastguard Worker    _capturable_doc,
10*da0073e9SAndroid Build Coastguard Worker    _default_to_fused_or_foreach,
11*da0073e9SAndroid Build Coastguard Worker    _differentiable_doc,
12*da0073e9SAndroid Build Coastguard Worker    _disable_dynamo_if_unsupported,
13*da0073e9SAndroid Build Coastguard Worker    _foreach_doc,
14*da0073e9SAndroid Build Coastguard Worker    _get_capturable_supported_devices,
15*da0073e9SAndroid Build Coastguard Worker    _get_scalar_dtype,
16*da0073e9SAndroid Build Coastguard Worker    _get_value,
17*da0073e9SAndroid Build Coastguard Worker    _maximize_doc,
18*da0073e9SAndroid Build Coastguard Worker    _use_grad_for_differentiable,
19*da0073e9SAndroid Build Coastguard Worker    _view_as_real,
20*da0073e9SAndroid Build Coastguard Worker    Optimizer,
21*da0073e9SAndroid Build Coastguard Worker    ParamsT,
22*da0073e9SAndroid Build Coastguard Worker)
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker__all__ = ["ASGD", "asgd"]
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Workerclass ASGD(Optimizer):
29*da0073e9SAndroid Build Coastguard Worker    def __init__(
30*da0073e9SAndroid Build Coastguard Worker        self,
31*da0073e9SAndroid Build Coastguard Worker        params: ParamsT,
32*da0073e9SAndroid Build Coastguard Worker        lr: Union[float, Tensor] = 1e-2,
33*da0073e9SAndroid Build Coastguard Worker        lambd: float = 1e-4,
34*da0073e9SAndroid Build Coastguard Worker        alpha: float = 0.75,
35*da0073e9SAndroid Build Coastguard Worker        t0: float = 1e6,
36*da0073e9SAndroid Build Coastguard Worker        weight_decay: float = 0,
37*da0073e9SAndroid Build Coastguard Worker        foreach: Optional[bool] = None,
38*da0073e9SAndroid Build Coastguard Worker        maximize: bool = False,
39*da0073e9SAndroid Build Coastguard Worker        differentiable: bool = False,
40*da0073e9SAndroid Build Coastguard Worker        capturable: bool = False,
41*da0073e9SAndroid Build Coastguard Worker    ):
42*da0073e9SAndroid Build Coastguard Worker        if isinstance(lr, Tensor) and lr.numel() != 1:
43*da0073e9SAndroid Build Coastguard Worker            raise ValueError("Tensor lr must be 1-element")
44*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= lr:
45*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid learning rate: {lr}")
46*da0073e9SAndroid Build Coastguard Worker        if not 0.0 <= weight_decay:
47*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker        defaults = dict(
50*da0073e9SAndroid Build Coastguard Worker            lr=lr,
51*da0073e9SAndroid Build Coastguard Worker            lambd=lambd,
52*da0073e9SAndroid Build Coastguard Worker            alpha=alpha,
53*da0073e9SAndroid Build Coastguard Worker            t0=t0,
54*da0073e9SAndroid Build Coastguard Worker            weight_decay=weight_decay,
55*da0073e9SAndroid Build Coastguard Worker            foreach=foreach,
56*da0073e9SAndroid Build Coastguard Worker            maximize=maximize,
57*da0073e9SAndroid Build Coastguard Worker            differentiable=differentiable,
58*da0073e9SAndroid Build Coastguard Worker            capturable=capturable,
59*da0073e9SAndroid Build Coastguard Worker        )
60*da0073e9SAndroid Build Coastguard Worker        super().__init__(params, defaults)
61*da0073e9SAndroid Build Coastguard Worker
62*da0073e9SAndroid Build Coastguard Worker    def __setstate__(self, state):
63*da0073e9SAndroid Build Coastguard Worker        super().__setstate__(state)
64*da0073e9SAndroid Build Coastguard Worker        for group in self.param_groups:
65*da0073e9SAndroid Build Coastguard Worker            group.setdefault("foreach", None)
66*da0073e9SAndroid Build Coastguard Worker            group.setdefault("maximize", False)
67*da0073e9SAndroid Build Coastguard Worker            group.setdefault("differentiable", False)
68*da0073e9SAndroid Build Coastguard Worker            group.setdefault("capturable", False)
69*da0073e9SAndroid Build Coastguard Worker            for p in group["params"]:
70*da0073e9SAndroid Build Coastguard Worker                p_state = self.state.get(p, [])
71*da0073e9SAndroid Build Coastguard Worker                if len(p_state) != 0:
72*da0073e9SAndroid Build Coastguard Worker                    if not torch.is_tensor(p_state["step"]):
73*da0073e9SAndroid Build Coastguard Worker                        step_val = float(p_state["step"])
74*da0073e9SAndroid Build Coastguard Worker                        p_state["step"] = torch.tensor(
75*da0073e9SAndroid Build Coastguard Worker                            step_val, dtype=_get_scalar_dtype(), device=p.device
76*da0073e9SAndroid Build Coastguard Worker                        )
77*da0073e9SAndroid Build Coastguard Worker                    if not torch.is_tensor(p_state["eta"]):
78*da0073e9SAndroid Build Coastguard Worker                        p_state["eta"] = torch.tensor(
79*da0073e9SAndroid Build Coastguard Worker                            p_state["eta"], dtype=_get_scalar_dtype(), device=p.device
80*da0073e9SAndroid Build Coastguard Worker                        )
81*da0073e9SAndroid Build Coastguard Worker                    if not torch.is_tensor(p_state["mu"]):
82*da0073e9SAndroid Build Coastguard Worker                        p_state["mu"] = torch.tensor(
83*da0073e9SAndroid Build Coastguard Worker                            p_state["mu"], dtype=_get_scalar_dtype(), device=p.device
84*da0073e9SAndroid Build Coastguard Worker                        )
85*da0073e9SAndroid Build Coastguard Worker
86*da0073e9SAndroid Build Coastguard Worker    def _init_group(self, group, params_with_grad, grads, mus, axs, etas, state_steps):
87*da0073e9SAndroid Build Coastguard Worker        has_complex = False
88*da0073e9SAndroid Build Coastguard Worker        for p in group["params"]:
89*da0073e9SAndroid Build Coastguard Worker            if p.grad is not None:
90*da0073e9SAndroid Build Coastguard Worker                has_complex |= torch.is_complex(p)
91*da0073e9SAndroid Build Coastguard Worker                params_with_grad.append(p)
92*da0073e9SAndroid Build Coastguard Worker                if p.grad.is_sparse:
93*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError("ASGD does not support sparse gradients")
94*da0073e9SAndroid Build Coastguard Worker                grads.append(p.grad)
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker                state = self.state[p]
97*da0073e9SAndroid Build Coastguard Worker                # State initialization
98*da0073e9SAndroid Build Coastguard Worker                if len(state) == 0:
99*da0073e9SAndroid Build Coastguard Worker                    state["step"] = torch.zeros(
100*da0073e9SAndroid Build Coastguard Worker                        (), device=p.device, dtype=_get_scalar_dtype()
101*da0073e9SAndroid Build Coastguard Worker                    )
102*da0073e9SAndroid Build Coastguard Worker                    state["eta"] = (
103*da0073e9SAndroid Build Coastguard Worker                        torch.as_tensor(
104*da0073e9SAndroid Build Coastguard Worker                            group["lr"], device=p.device, dtype=_get_scalar_dtype()
105*da0073e9SAndroid Build Coastguard Worker                        )
106*da0073e9SAndroid Build Coastguard Worker                        .clone()
107*da0073e9SAndroid Build Coastguard Worker                        .detach()
108*da0073e9SAndroid Build Coastguard Worker                    )
109*da0073e9SAndroid Build Coastguard Worker                    state["mu"] = torch.ones(
110*da0073e9SAndroid Build Coastguard Worker                        (), device=p.device, dtype=_get_scalar_dtype()
111*da0073e9SAndroid Build Coastguard Worker                    )
112*da0073e9SAndroid Build Coastguard Worker                    state["ax"] = torch.zeros_like(
113*da0073e9SAndroid Build Coastguard Worker                        p, memory_format=torch.preserve_format
114*da0073e9SAndroid Build Coastguard Worker                    )
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker                mus.append(state["mu"])
117*da0073e9SAndroid Build Coastguard Worker                axs.append(state["ax"])
118*da0073e9SAndroid Build Coastguard Worker                etas.append(state["eta"])
119*da0073e9SAndroid Build Coastguard Worker                state_steps.append(state["step"])
120*da0073e9SAndroid Build Coastguard Worker        return has_complex
121*da0073e9SAndroid Build Coastguard Worker
122*da0073e9SAndroid Build Coastguard Worker    @_use_grad_for_differentiable
123*da0073e9SAndroid Build Coastguard Worker    def step(self, closure=None):
124*da0073e9SAndroid Build Coastguard Worker        """Perform a single optimization step.
125*da0073e9SAndroid Build Coastguard Worker
126*da0073e9SAndroid Build Coastguard Worker        Args:
127*da0073e9SAndroid Build Coastguard Worker            closure (Callable, optional): A closure that reevaluates the model
128*da0073e9SAndroid Build Coastguard Worker                and returns the loss.
129*da0073e9SAndroid Build Coastguard Worker        """
130*da0073e9SAndroid Build Coastguard Worker        self._cuda_graph_capture_health_check()
131*da0073e9SAndroid Build Coastguard Worker
132*da0073e9SAndroid Build Coastguard Worker        loss = None
133*da0073e9SAndroid Build Coastguard Worker        if closure is not None:
134*da0073e9SAndroid Build Coastguard Worker            with torch.enable_grad():
135*da0073e9SAndroid Build Coastguard Worker                loss = closure()
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker        for group in self.param_groups:
138*da0073e9SAndroid Build Coastguard Worker            params_with_grad: List[Tensor] = []
139*da0073e9SAndroid Build Coastguard Worker            grads: List[Tensor] = []
140*da0073e9SAndroid Build Coastguard Worker            mus: List[Tensor] = []
141*da0073e9SAndroid Build Coastguard Worker            axs: List[Tensor] = []
142*da0073e9SAndroid Build Coastguard Worker            etas: List[Tensor] = []
143*da0073e9SAndroid Build Coastguard Worker            state_steps: List[Tensor] = []
144*da0073e9SAndroid Build Coastguard Worker
145*da0073e9SAndroid Build Coastguard Worker            has_complex = self._init_group(
146*da0073e9SAndroid Build Coastguard Worker                group, params_with_grad, grads, mus, axs, etas, state_steps
147*da0073e9SAndroid Build Coastguard Worker            )
148*da0073e9SAndroid Build Coastguard Worker
149*da0073e9SAndroid Build Coastguard Worker            asgd(
150*da0073e9SAndroid Build Coastguard Worker                params_with_grad,
151*da0073e9SAndroid Build Coastguard Worker                grads,
152*da0073e9SAndroid Build Coastguard Worker                axs,
153*da0073e9SAndroid Build Coastguard Worker                mus,
154*da0073e9SAndroid Build Coastguard Worker                etas,
155*da0073e9SAndroid Build Coastguard Worker                state_steps,
156*da0073e9SAndroid Build Coastguard Worker                lambd=group["lambd"],
157*da0073e9SAndroid Build Coastguard Worker                lr=group["lr"],
158*da0073e9SAndroid Build Coastguard Worker                t0=group["t0"],
159*da0073e9SAndroid Build Coastguard Worker                alpha=group["alpha"],
160*da0073e9SAndroid Build Coastguard Worker                weight_decay=group["weight_decay"],
161*da0073e9SAndroid Build Coastguard Worker                foreach=group["foreach"],
162*da0073e9SAndroid Build Coastguard Worker                maximize=group["maximize"],
163*da0073e9SAndroid Build Coastguard Worker                differentiable=group["differentiable"],
164*da0073e9SAndroid Build Coastguard Worker                capturable=group["capturable"],
165*da0073e9SAndroid Build Coastguard Worker                has_complex=has_complex,
166*da0073e9SAndroid Build Coastguard Worker            )
167*da0073e9SAndroid Build Coastguard Worker
168*da0073e9SAndroid Build Coastguard Worker        return loss
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker
171*da0073e9SAndroid Build Coastguard WorkerASGD.__doc__ = rf"""Implements Averaged Stochastic Gradient Descent.
172*da0073e9SAndroid Build Coastguard Worker
173*da0073e9SAndroid Build Coastguard Worker    It has been proposed in `Acceleration of stochastic approximation by
174*da0073e9SAndroid Build Coastguard Worker    averaging`_.
175*da0073e9SAndroid Build Coastguard Worker
176*da0073e9SAndroid Build Coastguard Worker    Args:
177*da0073e9SAndroid Build Coastguard Worker        params (iterable): iterable of parameters to optimize or dicts defining
178*da0073e9SAndroid Build Coastguard Worker            parameter groups
179*da0073e9SAndroid Build Coastguard Worker        lr (float, Tensor, optional): learning rate (default: 1e-2)
180*da0073e9SAndroid Build Coastguard Worker        lambd (float, optional): decay term (default: 1e-4)
181*da0073e9SAndroid Build Coastguard Worker        alpha (float, optional): power for eta update (default: 0.75)
182*da0073e9SAndroid Build Coastguard Worker        t0 (float, optional): point at which to start averaging (default: 1e6)
183*da0073e9SAndroid Build Coastguard Worker        weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
184*da0073e9SAndroid Build Coastguard Worker        {_foreach_doc}
185*da0073e9SAndroid Build Coastguard Worker        {_maximize_doc}
186*da0073e9SAndroid Build Coastguard Worker        {_differentiable_doc}
187*da0073e9SAndroid Build Coastguard Worker        {_capturable_doc}
188*da0073e9SAndroid Build Coastguard Worker
189*da0073e9SAndroid Build Coastguard Worker    .. _Acceleration of stochastic approximation by averaging:
190*da0073e9SAndroid Build Coastguard Worker        https://dl.acm.org/citation.cfm?id=131098
191*da0073e9SAndroid Build Coastguard Worker
192*da0073e9SAndroid Build Coastguard Worker    """
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Workerdef _single_tensor_asgd(
196*da0073e9SAndroid Build Coastguard Worker    params: List[Tensor],
197*da0073e9SAndroid Build Coastguard Worker    grads: List[Tensor],
198*da0073e9SAndroid Build Coastguard Worker    axs: List[Tensor],
199*da0073e9SAndroid Build Coastguard Worker    mus: List[Tensor],
200*da0073e9SAndroid Build Coastguard Worker    etas: List[Tensor],
201*da0073e9SAndroid Build Coastguard Worker    state_steps: List[Tensor],
202*da0073e9SAndroid Build Coastguard Worker    *,
203*da0073e9SAndroid Build Coastguard Worker    lambd: float,
204*da0073e9SAndroid Build Coastguard Worker    lr: float,
205*da0073e9SAndroid Build Coastguard Worker    t0: float,
206*da0073e9SAndroid Build Coastguard Worker    alpha: float,
207*da0073e9SAndroid Build Coastguard Worker    weight_decay: float,
208*da0073e9SAndroid Build Coastguard Worker    maximize: bool,
209*da0073e9SAndroid Build Coastguard Worker    differentiable: bool,
210*da0073e9SAndroid Build Coastguard Worker    capturable: bool,
211*da0073e9SAndroid Build Coastguard Worker    has_complex: bool,
212*da0073e9SAndroid Build Coastguard Worker):
213*da0073e9SAndroid Build Coastguard Worker    for i, param in enumerate(params):
214*da0073e9SAndroid Build Coastguard Worker        grad = grads[i]
215*da0073e9SAndroid Build Coastguard Worker        grad = grad if not maximize else -grad
216*da0073e9SAndroid Build Coastguard Worker        mu = mus[i]
217*da0073e9SAndroid Build Coastguard Worker        ax = axs[i]
218*da0073e9SAndroid Build Coastguard Worker        eta = etas[i]
219*da0073e9SAndroid Build Coastguard Worker        step_t = state_steps[i]
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Worker        # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
222*da0073e9SAndroid Build Coastguard Worker        if not torch._utils.is_compiling() and capturable:
223*da0073e9SAndroid Build Coastguard Worker            capturable_supported_devices = _get_capturable_supported_devices()
224*da0073e9SAndroid Build Coastguard Worker            assert (
225*da0073e9SAndroid Build Coastguard Worker                param.device.type
226*da0073e9SAndroid Build Coastguard Worker                == mu.device.type
227*da0073e9SAndroid Build Coastguard Worker                == eta.device.type
228*da0073e9SAndroid Build Coastguard Worker                == step_t.device.type
229*da0073e9SAndroid Build Coastguard Worker                and param.device.type in capturable_supported_devices
230*da0073e9SAndroid Build Coastguard Worker            ), (
231*da0073e9SAndroid Build Coastguard Worker                f"If capturable=True, params, mus, etas, and state_steps must be "
232*da0073e9SAndroid Build Coastguard Worker                f"on supported devices: {capturable_supported_devices}."
233*da0073e9SAndroid Build Coastguard Worker            )
234*da0073e9SAndroid Build Coastguard Worker
235*da0073e9SAndroid Build Coastguard Worker        if torch.is_complex(param):
236*da0073e9SAndroid Build Coastguard Worker            grad = torch.view_as_real(grad)
237*da0073e9SAndroid Build Coastguard Worker            param = torch.view_as_real(param)
238*da0073e9SAndroid Build Coastguard Worker            ax = torch.view_as_real(ax)
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker        # update step
241*da0073e9SAndroid Build Coastguard Worker        step_t += 1
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Worker        if weight_decay != 0:
244*da0073e9SAndroid Build Coastguard Worker            grad = grad.add(param, alpha=weight_decay)
245*da0073e9SAndroid Build Coastguard Worker
246*da0073e9SAndroid Build Coastguard Worker        if capturable:
247*da0073e9SAndroid Build Coastguard Worker            param.mul_(1 - lambd * eta)
248*da0073e9SAndroid Build Coastguard Worker            param.addcmul_(grad, eta, value=-1)  # update parameter
249*da0073e9SAndroid Build Coastguard Worker        else:
250*da0073e9SAndroid Build Coastguard Worker            eta_value = _get_value(eta)
251*da0073e9SAndroid Build Coastguard Worker            param.mul_(1 - lambd * eta_value)  # decay term
252*da0073e9SAndroid Build Coastguard Worker            param.add_(grad, alpha=-eta_value)  # update parameter
253*da0073e9SAndroid Build Coastguard Worker
254*da0073e9SAndroid Build Coastguard Worker        # averaging
255*da0073e9SAndroid Build Coastguard Worker        if capturable or mu.item() != 1:
256*da0073e9SAndroid Build Coastguard Worker            ax.add_(param.sub(ax).mul_(mu))
257*da0073e9SAndroid Build Coastguard Worker        else:
258*da0073e9SAndroid Build Coastguard Worker            ax.copy_(param)
259*da0073e9SAndroid Build Coastguard Worker
260*da0073e9SAndroid Build Coastguard Worker        if capturable:
261*da0073e9SAndroid Build Coastguard Worker            eta.copy_(lr / ((1 + lambd * lr * step_t) ** alpha))
262*da0073e9SAndroid Build Coastguard Worker            mu.copy_(1 / torch.maximum(step_t - t0, torch.ones_like(step_t)))
263*da0073e9SAndroid Build Coastguard Worker        else:
264*da0073e9SAndroid Build Coastguard Worker            step = _get_value(step_t)
265*da0073e9SAndroid Build Coastguard Worker            new_eta = torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha))
266*da0073e9SAndroid Build Coastguard Worker            eta.copy_(new_eta)
267*da0073e9SAndroid Build Coastguard Worker            new_mu = torch.as_tensor(1 / max(1, step - t0))
268*da0073e9SAndroid Build Coastguard Worker            mu.copy_(new_mu)
269*da0073e9SAndroid Build Coastguard Worker
270*da0073e9SAndroid Build Coastguard Worker
271*da0073e9SAndroid Build Coastguard Workerdef _multi_tensor_asgd(
272*da0073e9SAndroid Build Coastguard Worker    params: List[Tensor],
273*da0073e9SAndroid Build Coastguard Worker    grads: List[Tensor],
274*da0073e9SAndroid Build Coastguard Worker    axs: List[Tensor],
275*da0073e9SAndroid Build Coastguard Worker    mus: List[Tensor],
276*da0073e9SAndroid Build Coastguard Worker    etas: List[Tensor],
277*da0073e9SAndroid Build Coastguard Worker    state_steps: List[Tensor],
278*da0073e9SAndroid Build Coastguard Worker    *,
279*da0073e9SAndroid Build Coastguard Worker    lambd: float,
280*da0073e9SAndroid Build Coastguard Worker    lr: float,
281*da0073e9SAndroid Build Coastguard Worker    t0: float,
282*da0073e9SAndroid Build Coastguard Worker    alpha: float,
283*da0073e9SAndroid Build Coastguard Worker    weight_decay: float,
284*da0073e9SAndroid Build Coastguard Worker    maximize: bool,
285*da0073e9SAndroid Build Coastguard Worker    differentiable: bool,
286*da0073e9SAndroid Build Coastguard Worker    capturable: bool,
287*da0073e9SAndroid Build Coastguard Worker    has_complex: bool,
288*da0073e9SAndroid Build Coastguard Worker):
289*da0073e9SAndroid Build Coastguard Worker    if len(params) == 0:
290*da0073e9SAndroid Build Coastguard Worker        return
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Worker    assert not differentiable, "_foreach ops don't support autograd"
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker    # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
295*da0073e9SAndroid Build Coastguard Worker    if not torch._utils.is_compiling() and capturable:
296*da0073e9SAndroid Build Coastguard Worker        capturable_supported_devices = _get_capturable_supported_devices(
297*da0073e9SAndroid Build Coastguard Worker            supports_xla=False
298*da0073e9SAndroid Build Coastguard Worker        )
299*da0073e9SAndroid Build Coastguard Worker        assert all(
300*da0073e9SAndroid Build Coastguard Worker            p.device.type == mu.device.type == eta.device.type == step.device.type
301*da0073e9SAndroid Build Coastguard Worker            and p.device.type in capturable_supported_devices
302*da0073e9SAndroid Build Coastguard Worker            for p, mu, eta, step in zip(params, mus, etas, state_steps)
303*da0073e9SAndroid Build Coastguard Worker        ), f"If capturable=True, params, mus, etas, and state_steps must be on supported devices: {capturable_supported_devices}."
304*da0073e9SAndroid Build Coastguard Worker
305*da0073e9SAndroid Build Coastguard Worker    grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
306*da0073e9SAndroid Build Coastguard Worker        [params, grads, axs, mus, etas, state_steps]  # type: ignore[list-item]
307*da0073e9SAndroid Build Coastguard Worker    )
308*da0073e9SAndroid Build Coastguard Worker    for (device, _), (
309*da0073e9SAndroid Build Coastguard Worker        (
310*da0073e9SAndroid Build Coastguard Worker            grouped_params_,
311*da0073e9SAndroid Build Coastguard Worker            grouped_grads_,
312*da0073e9SAndroid Build Coastguard Worker            grouped_axs_,
313*da0073e9SAndroid Build Coastguard Worker            grouped_mus_,
314*da0073e9SAndroid Build Coastguard Worker            grouped_etas_,
315*da0073e9SAndroid Build Coastguard Worker            grouped_state_steps_,
316*da0073e9SAndroid Build Coastguard Worker        ),
317*da0073e9SAndroid Build Coastguard Worker        _,
318*da0073e9SAndroid Build Coastguard Worker    ) in grouped_tensors.items():
319*da0073e9SAndroid Build Coastguard Worker        grouped_params = cast(List[Tensor], grouped_params_)
320*da0073e9SAndroid Build Coastguard Worker        grouped_grads = cast(List[Tensor], grouped_grads_)
321*da0073e9SAndroid Build Coastguard Worker        grouped_axs = cast(List[Tensor], grouped_axs_)
322*da0073e9SAndroid Build Coastguard Worker        grouped_mus = cast(List[Tensor], grouped_mus_)
323*da0073e9SAndroid Build Coastguard Worker        grouped_etas = cast(List[Tensor], grouped_etas_)
324*da0073e9SAndroid Build Coastguard Worker        grouped_state_steps = cast(List[Tensor], grouped_state_steps_)
325*da0073e9SAndroid Build Coastguard Worker
326*da0073e9SAndroid Build Coastguard Worker        if has_complex:
327*da0073e9SAndroid Build Coastguard Worker            _view_as_real(grouped_params, grouped_grads, grouped_axs)
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker        if maximize:
330*da0073e9SAndroid Build Coastguard Worker            grouped_grads = torch._foreach_neg(grouped_grads)  # type: ignore[assignment]
331*da0073e9SAndroid Build Coastguard Worker
332*da0073e9SAndroid Build Coastguard Worker        # Update steps
333*da0073e9SAndroid Build Coastguard Worker        # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
334*da0073e9SAndroid Build Coastguard Worker        # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
335*da0073e9SAndroid Build Coastguard Worker        # wrapped it once now. The alpha is required to assure we go to the right overload.
336*da0073e9SAndroid Build Coastguard Worker        if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu:
337*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(
338*da0073e9SAndroid Build Coastguard Worker                grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
339*da0073e9SAndroid Build Coastguard Worker            )
340*da0073e9SAndroid Build Coastguard Worker        else:
341*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(grouped_state_steps, 1)
342*da0073e9SAndroid Build Coastguard Worker
343*da0073e9SAndroid Build Coastguard Worker        # intermediate = grad + param * lambd
344*da0073e9SAndroid Build Coastguard Worker        intermediate: Union[Tuple[Tensor, ...], List[Tensor]]
345*da0073e9SAndroid Build Coastguard Worker        if weight_decay != 0:
346*da0073e9SAndroid Build Coastguard Worker            if maximize:
347*da0073e9SAndroid Build Coastguard Worker                torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay)
348*da0073e9SAndroid Build Coastguard Worker                intermediate = grouped_grads
349*da0073e9SAndroid Build Coastguard Worker            else:
350*da0073e9SAndroid Build Coastguard Worker                intermediate = torch._foreach_add(
351*da0073e9SAndroid Build Coastguard Worker                    grouped_grads, grouped_params, alpha=weight_decay
352*da0073e9SAndroid Build Coastguard Worker                )
353*da0073e9SAndroid Build Coastguard Worker
354*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(intermediate, grouped_params, alpha=lambd)
355*da0073e9SAndroid Build Coastguard Worker        else:
356*da0073e9SAndroid Build Coastguard Worker            intermediate = torch._foreach_add(
357*da0073e9SAndroid Build Coastguard Worker                grouped_grads, grouped_params, alpha=lambd
358*da0073e9SAndroid Build Coastguard Worker            )
359*da0073e9SAndroid Build Coastguard Worker
360*da0073e9SAndroid Build Coastguard Worker        # update param
361*da0073e9SAndroid Build Coastguard Worker        # param * (1 - lambd * eta) - eta * grad
362*da0073e9SAndroid Build Coastguard Worker        # => param - param * lambd * eta - eta * grad
363*da0073e9SAndroid Build Coastguard Worker        # => param - eta * intermediate
364*da0073e9SAndroid Build Coastguard Worker        torch._foreach_addcmul_(grouped_params, intermediate, grouped_etas, value=-1)
365*da0073e9SAndroid Build Coastguard Worker        del intermediate
366*da0073e9SAndroid Build Coastguard Worker
367*da0073e9SAndroid Build Coastguard Worker        # update grouped_axs
368*da0073e9SAndroid Build Coastguard Worker        # averaging: ax = ax + mu * (param - ax)
369*da0073e9SAndroid Build Coastguard Worker        # Note (mlazos): We can't use lerp here since it requires weight to be float64
370*da0073e9SAndroid Build Coastguard Worker        # and our grouping code requires dtypes to match for all tensors in a group (and it should, since
371*da0073e9SAndroid Build Coastguard Worker        # we use the mus in other places)
372*da0073e9SAndroid Build Coastguard Worker        # all dtypes need to match, so we could introduce a cast in a loop
373*da0073e9SAndroid Build Coastguard Worker        # but since this only adds one additional kernel launch, this looks like the cleaner
374*da0073e9SAndroid Build Coastguard Worker        # and faster solution
375*da0073e9SAndroid Build Coastguard Worker        intermediate = torch._foreach_sub(grouped_params, grouped_axs)
376*da0073e9SAndroid Build Coastguard Worker        torch._foreach_addcmul_(grouped_axs, intermediate, grouped_mus)
377*da0073e9SAndroid Build Coastguard Worker        del intermediate
378*da0073e9SAndroid Build Coastguard Worker
379*da0073e9SAndroid Build Coastguard Worker        new_etas: Union[Tuple[Tensor, ...], List[Tensor]]
380*da0073e9SAndroid Build Coastguard Worker        new_mus: Union[Tuple[Tensor, ...], List[Tensor]]
381*da0073e9SAndroid Build Coastguard Worker        if capturable:
382*da0073e9SAndroid Build Coastguard Worker            # update grouped_mus
383*da0073e9SAndroid Build Coastguard Worker            new_mus = torch._foreach_sub(grouped_state_steps, t0)
384*da0073e9SAndroid Build Coastguard Worker            torch._foreach_maximum_(new_mus, 1.0)
385*da0073e9SAndroid Build Coastguard Worker            torch._foreach_reciprocal_(new_mus)
386*da0073e9SAndroid Build Coastguard Worker            torch._foreach_copy_(grouped_mus, new_mus)
387*da0073e9SAndroid Build Coastguard Worker            del new_mus
388*da0073e9SAndroid Build Coastguard Worker
389*da0073e9SAndroid Build Coastguard Worker            # update eta = lr / ((1 + lambd * lr * step)^alpha)
390*da0073e9SAndroid Build Coastguard Worker            new_etas = torch._foreach_mul(grouped_state_steps, lambd)
391*da0073e9SAndroid Build Coastguard Worker            torch._foreach_mul_(new_etas, lr)
392*da0073e9SAndroid Build Coastguard Worker            torch._foreach_add_(new_etas, 1)
393*da0073e9SAndroid Build Coastguard Worker            torch._foreach_pow_(new_etas, alpha)
394*da0073e9SAndroid Build Coastguard Worker            torch._foreach_reciprocal_(new_etas)
395*da0073e9SAndroid Build Coastguard Worker            torch._foreach_mul_(new_etas, lr)
396*da0073e9SAndroid Build Coastguard Worker            torch._foreach_copy_(grouped_etas, new_etas)
397*da0073e9SAndroid Build Coastguard Worker        else:
398*da0073e9SAndroid Build Coastguard Worker            new_etas = [
399*da0073e9SAndroid Build Coastguard Worker                torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha), device=device)
400*da0073e9SAndroid Build Coastguard Worker                for step in grouped_state_steps
401*da0073e9SAndroid Build Coastguard Worker            ]
402*da0073e9SAndroid Build Coastguard Worker            new_mus = [
403*da0073e9SAndroid Build Coastguard Worker                torch.as_tensor(1 / max(1, _get_value(step) - t0), device=device)
404*da0073e9SAndroid Build Coastguard Worker                for step in grouped_state_steps
405*da0073e9SAndroid Build Coastguard Worker            ]
406*da0073e9SAndroid Build Coastguard Worker            torch._foreach_copy_(grouped_etas, new_etas)
407*da0073e9SAndroid Build Coastguard Worker            torch._foreach_copy_(grouped_mus, new_mus)
408*da0073e9SAndroid Build Coastguard Worker
409*da0073e9SAndroid Build Coastguard Worker
410*da0073e9SAndroid Build Coastguard Worker@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_asgd)
411*da0073e9SAndroid Build Coastguard Workerdef asgd(
412*da0073e9SAndroid Build Coastguard Worker    params: List[Tensor],
413*da0073e9SAndroid Build Coastguard Worker    grads: List[Tensor],
414*da0073e9SAndroid Build Coastguard Worker    axs: List[Tensor],
415*da0073e9SAndroid Build Coastguard Worker    mus: List[Tensor],
416*da0073e9SAndroid Build Coastguard Worker    etas: List[Tensor],
417*da0073e9SAndroid Build Coastguard Worker    state_steps: List[Tensor],
418*da0073e9SAndroid Build Coastguard Worker    # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
419*da0073e9SAndroid Build Coastguard Worker    # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
420*da0073e9SAndroid Build Coastguard Worker    foreach: Optional[bool] = None,
421*da0073e9SAndroid Build Coastguard Worker    maximize: bool = False,
422*da0073e9SAndroid Build Coastguard Worker    differentiable: bool = False,
423*da0073e9SAndroid Build Coastguard Worker    capturable: bool = False,
424*da0073e9SAndroid Build Coastguard Worker    has_complex: bool = False,
425*da0073e9SAndroid Build Coastguard Worker    *,
426*da0073e9SAndroid Build Coastguard Worker    lambd: float,
427*da0073e9SAndroid Build Coastguard Worker    lr: float,
428*da0073e9SAndroid Build Coastguard Worker    t0: float,
429*da0073e9SAndroid Build Coastguard Worker    alpha: float,
430*da0073e9SAndroid Build Coastguard Worker    weight_decay: float,
431*da0073e9SAndroid Build Coastguard Worker):
432*da0073e9SAndroid Build Coastguard Worker    r"""Functional API that performs asgd algorithm computation.
433*da0073e9SAndroid Build Coastguard Worker
434*da0073e9SAndroid Build Coastguard Worker    See :class:`~torch.optim.ASGD` for details.
435*da0073e9SAndroid Build Coastguard Worker    """
436*da0073e9SAndroid Build Coastguard Worker    if foreach is None:
437*da0073e9SAndroid Build Coastguard Worker        _, foreach = _default_to_fused_or_foreach(
438*da0073e9SAndroid Build Coastguard Worker            params, differentiable, use_fused=False
439*da0073e9SAndroid Build Coastguard Worker        )
440*da0073e9SAndroid Build Coastguard Worker
441*da0073e9SAndroid Build Coastguard Worker    if foreach and torch.jit.is_scripting():
442*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError("torch.jit.script not supported with foreach optimizers")
443*da0073e9SAndroid Build Coastguard Worker
444*da0073e9SAndroid Build Coastguard Worker    if foreach and not torch.jit.is_scripting():
445*da0073e9SAndroid Build Coastguard Worker        func = _multi_tensor_asgd
446*da0073e9SAndroid Build Coastguard Worker    else:
447*da0073e9SAndroid Build Coastguard Worker        func = _single_tensor_asgd
448*da0073e9SAndroid Build Coastguard Worker
449*da0073e9SAndroid Build Coastguard Worker    func(
450*da0073e9SAndroid Build Coastguard Worker        params,
451*da0073e9SAndroid Build Coastguard Worker        grads,
452*da0073e9SAndroid Build Coastguard Worker        axs,
453*da0073e9SAndroid Build Coastguard Worker        mus,
454*da0073e9SAndroid Build Coastguard Worker        etas,
455*da0073e9SAndroid Build Coastguard Worker        state_steps,
456*da0073e9SAndroid Build Coastguard Worker        lambd=lambd,
457*da0073e9SAndroid Build Coastguard Worker        lr=lr,
458*da0073e9SAndroid Build Coastguard Worker        t0=t0,
459*da0073e9SAndroid Build Coastguard Worker        alpha=alpha,
460*da0073e9SAndroid Build Coastguard Worker        weight_decay=weight_decay,
461*da0073e9SAndroid Build Coastguard Worker        maximize=maximize,
462*da0073e9SAndroid Build Coastguard Worker        differentiable=differentiable,
463*da0073e9SAndroid Build Coastguard Worker        capturable=capturable,
464*da0073e9SAndroid Build Coastguard Worker        has_complex=has_complex,
465*da0073e9SAndroid Build Coastguard Worker    )
466