1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-decorators 2*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 3*da0073e9SAndroid Build Coastguard Workerfrom typing import cast, List, Optional, Tuple, Union 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport torch 6*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerfrom .optimizer import ( 9*da0073e9SAndroid Build Coastguard Worker _capturable_doc, 10*da0073e9SAndroid Build Coastguard Worker _default_to_fused_or_foreach, 11*da0073e9SAndroid Build Coastguard Worker _differentiable_doc, 12*da0073e9SAndroid Build Coastguard Worker _disable_dynamo_if_unsupported, 13*da0073e9SAndroid Build Coastguard Worker _foreach_doc, 14*da0073e9SAndroid Build Coastguard Worker _get_capturable_supported_devices, 15*da0073e9SAndroid Build Coastguard Worker _get_scalar_dtype, 16*da0073e9SAndroid Build Coastguard Worker _get_value, 17*da0073e9SAndroid Build Coastguard Worker _maximize_doc, 18*da0073e9SAndroid Build Coastguard Worker _use_grad_for_differentiable, 19*da0073e9SAndroid Build Coastguard Worker _view_as_real, 20*da0073e9SAndroid Build Coastguard Worker Optimizer, 21*da0073e9SAndroid Build Coastguard Worker ParamsT, 22*da0073e9SAndroid Build Coastguard Worker) 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker__all__ = ["ASGD", "asgd"] 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Workerclass ASGD(Optimizer): 29*da0073e9SAndroid Build Coastguard Worker def __init__( 30*da0073e9SAndroid Build Coastguard Worker self, 31*da0073e9SAndroid Build Coastguard Worker params: ParamsT, 32*da0073e9SAndroid Build Coastguard Worker lr: Union[float, Tensor] = 1e-2, 33*da0073e9SAndroid Build Coastguard Worker lambd: float = 1e-4, 34*da0073e9SAndroid Build Coastguard Worker alpha: float = 0.75, 35*da0073e9SAndroid Build Coastguard Worker t0: float = 1e6, 36*da0073e9SAndroid Build Coastguard Worker weight_decay: float = 0, 37*da0073e9SAndroid Build Coastguard Worker foreach: Optional[bool] = None, 38*da0073e9SAndroid Build Coastguard Worker maximize: bool = False, 39*da0073e9SAndroid Build Coastguard Worker differentiable: bool = False, 40*da0073e9SAndroid Build Coastguard Worker capturable: bool = False, 41*da0073e9SAndroid Build Coastguard Worker ): 42*da0073e9SAndroid Build Coastguard Worker if isinstance(lr, Tensor) and lr.numel() != 1: 43*da0073e9SAndroid Build Coastguard Worker raise ValueError("Tensor lr must be 1-element") 44*da0073e9SAndroid Build Coastguard Worker if not 0.0 <= lr: 45*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Invalid learning rate: {lr}") 46*da0073e9SAndroid Build Coastguard Worker if not 0.0 <= weight_decay: 47*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Invalid weight_decay value: {weight_decay}") 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker defaults = dict( 50*da0073e9SAndroid Build Coastguard Worker lr=lr, 51*da0073e9SAndroid Build Coastguard Worker lambd=lambd, 52*da0073e9SAndroid Build Coastguard Worker alpha=alpha, 53*da0073e9SAndroid Build Coastguard Worker t0=t0, 54*da0073e9SAndroid Build Coastguard Worker weight_decay=weight_decay, 55*da0073e9SAndroid Build Coastguard Worker foreach=foreach, 56*da0073e9SAndroid Build Coastguard Worker maximize=maximize, 57*da0073e9SAndroid Build Coastguard Worker differentiable=differentiable, 58*da0073e9SAndroid Build Coastguard Worker capturable=capturable, 59*da0073e9SAndroid Build Coastguard Worker ) 60*da0073e9SAndroid Build Coastguard Worker super().__init__(params, defaults) 61*da0073e9SAndroid Build Coastguard Worker 62*da0073e9SAndroid Build Coastguard Worker def __setstate__(self, state): 63*da0073e9SAndroid Build Coastguard Worker super().__setstate__(state) 64*da0073e9SAndroid Build Coastguard Worker for group in self.param_groups: 65*da0073e9SAndroid Build Coastguard Worker group.setdefault("foreach", None) 66*da0073e9SAndroid Build Coastguard Worker group.setdefault("maximize", False) 67*da0073e9SAndroid Build Coastguard Worker group.setdefault("differentiable", False) 68*da0073e9SAndroid Build Coastguard Worker group.setdefault("capturable", False) 69*da0073e9SAndroid Build Coastguard Worker for p in group["params"]: 70*da0073e9SAndroid Build Coastguard Worker p_state = self.state.get(p, []) 71*da0073e9SAndroid Build Coastguard Worker if len(p_state) != 0: 72*da0073e9SAndroid Build Coastguard Worker if not torch.is_tensor(p_state["step"]): 73*da0073e9SAndroid Build Coastguard Worker step_val = float(p_state["step"]) 74*da0073e9SAndroid Build Coastguard Worker p_state["step"] = torch.tensor( 75*da0073e9SAndroid Build Coastguard Worker step_val, dtype=_get_scalar_dtype(), device=p.device 76*da0073e9SAndroid Build Coastguard Worker ) 77*da0073e9SAndroid Build Coastguard Worker if not torch.is_tensor(p_state["eta"]): 78*da0073e9SAndroid Build Coastguard Worker p_state["eta"] = torch.tensor( 79*da0073e9SAndroid Build Coastguard Worker p_state["eta"], dtype=_get_scalar_dtype(), device=p.device 80*da0073e9SAndroid Build Coastguard Worker ) 81*da0073e9SAndroid Build Coastguard Worker if not torch.is_tensor(p_state["mu"]): 82*da0073e9SAndroid Build Coastguard Worker p_state["mu"] = torch.tensor( 83*da0073e9SAndroid Build Coastguard Worker p_state["mu"], dtype=_get_scalar_dtype(), device=p.device 84*da0073e9SAndroid Build Coastguard Worker ) 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker def _init_group(self, group, params_with_grad, grads, mus, axs, etas, state_steps): 87*da0073e9SAndroid Build Coastguard Worker has_complex = False 88*da0073e9SAndroid Build Coastguard Worker for p in group["params"]: 89*da0073e9SAndroid Build Coastguard Worker if p.grad is not None: 90*da0073e9SAndroid Build Coastguard Worker has_complex |= torch.is_complex(p) 91*da0073e9SAndroid Build Coastguard Worker params_with_grad.append(p) 92*da0073e9SAndroid Build Coastguard Worker if p.grad.is_sparse: 93*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("ASGD does not support sparse gradients") 94*da0073e9SAndroid Build Coastguard Worker grads.append(p.grad) 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker state = self.state[p] 97*da0073e9SAndroid Build Coastguard Worker # State initialization 98*da0073e9SAndroid Build Coastguard Worker if len(state) == 0: 99*da0073e9SAndroid Build Coastguard Worker state["step"] = torch.zeros( 100*da0073e9SAndroid Build Coastguard Worker (), device=p.device, dtype=_get_scalar_dtype() 101*da0073e9SAndroid Build Coastguard Worker ) 102*da0073e9SAndroid Build Coastguard Worker state["eta"] = ( 103*da0073e9SAndroid Build Coastguard Worker torch.as_tensor( 104*da0073e9SAndroid Build Coastguard Worker group["lr"], device=p.device, dtype=_get_scalar_dtype() 105*da0073e9SAndroid Build Coastguard Worker ) 106*da0073e9SAndroid Build Coastguard Worker .clone() 107*da0073e9SAndroid Build Coastguard Worker .detach() 108*da0073e9SAndroid Build Coastguard Worker ) 109*da0073e9SAndroid Build Coastguard Worker state["mu"] = torch.ones( 110*da0073e9SAndroid Build Coastguard Worker (), device=p.device, dtype=_get_scalar_dtype() 111*da0073e9SAndroid Build Coastguard Worker ) 112*da0073e9SAndroid Build Coastguard Worker state["ax"] = torch.zeros_like( 113*da0073e9SAndroid Build Coastguard Worker p, memory_format=torch.preserve_format 114*da0073e9SAndroid Build Coastguard Worker ) 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker mus.append(state["mu"]) 117*da0073e9SAndroid Build Coastguard Worker axs.append(state["ax"]) 118*da0073e9SAndroid Build Coastguard Worker etas.append(state["eta"]) 119*da0073e9SAndroid Build Coastguard Worker state_steps.append(state["step"]) 120*da0073e9SAndroid Build Coastguard Worker return has_complex 121*da0073e9SAndroid Build Coastguard Worker 122*da0073e9SAndroid Build Coastguard Worker @_use_grad_for_differentiable 123*da0073e9SAndroid Build Coastguard Worker def step(self, closure=None): 124*da0073e9SAndroid Build Coastguard Worker """Perform a single optimization step. 125*da0073e9SAndroid Build Coastguard Worker 126*da0073e9SAndroid Build Coastguard Worker Args: 127*da0073e9SAndroid Build Coastguard Worker closure (Callable, optional): A closure that reevaluates the model 128*da0073e9SAndroid Build Coastguard Worker and returns the loss. 129*da0073e9SAndroid Build Coastguard Worker """ 130*da0073e9SAndroid Build Coastguard Worker self._cuda_graph_capture_health_check() 131*da0073e9SAndroid Build Coastguard Worker 132*da0073e9SAndroid Build Coastguard Worker loss = None 133*da0073e9SAndroid Build Coastguard Worker if closure is not None: 134*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 135*da0073e9SAndroid Build Coastguard Worker loss = closure() 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker for group in self.param_groups: 138*da0073e9SAndroid Build Coastguard Worker params_with_grad: List[Tensor] = [] 139*da0073e9SAndroid Build Coastguard Worker grads: List[Tensor] = [] 140*da0073e9SAndroid Build Coastguard Worker mus: List[Tensor] = [] 141*da0073e9SAndroid Build Coastguard Worker axs: List[Tensor] = [] 142*da0073e9SAndroid Build Coastguard Worker etas: List[Tensor] = [] 143*da0073e9SAndroid Build Coastguard Worker state_steps: List[Tensor] = [] 144*da0073e9SAndroid Build Coastguard Worker 145*da0073e9SAndroid Build Coastguard Worker has_complex = self._init_group( 146*da0073e9SAndroid Build Coastguard Worker group, params_with_grad, grads, mus, axs, etas, state_steps 147*da0073e9SAndroid Build Coastguard Worker ) 148*da0073e9SAndroid Build Coastguard Worker 149*da0073e9SAndroid Build Coastguard Worker asgd( 150*da0073e9SAndroid Build Coastguard Worker params_with_grad, 151*da0073e9SAndroid Build Coastguard Worker grads, 152*da0073e9SAndroid Build Coastguard Worker axs, 153*da0073e9SAndroid Build Coastguard Worker mus, 154*da0073e9SAndroid Build Coastguard Worker etas, 155*da0073e9SAndroid Build Coastguard Worker state_steps, 156*da0073e9SAndroid Build Coastguard Worker lambd=group["lambd"], 157*da0073e9SAndroid Build Coastguard Worker lr=group["lr"], 158*da0073e9SAndroid Build Coastguard Worker t0=group["t0"], 159*da0073e9SAndroid Build Coastguard Worker alpha=group["alpha"], 160*da0073e9SAndroid Build Coastguard Worker weight_decay=group["weight_decay"], 161*da0073e9SAndroid Build Coastguard Worker foreach=group["foreach"], 162*da0073e9SAndroid Build Coastguard Worker maximize=group["maximize"], 163*da0073e9SAndroid Build Coastguard Worker differentiable=group["differentiable"], 164*da0073e9SAndroid Build Coastguard Worker capturable=group["capturable"], 165*da0073e9SAndroid Build Coastguard Worker has_complex=has_complex, 166*da0073e9SAndroid Build Coastguard Worker ) 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker return loss 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard WorkerASGD.__doc__ = rf"""Implements Averaged Stochastic Gradient Descent. 172*da0073e9SAndroid Build Coastguard Worker 173*da0073e9SAndroid Build Coastguard Worker It has been proposed in `Acceleration of stochastic approximation by 174*da0073e9SAndroid Build Coastguard Worker averaging`_. 175*da0073e9SAndroid Build Coastguard Worker 176*da0073e9SAndroid Build Coastguard Worker Args: 177*da0073e9SAndroid Build Coastguard Worker params (iterable): iterable of parameters to optimize or dicts defining 178*da0073e9SAndroid Build Coastguard Worker parameter groups 179*da0073e9SAndroid Build Coastguard Worker lr (float, Tensor, optional): learning rate (default: 1e-2) 180*da0073e9SAndroid Build Coastguard Worker lambd (float, optional): decay term (default: 1e-4) 181*da0073e9SAndroid Build Coastguard Worker alpha (float, optional): power for eta update (default: 0.75) 182*da0073e9SAndroid Build Coastguard Worker t0 (float, optional): point at which to start averaging (default: 1e6) 183*da0073e9SAndroid Build Coastguard Worker weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 184*da0073e9SAndroid Build Coastguard Worker {_foreach_doc} 185*da0073e9SAndroid Build Coastguard Worker {_maximize_doc} 186*da0073e9SAndroid Build Coastguard Worker {_differentiable_doc} 187*da0073e9SAndroid Build Coastguard Worker {_capturable_doc} 188*da0073e9SAndroid Build Coastguard Worker 189*da0073e9SAndroid Build Coastguard Worker .. _Acceleration of stochastic approximation by averaging: 190*da0073e9SAndroid Build Coastguard Worker https://dl.acm.org/citation.cfm?id=131098 191*da0073e9SAndroid Build Coastguard Worker 192*da0073e9SAndroid Build Coastguard Worker """ 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Workerdef _single_tensor_asgd( 196*da0073e9SAndroid Build Coastguard Worker params: List[Tensor], 197*da0073e9SAndroid Build Coastguard Worker grads: List[Tensor], 198*da0073e9SAndroid Build Coastguard Worker axs: List[Tensor], 199*da0073e9SAndroid Build Coastguard Worker mus: List[Tensor], 200*da0073e9SAndroid Build Coastguard Worker etas: List[Tensor], 201*da0073e9SAndroid Build Coastguard Worker state_steps: List[Tensor], 202*da0073e9SAndroid Build Coastguard Worker *, 203*da0073e9SAndroid Build Coastguard Worker lambd: float, 204*da0073e9SAndroid Build Coastguard Worker lr: float, 205*da0073e9SAndroid Build Coastguard Worker t0: float, 206*da0073e9SAndroid Build Coastguard Worker alpha: float, 207*da0073e9SAndroid Build Coastguard Worker weight_decay: float, 208*da0073e9SAndroid Build Coastguard Worker maximize: bool, 209*da0073e9SAndroid Build Coastguard Worker differentiable: bool, 210*da0073e9SAndroid Build Coastguard Worker capturable: bool, 211*da0073e9SAndroid Build Coastguard Worker has_complex: bool, 212*da0073e9SAndroid Build Coastguard Worker): 213*da0073e9SAndroid Build Coastguard Worker for i, param in enumerate(params): 214*da0073e9SAndroid Build Coastguard Worker grad = grads[i] 215*da0073e9SAndroid Build Coastguard Worker grad = grad if not maximize else -grad 216*da0073e9SAndroid Build Coastguard Worker mu = mus[i] 217*da0073e9SAndroid Build Coastguard Worker ax = axs[i] 218*da0073e9SAndroid Build Coastguard Worker eta = etas[i] 219*da0073e9SAndroid Build Coastguard Worker step_t = state_steps[i] 220*da0073e9SAndroid Build Coastguard Worker 221*da0073e9SAndroid Build Coastguard Worker # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] 222*da0073e9SAndroid Build Coastguard Worker if not torch._utils.is_compiling() and capturable: 223*da0073e9SAndroid Build Coastguard Worker capturable_supported_devices = _get_capturable_supported_devices() 224*da0073e9SAndroid Build Coastguard Worker assert ( 225*da0073e9SAndroid Build Coastguard Worker param.device.type 226*da0073e9SAndroid Build Coastguard Worker == mu.device.type 227*da0073e9SAndroid Build Coastguard Worker == eta.device.type 228*da0073e9SAndroid Build Coastguard Worker == step_t.device.type 229*da0073e9SAndroid Build Coastguard Worker and param.device.type in capturable_supported_devices 230*da0073e9SAndroid Build Coastguard Worker ), ( 231*da0073e9SAndroid Build Coastguard Worker f"If capturable=True, params, mus, etas, and state_steps must be " 232*da0073e9SAndroid Build Coastguard Worker f"on supported devices: {capturable_supported_devices}." 233*da0073e9SAndroid Build Coastguard Worker ) 234*da0073e9SAndroid Build Coastguard Worker 235*da0073e9SAndroid Build Coastguard Worker if torch.is_complex(param): 236*da0073e9SAndroid Build Coastguard Worker grad = torch.view_as_real(grad) 237*da0073e9SAndroid Build Coastguard Worker param = torch.view_as_real(param) 238*da0073e9SAndroid Build Coastguard Worker ax = torch.view_as_real(ax) 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker # update step 241*da0073e9SAndroid Build Coastguard Worker step_t += 1 242*da0073e9SAndroid Build Coastguard Worker 243*da0073e9SAndroid Build Coastguard Worker if weight_decay != 0: 244*da0073e9SAndroid Build Coastguard Worker grad = grad.add(param, alpha=weight_decay) 245*da0073e9SAndroid Build Coastguard Worker 246*da0073e9SAndroid Build Coastguard Worker if capturable: 247*da0073e9SAndroid Build Coastguard Worker param.mul_(1 - lambd * eta) 248*da0073e9SAndroid Build Coastguard Worker param.addcmul_(grad, eta, value=-1) # update parameter 249*da0073e9SAndroid Build Coastguard Worker else: 250*da0073e9SAndroid Build Coastguard Worker eta_value = _get_value(eta) 251*da0073e9SAndroid Build Coastguard Worker param.mul_(1 - lambd * eta_value) # decay term 252*da0073e9SAndroid Build Coastguard Worker param.add_(grad, alpha=-eta_value) # update parameter 253*da0073e9SAndroid Build Coastguard Worker 254*da0073e9SAndroid Build Coastguard Worker # averaging 255*da0073e9SAndroid Build Coastguard Worker if capturable or mu.item() != 1: 256*da0073e9SAndroid Build Coastguard Worker ax.add_(param.sub(ax).mul_(mu)) 257*da0073e9SAndroid Build Coastguard Worker else: 258*da0073e9SAndroid Build Coastguard Worker ax.copy_(param) 259*da0073e9SAndroid Build Coastguard Worker 260*da0073e9SAndroid Build Coastguard Worker if capturable: 261*da0073e9SAndroid Build Coastguard Worker eta.copy_(lr / ((1 + lambd * lr * step_t) ** alpha)) 262*da0073e9SAndroid Build Coastguard Worker mu.copy_(1 / torch.maximum(step_t - t0, torch.ones_like(step_t))) 263*da0073e9SAndroid Build Coastguard Worker else: 264*da0073e9SAndroid Build Coastguard Worker step = _get_value(step_t) 265*da0073e9SAndroid Build Coastguard Worker new_eta = torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha)) 266*da0073e9SAndroid Build Coastguard Worker eta.copy_(new_eta) 267*da0073e9SAndroid Build Coastguard Worker new_mu = torch.as_tensor(1 / max(1, step - t0)) 268*da0073e9SAndroid Build Coastguard Worker mu.copy_(new_mu) 269*da0073e9SAndroid Build Coastguard Worker 270*da0073e9SAndroid Build Coastguard Worker 271*da0073e9SAndroid Build Coastguard Workerdef _multi_tensor_asgd( 272*da0073e9SAndroid Build Coastguard Worker params: List[Tensor], 273*da0073e9SAndroid Build Coastguard Worker grads: List[Tensor], 274*da0073e9SAndroid Build Coastguard Worker axs: List[Tensor], 275*da0073e9SAndroid Build Coastguard Worker mus: List[Tensor], 276*da0073e9SAndroid Build Coastguard Worker etas: List[Tensor], 277*da0073e9SAndroid Build Coastguard Worker state_steps: List[Tensor], 278*da0073e9SAndroid Build Coastguard Worker *, 279*da0073e9SAndroid Build Coastguard Worker lambd: float, 280*da0073e9SAndroid Build Coastguard Worker lr: float, 281*da0073e9SAndroid Build Coastguard Worker t0: float, 282*da0073e9SAndroid Build Coastguard Worker alpha: float, 283*da0073e9SAndroid Build Coastguard Worker weight_decay: float, 284*da0073e9SAndroid Build Coastguard Worker maximize: bool, 285*da0073e9SAndroid Build Coastguard Worker differentiable: bool, 286*da0073e9SAndroid Build Coastguard Worker capturable: bool, 287*da0073e9SAndroid Build Coastguard Worker has_complex: bool, 288*da0073e9SAndroid Build Coastguard Worker): 289*da0073e9SAndroid Build Coastguard Worker if len(params) == 0: 290*da0073e9SAndroid Build Coastguard Worker return 291*da0073e9SAndroid Build Coastguard Worker 292*da0073e9SAndroid Build Coastguard Worker assert not differentiable, "_foreach ops don't support autograd" 293*da0073e9SAndroid Build Coastguard Worker 294*da0073e9SAndroid Build Coastguard Worker # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] 295*da0073e9SAndroid Build Coastguard Worker if not torch._utils.is_compiling() and capturable: 296*da0073e9SAndroid Build Coastguard Worker capturable_supported_devices = _get_capturable_supported_devices( 297*da0073e9SAndroid Build Coastguard Worker supports_xla=False 298*da0073e9SAndroid Build Coastguard Worker ) 299*da0073e9SAndroid Build Coastguard Worker assert all( 300*da0073e9SAndroid Build Coastguard Worker p.device.type == mu.device.type == eta.device.type == step.device.type 301*da0073e9SAndroid Build Coastguard Worker and p.device.type in capturable_supported_devices 302*da0073e9SAndroid Build Coastguard Worker for p, mu, eta, step in zip(params, mus, etas, state_steps) 303*da0073e9SAndroid Build Coastguard Worker ), f"If capturable=True, params, mus, etas, and state_steps must be on supported devices: {capturable_supported_devices}." 304*da0073e9SAndroid Build Coastguard Worker 305*da0073e9SAndroid Build Coastguard Worker grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( 306*da0073e9SAndroid Build Coastguard Worker [params, grads, axs, mus, etas, state_steps] # type: ignore[list-item] 307*da0073e9SAndroid Build Coastguard Worker ) 308*da0073e9SAndroid Build Coastguard Worker for (device, _), ( 309*da0073e9SAndroid Build Coastguard Worker ( 310*da0073e9SAndroid Build Coastguard Worker grouped_params_, 311*da0073e9SAndroid Build Coastguard Worker grouped_grads_, 312*da0073e9SAndroid Build Coastguard Worker grouped_axs_, 313*da0073e9SAndroid Build Coastguard Worker grouped_mus_, 314*da0073e9SAndroid Build Coastguard Worker grouped_etas_, 315*da0073e9SAndroid Build Coastguard Worker grouped_state_steps_, 316*da0073e9SAndroid Build Coastguard Worker ), 317*da0073e9SAndroid Build Coastguard Worker _, 318*da0073e9SAndroid Build Coastguard Worker ) in grouped_tensors.items(): 319*da0073e9SAndroid Build Coastguard Worker grouped_params = cast(List[Tensor], grouped_params_) 320*da0073e9SAndroid Build Coastguard Worker grouped_grads = cast(List[Tensor], grouped_grads_) 321*da0073e9SAndroid Build Coastguard Worker grouped_axs = cast(List[Tensor], grouped_axs_) 322*da0073e9SAndroid Build Coastguard Worker grouped_mus = cast(List[Tensor], grouped_mus_) 323*da0073e9SAndroid Build Coastguard Worker grouped_etas = cast(List[Tensor], grouped_etas_) 324*da0073e9SAndroid Build Coastguard Worker grouped_state_steps = cast(List[Tensor], grouped_state_steps_) 325*da0073e9SAndroid Build Coastguard Worker 326*da0073e9SAndroid Build Coastguard Worker if has_complex: 327*da0073e9SAndroid Build Coastguard Worker _view_as_real(grouped_params, grouped_grads, grouped_axs) 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Worker if maximize: 330*da0073e9SAndroid Build Coastguard Worker grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment] 331*da0073e9SAndroid Build Coastguard Worker 332*da0073e9SAndroid Build Coastguard Worker # Update steps 333*da0073e9SAndroid Build Coastguard Worker # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over 334*da0073e9SAndroid Build Coastguard Worker # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just 335*da0073e9SAndroid Build Coastguard Worker # wrapped it once now. The alpha is required to assure we go to the right overload. 336*da0073e9SAndroid Build Coastguard Worker if not torch._utils.is_compiling() and grouped_state_steps[0].is_cpu: 337*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_( 338*da0073e9SAndroid Build Coastguard Worker grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 339*da0073e9SAndroid Build Coastguard Worker ) 340*da0073e9SAndroid Build Coastguard Worker else: 341*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_(grouped_state_steps, 1) 342*da0073e9SAndroid Build Coastguard Worker 343*da0073e9SAndroid Build Coastguard Worker # intermediate = grad + param * lambd 344*da0073e9SAndroid Build Coastguard Worker intermediate: Union[Tuple[Tensor, ...], List[Tensor]] 345*da0073e9SAndroid Build Coastguard Worker if weight_decay != 0: 346*da0073e9SAndroid Build Coastguard Worker if maximize: 347*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay) 348*da0073e9SAndroid Build Coastguard Worker intermediate = grouped_grads 349*da0073e9SAndroid Build Coastguard Worker else: 350*da0073e9SAndroid Build Coastguard Worker intermediate = torch._foreach_add( 351*da0073e9SAndroid Build Coastguard Worker grouped_grads, grouped_params, alpha=weight_decay 352*da0073e9SAndroid Build Coastguard Worker ) 353*da0073e9SAndroid Build Coastguard Worker 354*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_(intermediate, grouped_params, alpha=lambd) 355*da0073e9SAndroid Build Coastguard Worker else: 356*da0073e9SAndroid Build Coastguard Worker intermediate = torch._foreach_add( 357*da0073e9SAndroid Build Coastguard Worker grouped_grads, grouped_params, alpha=lambd 358*da0073e9SAndroid Build Coastguard Worker ) 359*da0073e9SAndroid Build Coastguard Worker 360*da0073e9SAndroid Build Coastguard Worker # update param 361*da0073e9SAndroid Build Coastguard Worker # param * (1 - lambd * eta) - eta * grad 362*da0073e9SAndroid Build Coastguard Worker # => param - param * lambd * eta - eta * grad 363*da0073e9SAndroid Build Coastguard Worker # => param - eta * intermediate 364*da0073e9SAndroid Build Coastguard Worker torch._foreach_addcmul_(grouped_params, intermediate, grouped_etas, value=-1) 365*da0073e9SAndroid Build Coastguard Worker del intermediate 366*da0073e9SAndroid Build Coastguard Worker 367*da0073e9SAndroid Build Coastguard Worker # update grouped_axs 368*da0073e9SAndroid Build Coastguard Worker # averaging: ax = ax + mu * (param - ax) 369*da0073e9SAndroid Build Coastguard Worker # Note (mlazos): We can't use lerp here since it requires weight to be float64 370*da0073e9SAndroid Build Coastguard Worker # and our grouping code requires dtypes to match for all tensors in a group (and it should, since 371*da0073e9SAndroid Build Coastguard Worker # we use the mus in other places) 372*da0073e9SAndroid Build Coastguard Worker # all dtypes need to match, so we could introduce a cast in a loop 373*da0073e9SAndroid Build Coastguard Worker # but since this only adds one additional kernel launch, this looks like the cleaner 374*da0073e9SAndroid Build Coastguard Worker # and faster solution 375*da0073e9SAndroid Build Coastguard Worker intermediate = torch._foreach_sub(grouped_params, grouped_axs) 376*da0073e9SAndroid Build Coastguard Worker torch._foreach_addcmul_(grouped_axs, intermediate, grouped_mus) 377*da0073e9SAndroid Build Coastguard Worker del intermediate 378*da0073e9SAndroid Build Coastguard Worker 379*da0073e9SAndroid Build Coastguard Worker new_etas: Union[Tuple[Tensor, ...], List[Tensor]] 380*da0073e9SAndroid Build Coastguard Worker new_mus: Union[Tuple[Tensor, ...], List[Tensor]] 381*da0073e9SAndroid Build Coastguard Worker if capturable: 382*da0073e9SAndroid Build Coastguard Worker # update grouped_mus 383*da0073e9SAndroid Build Coastguard Worker new_mus = torch._foreach_sub(grouped_state_steps, t0) 384*da0073e9SAndroid Build Coastguard Worker torch._foreach_maximum_(new_mus, 1.0) 385*da0073e9SAndroid Build Coastguard Worker torch._foreach_reciprocal_(new_mus) 386*da0073e9SAndroid Build Coastguard Worker torch._foreach_copy_(grouped_mus, new_mus) 387*da0073e9SAndroid Build Coastguard Worker del new_mus 388*da0073e9SAndroid Build Coastguard Worker 389*da0073e9SAndroid Build Coastguard Worker # update eta = lr / ((1 + lambd * lr * step)^alpha) 390*da0073e9SAndroid Build Coastguard Worker new_etas = torch._foreach_mul(grouped_state_steps, lambd) 391*da0073e9SAndroid Build Coastguard Worker torch._foreach_mul_(new_etas, lr) 392*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_(new_etas, 1) 393*da0073e9SAndroid Build Coastguard Worker torch._foreach_pow_(new_etas, alpha) 394*da0073e9SAndroid Build Coastguard Worker torch._foreach_reciprocal_(new_etas) 395*da0073e9SAndroid Build Coastguard Worker torch._foreach_mul_(new_etas, lr) 396*da0073e9SAndroid Build Coastguard Worker torch._foreach_copy_(grouped_etas, new_etas) 397*da0073e9SAndroid Build Coastguard Worker else: 398*da0073e9SAndroid Build Coastguard Worker new_etas = [ 399*da0073e9SAndroid Build Coastguard Worker torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha), device=device) 400*da0073e9SAndroid Build Coastguard Worker for step in grouped_state_steps 401*da0073e9SAndroid Build Coastguard Worker ] 402*da0073e9SAndroid Build Coastguard Worker new_mus = [ 403*da0073e9SAndroid Build Coastguard Worker torch.as_tensor(1 / max(1, _get_value(step) - t0), device=device) 404*da0073e9SAndroid Build Coastguard Worker for step in grouped_state_steps 405*da0073e9SAndroid Build Coastguard Worker ] 406*da0073e9SAndroid Build Coastguard Worker torch._foreach_copy_(grouped_etas, new_etas) 407*da0073e9SAndroid Build Coastguard Worker torch._foreach_copy_(grouped_mus, new_mus) 408*da0073e9SAndroid Build Coastguard Worker 409*da0073e9SAndroid Build Coastguard Worker 410*da0073e9SAndroid Build Coastguard Worker@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_asgd) 411*da0073e9SAndroid Build Coastguard Workerdef asgd( 412*da0073e9SAndroid Build Coastguard Worker params: List[Tensor], 413*da0073e9SAndroid Build Coastguard Worker grads: List[Tensor], 414*da0073e9SAndroid Build Coastguard Worker axs: List[Tensor], 415*da0073e9SAndroid Build Coastguard Worker mus: List[Tensor], 416*da0073e9SAndroid Build Coastguard Worker etas: List[Tensor], 417*da0073e9SAndroid Build Coastguard Worker state_steps: List[Tensor], 418*da0073e9SAndroid Build Coastguard Worker # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 419*da0073e9SAndroid Build Coastguard Worker # setting this as kwarg for now as functional API is compiled by torch/distributed/optim 420*da0073e9SAndroid Build Coastguard Worker foreach: Optional[bool] = None, 421*da0073e9SAndroid Build Coastguard Worker maximize: bool = False, 422*da0073e9SAndroid Build Coastguard Worker differentiable: bool = False, 423*da0073e9SAndroid Build Coastguard Worker capturable: bool = False, 424*da0073e9SAndroid Build Coastguard Worker has_complex: bool = False, 425*da0073e9SAndroid Build Coastguard Worker *, 426*da0073e9SAndroid Build Coastguard Worker lambd: float, 427*da0073e9SAndroid Build Coastguard Worker lr: float, 428*da0073e9SAndroid Build Coastguard Worker t0: float, 429*da0073e9SAndroid Build Coastguard Worker alpha: float, 430*da0073e9SAndroid Build Coastguard Worker weight_decay: float, 431*da0073e9SAndroid Build Coastguard Worker): 432*da0073e9SAndroid Build Coastguard Worker r"""Functional API that performs asgd algorithm computation. 433*da0073e9SAndroid Build Coastguard Worker 434*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.optim.ASGD` for details. 435*da0073e9SAndroid Build Coastguard Worker """ 436*da0073e9SAndroid Build Coastguard Worker if foreach is None: 437*da0073e9SAndroid Build Coastguard Worker _, foreach = _default_to_fused_or_foreach( 438*da0073e9SAndroid Build Coastguard Worker params, differentiable, use_fused=False 439*da0073e9SAndroid Build Coastguard Worker ) 440*da0073e9SAndroid Build Coastguard Worker 441*da0073e9SAndroid Build Coastguard Worker if foreach and torch.jit.is_scripting(): 442*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("torch.jit.script not supported with foreach optimizers") 443*da0073e9SAndroid Build Coastguard Worker 444*da0073e9SAndroid Build Coastguard Worker if foreach and not torch.jit.is_scripting(): 445*da0073e9SAndroid Build Coastguard Worker func = _multi_tensor_asgd 446*da0073e9SAndroid Build Coastguard Worker else: 447*da0073e9SAndroid Build Coastguard Worker func = _single_tensor_asgd 448*da0073e9SAndroid Build Coastguard Worker 449*da0073e9SAndroid Build Coastguard Worker func( 450*da0073e9SAndroid Build Coastguard Worker params, 451*da0073e9SAndroid Build Coastguard Worker grads, 452*da0073e9SAndroid Build Coastguard Worker axs, 453*da0073e9SAndroid Build Coastguard Worker mus, 454*da0073e9SAndroid Build Coastguard Worker etas, 455*da0073e9SAndroid Build Coastguard Worker state_steps, 456*da0073e9SAndroid Build Coastguard Worker lambd=lambd, 457*da0073e9SAndroid Build Coastguard Worker lr=lr, 458*da0073e9SAndroid Build Coastguard Worker t0=t0, 459*da0073e9SAndroid Build Coastguard Worker alpha=alpha, 460*da0073e9SAndroid Build Coastguard Worker weight_decay=weight_decay, 461*da0073e9SAndroid Build Coastguard Worker maximize=maximize, 462*da0073e9SAndroid Build Coastguard Worker differentiable=differentiable, 463*da0073e9SAndroid Build Coastguard Worker capturable=capturable, 464*da0073e9SAndroid Build Coastguard Worker has_complex=has_complex, 465*da0073e9SAndroid Build Coastguard Worker ) 466