1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-decorators 2*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 3*da0073e9SAndroid Build Coastguard Workerfrom typing import cast, List, Optional, Tuple, Union 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerimport torch 6*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Workerfrom .optimizer import ( 9*da0073e9SAndroid Build Coastguard Worker _capturable_doc, 10*da0073e9SAndroid Build Coastguard Worker _default_to_fused_or_foreach, 11*da0073e9SAndroid Build Coastguard Worker _device_dtype_check_for_fused, 12*da0073e9SAndroid Build Coastguard Worker _differentiable_doc, 13*da0073e9SAndroid Build Coastguard Worker _disable_dynamo_if_unsupported, 14*da0073e9SAndroid Build Coastguard Worker _foreach_doc, 15*da0073e9SAndroid Build Coastguard Worker _fused_doc, 16*da0073e9SAndroid Build Coastguard Worker _get_capturable_supported_devices, 17*da0073e9SAndroid Build Coastguard Worker _get_scalar_dtype, 18*da0073e9SAndroid Build Coastguard Worker _get_value, 19*da0073e9SAndroid Build Coastguard Worker _maximize_doc, 20*da0073e9SAndroid Build Coastguard Worker _stack_if_compiling, 21*da0073e9SAndroid Build Coastguard Worker _use_grad_for_differentiable, 22*da0073e9SAndroid Build Coastguard Worker _view_as_real, 23*da0073e9SAndroid Build Coastguard Worker DeviceDict, 24*da0073e9SAndroid Build Coastguard Worker Optimizer, 25*da0073e9SAndroid Build Coastguard Worker ParamsT, 26*da0073e9SAndroid Build Coastguard Worker) 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard Worker__all__ = ["AdamW", "adamw"] 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Workerclass AdamW(Optimizer): 33*da0073e9SAndroid Build Coastguard Worker def __init__( 34*da0073e9SAndroid Build Coastguard Worker self, 35*da0073e9SAndroid Build Coastguard Worker params: ParamsT, 36*da0073e9SAndroid Build Coastguard Worker lr: Union[float, Tensor] = 1e-3, 37*da0073e9SAndroid Build Coastguard Worker betas: Tuple[float, float] = (0.9, 0.999), 38*da0073e9SAndroid Build Coastguard Worker eps: float = 1e-8, 39*da0073e9SAndroid Build Coastguard Worker weight_decay: float = 1e-2, 40*da0073e9SAndroid Build Coastguard Worker amsgrad: bool = False, 41*da0073e9SAndroid Build Coastguard Worker *, 42*da0073e9SAndroid Build Coastguard Worker maximize: bool = False, 43*da0073e9SAndroid Build Coastguard Worker foreach: Optional[bool] = None, 44*da0073e9SAndroid Build Coastguard Worker capturable: bool = False, 45*da0073e9SAndroid Build Coastguard Worker differentiable: bool = False, 46*da0073e9SAndroid Build Coastguard Worker fused: Optional[bool] = None, 47*da0073e9SAndroid Build Coastguard Worker ): 48*da0073e9SAndroid Build Coastguard Worker if isinstance(lr, Tensor): 49*da0073e9SAndroid Build Coastguard Worker if foreach and not capturable: 50*da0073e9SAndroid Build Coastguard Worker raise ValueError( 51*da0073e9SAndroid Build Coastguard Worker "lr as a Tensor is not supported for capturable=False and foreach=True" 52*da0073e9SAndroid Build Coastguard Worker ) 53*da0073e9SAndroid Build Coastguard Worker if lr.numel() != 1: 54*da0073e9SAndroid Build Coastguard Worker raise ValueError("Tensor lr must be 1-element") 55*da0073e9SAndroid Build Coastguard Worker if not 0.0 <= lr: 56*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Invalid learning rate: {lr}") 57*da0073e9SAndroid Build Coastguard Worker if not 0.0 <= eps: 58*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Invalid epsilon value: {eps}") 59*da0073e9SAndroid Build Coastguard Worker if not 0.0 <= betas[0] < 1.0: 60*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") 61*da0073e9SAndroid Build Coastguard Worker if not 0.0 <= betas[1] < 1.0: 62*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") 63*da0073e9SAndroid Build Coastguard Worker if not 0.0 <= weight_decay: 64*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Invalid weight_decay value: {weight_decay}") 65*da0073e9SAndroid Build Coastguard Worker defaults = dict( 66*da0073e9SAndroid Build Coastguard Worker lr=lr, 67*da0073e9SAndroid Build Coastguard Worker betas=betas, 68*da0073e9SAndroid Build Coastguard Worker eps=eps, 69*da0073e9SAndroid Build Coastguard Worker weight_decay=weight_decay, 70*da0073e9SAndroid Build Coastguard Worker amsgrad=amsgrad, 71*da0073e9SAndroid Build Coastguard Worker foreach=foreach, 72*da0073e9SAndroid Build Coastguard Worker maximize=maximize, 73*da0073e9SAndroid Build Coastguard Worker capturable=capturable, 74*da0073e9SAndroid Build Coastguard Worker differentiable=differentiable, 75*da0073e9SAndroid Build Coastguard Worker fused=fused, 76*da0073e9SAndroid Build Coastguard Worker ) 77*da0073e9SAndroid Build Coastguard Worker super().__init__(params, defaults) 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker if fused: 80*da0073e9SAndroid Build Coastguard Worker if differentiable: 81*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("`fused` does not support `differentiable`") 82*da0073e9SAndroid Build Coastguard Worker self._step_supports_amp_scaling = True 83*da0073e9SAndroid Build Coastguard Worker if foreach: 84*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("`fused` and `foreach` cannot be `True` together.") 85*da0073e9SAndroid Build Coastguard Worker 86*da0073e9SAndroid Build Coastguard Worker def __setstate__(self, state): 87*da0073e9SAndroid Build Coastguard Worker super().__setstate__(state) 88*da0073e9SAndroid Build Coastguard Worker for group in self.param_groups: 89*da0073e9SAndroid Build Coastguard Worker group.setdefault("amsgrad", False) 90*da0073e9SAndroid Build Coastguard Worker group.setdefault("maximize", False) 91*da0073e9SAndroid Build Coastguard Worker group.setdefault("foreach", None) 92*da0073e9SAndroid Build Coastguard Worker group.setdefault("capturable", False) 93*da0073e9SAndroid Build Coastguard Worker group.setdefault("differentiable", False) 94*da0073e9SAndroid Build Coastguard Worker fused = group.setdefault("fused", None) 95*da0073e9SAndroid Build Coastguard Worker for p in group["params"]: 96*da0073e9SAndroid Build Coastguard Worker p_state = self.state.get(p, []) 97*da0073e9SAndroid Build Coastguard Worker if len(p_state) != 0 and not torch.is_tensor(p_state["step"]): 98*da0073e9SAndroid Build Coastguard Worker step_val = float(p_state["step"]) 99*da0073e9SAndroid Build Coastguard Worker p_state["step"] = ( 100*da0073e9SAndroid Build Coastguard Worker torch.tensor( 101*da0073e9SAndroid Build Coastguard Worker step_val, 102*da0073e9SAndroid Build Coastguard Worker dtype=_get_scalar_dtype(is_fused=fused), 103*da0073e9SAndroid Build Coastguard Worker device=p.device, 104*da0073e9SAndroid Build Coastguard Worker ) 105*da0073e9SAndroid Build Coastguard Worker if group["capturable"] or group["fused"] 106*da0073e9SAndroid Build Coastguard Worker else torch.tensor(step_val, dtype=_get_scalar_dtype()) 107*da0073e9SAndroid Build Coastguard Worker ) 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Worker def _init_group( 110*da0073e9SAndroid Build Coastguard Worker self, 111*da0073e9SAndroid Build Coastguard Worker group, 112*da0073e9SAndroid Build Coastguard Worker params_with_grad, 113*da0073e9SAndroid Build Coastguard Worker grads, 114*da0073e9SAndroid Build Coastguard Worker amsgrad, 115*da0073e9SAndroid Build Coastguard Worker exp_avgs, 116*da0073e9SAndroid Build Coastguard Worker exp_avg_sqs, 117*da0073e9SAndroid Build Coastguard Worker max_exp_avg_sqs, 118*da0073e9SAndroid Build Coastguard Worker state_steps, 119*da0073e9SAndroid Build Coastguard Worker ): 120*da0073e9SAndroid Build Coastguard Worker has_complex = False 121*da0073e9SAndroid Build Coastguard Worker for p in group["params"]: 122*da0073e9SAndroid Build Coastguard Worker if p.grad is None: 123*da0073e9SAndroid Build Coastguard Worker continue 124*da0073e9SAndroid Build Coastguard Worker has_complex |= torch.is_complex(p) 125*da0073e9SAndroid Build Coastguard Worker params_with_grad.append(p) 126*da0073e9SAndroid Build Coastguard Worker if p.grad.is_sparse: 127*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("AdamW does not support sparse gradients") 128*da0073e9SAndroid Build Coastguard Worker grads.append(p.grad) 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Worker state = self.state[p] 131*da0073e9SAndroid Build Coastguard Worker 132*da0073e9SAndroid Build Coastguard Worker # State initialization 133*da0073e9SAndroid Build Coastguard Worker if len(state) == 0: 134*da0073e9SAndroid Build Coastguard Worker if group["fused"]: 135*da0073e9SAndroid Build Coastguard Worker _device_dtype_check_for_fused(p) 136*da0073e9SAndroid Build Coastguard Worker # note(crcrpar): Deliberately host `step` on CPU if both capturable and fused are off. 137*da0073e9SAndroid Build Coastguard Worker # This is because kernel launches are costly on CUDA and XLA. 138*da0073e9SAndroid Build Coastguard Worker state["step"] = ( 139*da0073e9SAndroid Build Coastguard Worker torch.zeros( 140*da0073e9SAndroid Build Coastguard Worker (), 141*da0073e9SAndroid Build Coastguard Worker dtype=_get_scalar_dtype(is_fused=group["fused"]), 142*da0073e9SAndroid Build Coastguard Worker device=p.device, 143*da0073e9SAndroid Build Coastguard Worker ) 144*da0073e9SAndroid Build Coastguard Worker if group["capturable"] or group["fused"] 145*da0073e9SAndroid Build Coastguard Worker else torch.tensor(0.0, dtype=_get_scalar_dtype()) 146*da0073e9SAndroid Build Coastguard Worker ) 147*da0073e9SAndroid Build Coastguard Worker # Exponential moving average of gradient values 148*da0073e9SAndroid Build Coastguard Worker state["exp_avg"] = torch.zeros_like( 149*da0073e9SAndroid Build Coastguard Worker p, memory_format=torch.preserve_format 150*da0073e9SAndroid Build Coastguard Worker ) 151*da0073e9SAndroid Build Coastguard Worker # Exponential moving average of squared gradient values 152*da0073e9SAndroid Build Coastguard Worker state["exp_avg_sq"] = torch.zeros_like( 153*da0073e9SAndroid Build Coastguard Worker p, memory_format=torch.preserve_format 154*da0073e9SAndroid Build Coastguard Worker ) 155*da0073e9SAndroid Build Coastguard Worker if amsgrad: 156*da0073e9SAndroid Build Coastguard Worker # Maintains max of all exp. moving avg. of sq. grad. values 157*da0073e9SAndroid Build Coastguard Worker state["max_exp_avg_sq"] = torch.zeros_like( 158*da0073e9SAndroid Build Coastguard Worker p, memory_format=torch.preserve_format 159*da0073e9SAndroid Build Coastguard Worker ) 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker exp_avgs.append(state["exp_avg"]) 162*da0073e9SAndroid Build Coastguard Worker exp_avg_sqs.append(state["exp_avg_sq"]) 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Worker if group["amsgrad"]: 165*da0073e9SAndroid Build Coastguard Worker max_exp_avg_sqs.append(state["max_exp_avg_sq"]) 166*da0073e9SAndroid Build Coastguard Worker if group["differentiable"] and state["step"].requires_grad: 167*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 168*da0073e9SAndroid Build Coastguard Worker "`requires_grad` is not supported for `step` in differentiable mode" 169*da0073e9SAndroid Build Coastguard Worker ) 170*da0073e9SAndroid Build Coastguard Worker 171*da0073e9SAndroid Build Coastguard Worker # Foreach without capturable does not support a tensor lr 172*da0073e9SAndroid Build Coastguard Worker if ( 173*da0073e9SAndroid Build Coastguard Worker group["foreach"] 174*da0073e9SAndroid Build Coastguard Worker and isinstance(group["lr"], Tensor) 175*da0073e9SAndroid Build Coastguard Worker and not group["capturable"] 176*da0073e9SAndroid Build Coastguard Worker ): 177*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 178*da0073e9SAndroid Build Coastguard Worker "lr as a Tensor is not supported for capturable=False and foreach=True" 179*da0073e9SAndroid Build Coastguard Worker ) 180*da0073e9SAndroid Build Coastguard Worker 181*da0073e9SAndroid Build Coastguard Worker state_steps.append(state["step"]) 182*da0073e9SAndroid Build Coastguard Worker return has_complex 183*da0073e9SAndroid Build Coastguard Worker 184*da0073e9SAndroid Build Coastguard Worker @_use_grad_for_differentiable 185*da0073e9SAndroid Build Coastguard Worker def step(self, closure=None): 186*da0073e9SAndroid Build Coastguard Worker """Perform a single optimization step. 187*da0073e9SAndroid Build Coastguard Worker 188*da0073e9SAndroid Build Coastguard Worker Args: 189*da0073e9SAndroid Build Coastguard Worker closure (Callable, optional): A closure that reevaluates the model 190*da0073e9SAndroid Build Coastguard Worker and returns the loss. 191*da0073e9SAndroid Build Coastguard Worker """ 192*da0073e9SAndroid Build Coastguard Worker self._cuda_graph_capture_health_check() 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Worker loss = None 195*da0073e9SAndroid Build Coastguard Worker if closure is not None: 196*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 197*da0073e9SAndroid Build Coastguard Worker loss = closure() 198*da0073e9SAndroid Build Coastguard Worker 199*da0073e9SAndroid Build Coastguard Worker for group in self.param_groups: 200*da0073e9SAndroid Build Coastguard Worker params_with_grad: List[Tensor] = [] 201*da0073e9SAndroid Build Coastguard Worker grads: List[Tensor] = [] 202*da0073e9SAndroid Build Coastguard Worker exp_avgs: List[Tensor] = [] 203*da0073e9SAndroid Build Coastguard Worker exp_avg_sqs: List[Tensor] = [] 204*da0073e9SAndroid Build Coastguard Worker max_exp_avg_sqs: List[Tensor] = [] 205*da0073e9SAndroid Build Coastguard Worker state_steps: List[Tensor] = [] 206*da0073e9SAndroid Build Coastguard Worker amsgrad: bool = group["amsgrad"] 207*da0073e9SAndroid Build Coastguard Worker beta1, beta2 = cast(Tuple[float, float], group["betas"]) 208*da0073e9SAndroid Build Coastguard Worker 209*da0073e9SAndroid Build Coastguard Worker has_complex = self._init_group( 210*da0073e9SAndroid Build Coastguard Worker group, 211*da0073e9SAndroid Build Coastguard Worker params_with_grad, 212*da0073e9SAndroid Build Coastguard Worker grads, 213*da0073e9SAndroid Build Coastguard Worker amsgrad, 214*da0073e9SAndroid Build Coastguard Worker exp_avgs, 215*da0073e9SAndroid Build Coastguard Worker exp_avg_sqs, 216*da0073e9SAndroid Build Coastguard Worker max_exp_avg_sqs, 217*da0073e9SAndroid Build Coastguard Worker state_steps, 218*da0073e9SAndroid Build Coastguard Worker ) 219*da0073e9SAndroid Build Coastguard Worker 220*da0073e9SAndroid Build Coastguard Worker adamw( 221*da0073e9SAndroid Build Coastguard Worker params_with_grad, 222*da0073e9SAndroid Build Coastguard Worker grads, 223*da0073e9SAndroid Build Coastguard Worker exp_avgs, 224*da0073e9SAndroid Build Coastguard Worker exp_avg_sqs, 225*da0073e9SAndroid Build Coastguard Worker max_exp_avg_sqs, 226*da0073e9SAndroid Build Coastguard Worker state_steps, 227*da0073e9SAndroid Build Coastguard Worker amsgrad=amsgrad, 228*da0073e9SAndroid Build Coastguard Worker beta1=beta1, 229*da0073e9SAndroid Build Coastguard Worker beta2=beta2, 230*da0073e9SAndroid Build Coastguard Worker lr=group["lr"], 231*da0073e9SAndroid Build Coastguard Worker weight_decay=group["weight_decay"], 232*da0073e9SAndroid Build Coastguard Worker eps=group["eps"], 233*da0073e9SAndroid Build Coastguard Worker maximize=group["maximize"], 234*da0073e9SAndroid Build Coastguard Worker foreach=group["foreach"], 235*da0073e9SAndroid Build Coastguard Worker capturable=group["capturable"], 236*da0073e9SAndroid Build Coastguard Worker differentiable=group["differentiable"], 237*da0073e9SAndroid Build Coastguard Worker fused=group["fused"], 238*da0073e9SAndroid Build Coastguard Worker grad_scale=getattr(self, "grad_scale", None), 239*da0073e9SAndroid Build Coastguard Worker found_inf=getattr(self, "found_inf", None), 240*da0073e9SAndroid Build Coastguard Worker has_complex=has_complex, 241*da0073e9SAndroid Build Coastguard Worker ) 242*da0073e9SAndroid Build Coastguard Worker 243*da0073e9SAndroid Build Coastguard Worker return loss 244*da0073e9SAndroid Build Coastguard Worker 245*da0073e9SAndroid Build Coastguard Worker 246*da0073e9SAndroid Build Coastguard WorkerAdamW.__doc__ = ( 247*da0073e9SAndroid Build Coastguard Worker r"""Implements AdamW algorithm. 248*da0073e9SAndroid Build Coastguard Worker 249*da0073e9SAndroid Build Coastguard Worker .. math:: 250*da0073e9SAndroid Build Coastguard Worker \begin{aligned} 251*da0073e9SAndroid Build Coastguard Worker &\rule{110mm}{0.4pt} \\ 252*da0073e9SAndroid Build Coastguard Worker &\textbf{input} : \gamma \text{(lr)}, \: \beta_1, \beta_2 253*da0073e9SAndroid Build Coastguard Worker \text{(betas)}, \: \theta_0 \text{(params)}, \: f(\theta) \text{(objective)}, 254*da0073e9SAndroid Build Coastguard Worker \: \epsilon \text{ (epsilon)} \\ 255*da0073e9SAndroid Build Coastguard Worker &\hspace{13mm} \lambda \text{(weight decay)}, \: \textit{amsgrad}, 256*da0073e9SAndroid Build Coastguard Worker \: \textit{maximize} \\ 257*da0073e9SAndroid Build Coastguard Worker &\textbf{initialize} : m_0 \leftarrow 0 \text{ (first moment)}, v_0 \leftarrow 0 258*da0073e9SAndroid Build Coastguard Worker \text{ ( second moment)}, \: \widehat{v_0}^{max}\leftarrow 0 \\[-1.ex] 259*da0073e9SAndroid Build Coastguard Worker &\rule{110mm}{0.4pt} \\ 260*da0073e9SAndroid Build Coastguard Worker &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ 261*da0073e9SAndroid Build Coastguard Worker 262*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ 263*da0073e9SAndroid Build Coastguard Worker &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ 264*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}\textbf{else} \\ 265*da0073e9SAndroid Build Coastguard Worker &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ 266*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma \lambda \theta_{t-1} \\ 267*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ 268*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ 269*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ 270*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ 271*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}\textbf{if} \: amsgrad \\ 272*da0073e9SAndroid Build Coastguard Worker &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max}, 273*da0073e9SAndroid Build Coastguard Worker \widehat{v_t}) \\ 274*da0073e9SAndroid Build Coastguard Worker &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/ 275*da0073e9SAndroid Build Coastguard Worker \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\ 276*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}\textbf{else} \\ 277*da0073e9SAndroid Build Coastguard Worker &\hspace{10mm}\theta_t \leftarrow \theta_t - \gamma \widehat{m_t}/ 278*da0073e9SAndroid Build Coastguard Worker \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ 279*da0073e9SAndroid Build Coastguard Worker &\rule{110mm}{0.4pt} \\[-1.ex] 280*da0073e9SAndroid Build Coastguard Worker &\bf{return} \: \theta_t \\[-1.ex] 281*da0073e9SAndroid Build Coastguard Worker &\rule{110mm}{0.4pt} \\[-1.ex] 282*da0073e9SAndroid Build Coastguard Worker \end{aligned} 283*da0073e9SAndroid Build Coastguard Worker 284*da0073e9SAndroid Build Coastguard Worker For further details regarding the algorithm we refer to `Decoupled Weight Decay Regularization`_. 285*da0073e9SAndroid Build Coastguard Worker """ 286*da0073e9SAndroid Build Coastguard Worker + rf""" 287*da0073e9SAndroid Build Coastguard Worker Args: 288*da0073e9SAndroid Build Coastguard Worker params (iterable): iterable of parameters to optimize or dicts defining 289*da0073e9SAndroid Build Coastguard Worker parameter groups 290*da0073e9SAndroid Build Coastguard Worker lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR 291*da0073e9SAndroid Build Coastguard Worker is not yet supported for all our implementations. Please use a float 292*da0073e9SAndroid Build Coastguard Worker LR if you are not also specifying fused=True or capturable=True. 293*da0073e9SAndroid Build Coastguard Worker betas (Tuple[float, float], optional): coefficients used for computing 294*da0073e9SAndroid Build Coastguard Worker running averages of gradient and its square (default: (0.9, 0.999)) 295*da0073e9SAndroid Build Coastguard Worker eps (float, optional): term added to the denominator to improve 296*da0073e9SAndroid Build Coastguard Worker numerical stability (default: 1e-8) 297*da0073e9SAndroid Build Coastguard Worker weight_decay (float, optional): weight decay coefficient (default: 1e-2) 298*da0073e9SAndroid Build Coastguard Worker amsgrad (bool, optional): whether to use the AMSGrad variant of this 299*da0073e9SAndroid Build Coastguard Worker algorithm from the paper `On the Convergence of Adam and Beyond`_ 300*da0073e9SAndroid Build Coastguard Worker (default: False) 301*da0073e9SAndroid Build Coastguard Worker {_maximize_doc} 302*da0073e9SAndroid Build Coastguard Worker {_foreach_doc} 303*da0073e9SAndroid Build Coastguard Worker {_capturable_doc} 304*da0073e9SAndroid Build Coastguard Worker {_differentiable_doc} 305*da0073e9SAndroid Build Coastguard Worker {_fused_doc} 306*da0073e9SAndroid Build Coastguard Worker .. Note:: 307*da0073e9SAndroid Build Coastguard Worker A prototype implementation of Adam and AdamW for MPS supports `torch.float32` and `torch.float16`. 308*da0073e9SAndroid Build Coastguard Worker .. _Decoupled Weight Decay Regularization: 309*da0073e9SAndroid Build Coastguard Worker https://arxiv.org/abs/1711.05101 310*da0073e9SAndroid Build Coastguard Worker .. _On the Convergence of Adam and Beyond: 311*da0073e9SAndroid Build Coastguard Worker https://openreview.net/forum?id=ryQu7f-RZ 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker """ 314*da0073e9SAndroid Build Coastguard Worker) 315*da0073e9SAndroid Build Coastguard Worker 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard Workerdef _single_tensor_adamw( 318*da0073e9SAndroid Build Coastguard Worker params: List[Tensor], 319*da0073e9SAndroid Build Coastguard Worker grads: List[Tensor], 320*da0073e9SAndroid Build Coastguard Worker exp_avgs: List[Tensor], 321*da0073e9SAndroid Build Coastguard Worker exp_avg_sqs: List[Tensor], 322*da0073e9SAndroid Build Coastguard Worker max_exp_avg_sqs: List[Tensor], 323*da0073e9SAndroid Build Coastguard Worker state_steps: List[Tensor], 324*da0073e9SAndroid Build Coastguard Worker grad_scale: Optional[Tensor], 325*da0073e9SAndroid Build Coastguard Worker found_inf: Optional[Tensor], 326*da0073e9SAndroid Build Coastguard Worker *, 327*da0073e9SAndroid Build Coastguard Worker amsgrad: bool, 328*da0073e9SAndroid Build Coastguard Worker beta1: float, 329*da0073e9SAndroid Build Coastguard Worker beta2: float, 330*da0073e9SAndroid Build Coastguard Worker lr: Union[Tensor, float], 331*da0073e9SAndroid Build Coastguard Worker weight_decay: float, 332*da0073e9SAndroid Build Coastguard Worker eps: float, 333*da0073e9SAndroid Build Coastguard Worker maximize: bool, 334*da0073e9SAndroid Build Coastguard Worker capturable: bool, 335*da0073e9SAndroid Build Coastguard Worker differentiable: bool, 336*da0073e9SAndroid Build Coastguard Worker has_complex: bool, 337*da0073e9SAndroid Build Coastguard Worker): 338*da0073e9SAndroid Build Coastguard Worker assert grad_scale is None and found_inf is None 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker if torch.jit.is_scripting(): 341*da0073e9SAndroid Build Coastguard Worker # this assert is due to JIT being dumb and not realizing that the ops below 342*da0073e9SAndroid Build Coastguard Worker # have overloads to handle both float and Tensor lrs, so we just assert it's 343*da0073e9SAndroid Build Coastguard Worker # a float since most people using JIT are using floats 344*da0073e9SAndroid Build Coastguard Worker assert isinstance(lr, float) 345*da0073e9SAndroid Build Coastguard Worker 346*da0073e9SAndroid Build Coastguard Worker for i, param in enumerate(params): 347*da0073e9SAndroid Build Coastguard Worker grad = grads[i] if not maximize else -grads[i] 348*da0073e9SAndroid Build Coastguard Worker exp_avg = exp_avgs[i] 349*da0073e9SAndroid Build Coastguard Worker exp_avg_sq = exp_avg_sqs[i] 350*da0073e9SAndroid Build Coastguard Worker step_t = state_steps[i] 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] 353*da0073e9SAndroid Build Coastguard Worker if not torch._utils.is_compiling() and capturable: 354*da0073e9SAndroid Build Coastguard Worker capturable_supported_devices = _get_capturable_supported_devices() 355*da0073e9SAndroid Build Coastguard Worker assert ( 356*da0073e9SAndroid Build Coastguard Worker param.device.type == step_t.device.type 357*da0073e9SAndroid Build Coastguard Worker and param.device.type in capturable_supported_devices 358*da0073e9SAndroid Build Coastguard Worker ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." 359*da0073e9SAndroid Build Coastguard Worker 360*da0073e9SAndroid Build Coastguard Worker if torch.is_complex(param): 361*da0073e9SAndroid Build Coastguard Worker grad = torch.view_as_real(grad) 362*da0073e9SAndroid Build Coastguard Worker exp_avg = torch.view_as_real(exp_avg) 363*da0073e9SAndroid Build Coastguard Worker exp_avg_sq = torch.view_as_real(exp_avg_sq) 364*da0073e9SAndroid Build Coastguard Worker if amsgrad: 365*da0073e9SAndroid Build Coastguard Worker max_exp_avg_sqs[i] = torch.view_as_real(max_exp_avg_sqs[i]) 366*da0073e9SAndroid Build Coastguard Worker param = torch.view_as_real(param) 367*da0073e9SAndroid Build Coastguard Worker 368*da0073e9SAndroid Build Coastguard Worker # update step 369*da0073e9SAndroid Build Coastguard Worker step_t += 1 370*da0073e9SAndroid Build Coastguard Worker 371*da0073e9SAndroid Build Coastguard Worker # Perform stepweight decay 372*da0073e9SAndroid Build Coastguard Worker param.mul_(1 - lr * weight_decay) 373*da0073e9SAndroid Build Coastguard Worker 374*da0073e9SAndroid Build Coastguard Worker # Decay the first and second moment running average coefficient 375*da0073e9SAndroid Build Coastguard Worker exp_avg.lerp_(grad, 1 - beta1) 376*da0073e9SAndroid Build Coastguard Worker exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 377*da0073e9SAndroid Build Coastguard Worker 378*da0073e9SAndroid Build Coastguard Worker if capturable or differentiable: 379*da0073e9SAndroid Build Coastguard Worker step = step_t 380*da0073e9SAndroid Build Coastguard Worker 381*da0073e9SAndroid Build Coastguard Worker bias_correction1 = 1 - beta1**step 382*da0073e9SAndroid Build Coastguard Worker bias_correction2 = 1 - beta2**step 383*da0073e9SAndroid Build Coastguard Worker 384*da0073e9SAndroid Build Coastguard Worker step_size = lr / bias_correction1 385*da0073e9SAndroid Build Coastguard Worker step_size_neg = step_size.neg() 386*da0073e9SAndroid Build Coastguard Worker 387*da0073e9SAndroid Build Coastguard Worker bias_correction2_sqrt = bias_correction2.sqrt() 388*da0073e9SAndroid Build Coastguard Worker 389*da0073e9SAndroid Build Coastguard Worker if amsgrad: 390*da0073e9SAndroid Build Coastguard Worker # Maintains the maximum of all 2nd moment running avg. till now 391*da0073e9SAndroid Build Coastguard Worker if differentiable: 392*da0073e9SAndroid Build Coastguard Worker max_exp_avg_sq = max_exp_avg_sqs[i].clone() 393*da0073e9SAndroid Build Coastguard Worker else: 394*da0073e9SAndroid Build Coastguard Worker max_exp_avg_sq = max_exp_avg_sqs[i] 395*da0073e9SAndroid Build Coastguard Worker 396*da0073e9SAndroid Build Coastguard Worker max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sq, exp_avg_sq)) 397*da0073e9SAndroid Build Coastguard Worker 398*da0073e9SAndroid Build Coastguard Worker # Uses the max. for normalizing running avg. of gradient 399*da0073e9SAndroid Build Coastguard Worker # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write 400*da0073e9SAndroid Build Coastguard Worker # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor) 401*da0073e9SAndroid Build Coastguard Worker denom = ( 402*da0073e9SAndroid Build Coastguard Worker max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg) 403*da0073e9SAndroid Build Coastguard Worker ).add_(eps / step_size_neg) 404*da0073e9SAndroid Build Coastguard Worker else: 405*da0073e9SAndroid Build Coastguard Worker denom = ( 406*da0073e9SAndroid Build Coastguard Worker exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg) 407*da0073e9SAndroid Build Coastguard Worker ).add_(eps / step_size_neg) 408*da0073e9SAndroid Build Coastguard Worker 409*da0073e9SAndroid Build Coastguard Worker param.addcdiv_(exp_avg, denom) 410*da0073e9SAndroid Build Coastguard Worker else: 411*da0073e9SAndroid Build Coastguard Worker step = _get_value(step_t) 412*da0073e9SAndroid Build Coastguard Worker 413*da0073e9SAndroid Build Coastguard Worker bias_correction1 = 1 - beta1**step 414*da0073e9SAndroid Build Coastguard Worker bias_correction2 = 1 - beta2**step 415*da0073e9SAndroid Build Coastguard Worker 416*da0073e9SAndroid Build Coastguard Worker step_size = lr / bias_correction1 417*da0073e9SAndroid Build Coastguard Worker 418*da0073e9SAndroid Build Coastguard Worker bias_correction2_sqrt = bias_correction2**0.5 419*da0073e9SAndroid Build Coastguard Worker 420*da0073e9SAndroid Build Coastguard Worker if amsgrad: 421*da0073e9SAndroid Build Coastguard Worker # Maintains the maximum of all 2nd moment running avg. till now 422*da0073e9SAndroid Build Coastguard Worker torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) 423*da0073e9SAndroid Build Coastguard Worker 424*da0073e9SAndroid Build Coastguard Worker # Use the max. for normalizing running avg. of gradient 425*da0073e9SAndroid Build Coastguard Worker denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps) 426*da0073e9SAndroid Build Coastguard Worker else: 427*da0073e9SAndroid Build Coastguard Worker denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) 428*da0073e9SAndroid Build Coastguard Worker 429*da0073e9SAndroid Build Coastguard Worker param.addcdiv_(exp_avg, denom, value=-step_size) 430*da0073e9SAndroid Build Coastguard Worker 431*da0073e9SAndroid Build Coastguard Worker # Lastly, switch back to complex view 432*da0073e9SAndroid Build Coastguard Worker if amsgrad and torch.is_complex(params[i]): 433*da0073e9SAndroid Build Coastguard Worker max_exp_avg_sqs[i] = torch.view_as_complex(max_exp_avg_sqs[i]) 434*da0073e9SAndroid Build Coastguard Worker 435*da0073e9SAndroid Build Coastguard Worker 436*da0073e9SAndroid Build Coastguard Workerdef _multi_tensor_adamw( 437*da0073e9SAndroid Build Coastguard Worker params: List[Tensor], 438*da0073e9SAndroid Build Coastguard Worker grads: List[Tensor], 439*da0073e9SAndroid Build Coastguard Worker exp_avgs: List[Tensor], 440*da0073e9SAndroid Build Coastguard Worker exp_avg_sqs: List[Tensor], 441*da0073e9SAndroid Build Coastguard Worker max_exp_avg_sqs: List[Tensor], 442*da0073e9SAndroid Build Coastguard Worker state_steps: List[Tensor], 443*da0073e9SAndroid Build Coastguard Worker grad_scale: Optional[Tensor], 444*da0073e9SAndroid Build Coastguard Worker found_inf: Optional[Tensor], 445*da0073e9SAndroid Build Coastguard Worker *, 446*da0073e9SAndroid Build Coastguard Worker amsgrad: bool, 447*da0073e9SAndroid Build Coastguard Worker beta1: float, 448*da0073e9SAndroid Build Coastguard Worker beta2: float, 449*da0073e9SAndroid Build Coastguard Worker lr: Union[Tensor, float], 450*da0073e9SAndroid Build Coastguard Worker weight_decay: float, 451*da0073e9SAndroid Build Coastguard Worker eps: float, 452*da0073e9SAndroid Build Coastguard Worker maximize: bool, 453*da0073e9SAndroid Build Coastguard Worker capturable: bool, 454*da0073e9SAndroid Build Coastguard Worker differentiable: bool, 455*da0073e9SAndroid Build Coastguard Worker has_complex: bool, 456*da0073e9SAndroid Build Coastguard Worker): 457*da0073e9SAndroid Build Coastguard Worker if len(params) == 0: 458*da0073e9SAndroid Build Coastguard Worker return 459*da0073e9SAndroid Build Coastguard Worker 460*da0073e9SAndroid Build Coastguard Worker if isinstance(lr, Tensor) and not capturable: 461*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 462*da0073e9SAndroid Build Coastguard Worker "lr as a Tensor is not supported for capturable=False and foreach=True" 463*da0073e9SAndroid Build Coastguard Worker ) 464*da0073e9SAndroid Build Coastguard Worker 465*da0073e9SAndroid Build Coastguard Worker # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable] 466*da0073e9SAndroid Build Coastguard Worker if not torch._utils.is_compiling() and capturable: 467*da0073e9SAndroid Build Coastguard Worker capturable_supported_devices = _get_capturable_supported_devices( 468*da0073e9SAndroid Build Coastguard Worker supports_xla=False 469*da0073e9SAndroid Build Coastguard Worker ) 470*da0073e9SAndroid Build Coastguard Worker assert all( 471*da0073e9SAndroid Build Coastguard Worker p.device.type == step.device.type 472*da0073e9SAndroid Build Coastguard Worker and p.device.type in capturable_supported_devices 473*da0073e9SAndroid Build Coastguard Worker for p, step in zip(params, state_steps) 474*da0073e9SAndroid Build Coastguard Worker ), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}." 475*da0073e9SAndroid Build Coastguard Worker 476*da0073e9SAndroid Build Coastguard Worker assert not differentiable, "_foreach ops don't support autograd" 477*da0073e9SAndroid Build Coastguard Worker 478*da0073e9SAndroid Build Coastguard Worker assert grad_scale is None and found_inf is None 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Worker grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( 481*da0073e9SAndroid Build Coastguard Worker [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item] 482*da0073e9SAndroid Build Coastguard Worker ) 483*da0073e9SAndroid Build Coastguard Worker for ( 484*da0073e9SAndroid Build Coastguard Worker device_params_, 485*da0073e9SAndroid Build Coastguard Worker device_grads_, 486*da0073e9SAndroid Build Coastguard Worker device_exp_avgs_, 487*da0073e9SAndroid Build Coastguard Worker device_exp_avg_sqs_, 488*da0073e9SAndroid Build Coastguard Worker device_max_exp_avg_sqs_, 489*da0073e9SAndroid Build Coastguard Worker device_state_steps_, 490*da0073e9SAndroid Build Coastguard Worker ), _ in grouped_tensors.values(): 491*da0073e9SAndroid Build Coastguard Worker device_params = cast(List[Tensor], device_params_) 492*da0073e9SAndroid Build Coastguard Worker device_grads = cast(List[Tensor], device_grads_) 493*da0073e9SAndroid Build Coastguard Worker device_exp_avgs = cast(List[Tensor], device_exp_avgs_) 494*da0073e9SAndroid Build Coastguard Worker device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_) 495*da0073e9SAndroid Build Coastguard Worker device_state_steps = cast(List[Tensor], device_state_steps_) 496*da0073e9SAndroid Build Coastguard Worker 497*da0073e9SAndroid Build Coastguard Worker if has_complex: 498*da0073e9SAndroid Build Coastguard Worker if amsgrad: 499*da0073e9SAndroid Build Coastguard Worker device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_) 500*da0073e9SAndroid Build Coastguard Worker _view_as_real( 501*da0073e9SAndroid Build Coastguard Worker device_params, 502*da0073e9SAndroid Build Coastguard Worker device_grads, 503*da0073e9SAndroid Build Coastguard Worker device_exp_avgs, 504*da0073e9SAndroid Build Coastguard Worker device_exp_avg_sqs, 505*da0073e9SAndroid Build Coastguard Worker device_max_exp_avg_sqs, 506*da0073e9SAndroid Build Coastguard Worker ) 507*da0073e9SAndroid Build Coastguard Worker else: 508*da0073e9SAndroid Build Coastguard Worker _view_as_real( 509*da0073e9SAndroid Build Coastguard Worker device_params, device_grads, device_exp_avgs, device_exp_avg_sqs 510*da0073e9SAndroid Build Coastguard Worker ) 511*da0073e9SAndroid Build Coastguard Worker 512*da0073e9SAndroid Build Coastguard Worker if maximize: 513*da0073e9SAndroid Build Coastguard Worker device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] 514*da0073e9SAndroid Build Coastguard Worker 515*da0073e9SAndroid Build Coastguard Worker # Update steps 516*da0073e9SAndroid Build Coastguard Worker # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over 517*da0073e9SAndroid Build Coastguard Worker # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just 518*da0073e9SAndroid Build Coastguard Worker # wrapped it once now. The alpha is required to assure we go to the right overload. 519*da0073e9SAndroid Build Coastguard Worker if not torch._utils.is_compiling() and device_state_steps[0].is_cpu: 520*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_( 521*da0073e9SAndroid Build Coastguard Worker device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 522*da0073e9SAndroid Build Coastguard Worker ) 523*da0073e9SAndroid Build Coastguard Worker else: 524*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_(device_state_steps, 1) 525*da0073e9SAndroid Build Coastguard Worker 526*da0073e9SAndroid Build Coastguard Worker # Perform stepweight decay 527*da0073e9SAndroid Build Coastguard Worker if weight_decay != 0: 528*da0073e9SAndroid Build Coastguard Worker torch._foreach_mul_(device_params, 1 - lr * weight_decay) 529*da0073e9SAndroid Build Coastguard Worker 530*da0073e9SAndroid Build Coastguard Worker # Decay the first and second moment running average coefficient 531*da0073e9SAndroid Build Coastguard Worker torch._foreach_lerp_(device_exp_avgs, device_grads, 1 - beta1) 532*da0073e9SAndroid Build Coastguard Worker 533*da0073e9SAndroid Build Coastguard Worker torch._foreach_mul_(device_exp_avg_sqs, beta2) 534*da0073e9SAndroid Build Coastguard Worker torch._foreach_addcmul_( 535*da0073e9SAndroid Build Coastguard Worker device_exp_avg_sqs, device_grads, device_grads, 1 - beta2 536*da0073e9SAndroid Build Coastguard Worker ) 537*da0073e9SAndroid Build Coastguard Worker 538*da0073e9SAndroid Build Coastguard Worker # Delete the local intermediate since it won't be used anymore to save on peak memory 539*da0073e9SAndroid Build Coastguard Worker del device_grads 540*da0073e9SAndroid Build Coastguard Worker 541*da0073e9SAndroid Build Coastguard Worker bias_correction1: Union[Tuple[Tensor, ...], List[Tensor]] 542*da0073e9SAndroid Build Coastguard Worker bias_correction2: Union[Tuple[Tensor, ...], List[Tensor]] 543*da0073e9SAndroid Build Coastguard Worker bias_correction2_sqrt: Union[Tuple[Tensor, ...], List[Tensor]] 544*da0073e9SAndroid Build Coastguard Worker 545*da0073e9SAndroid Build Coastguard Worker if capturable: 546*da0073e9SAndroid Build Coastguard Worker bias_correction1 = torch._foreach_pow(beta1, device_state_steps) 547*da0073e9SAndroid Build Coastguard Worker bias_correction2 = torch._foreach_pow(beta2, device_state_steps) 548*da0073e9SAndroid Build Coastguard Worker # foreach_sub doesn't allow a scalar as the first arg 549*da0073e9SAndroid Build Coastguard Worker torch._foreach_sub_(bias_correction1, 1) 550*da0073e9SAndroid Build Coastguard Worker torch._foreach_sub_(bias_correction2, 1) 551*da0073e9SAndroid Build Coastguard Worker # we do not negate bias_correction1 as it'll need to be negated later anyway 552*da0073e9SAndroid Build Coastguard Worker torch._foreach_neg_(bias_correction2) 553*da0073e9SAndroid Build Coastguard Worker 554*da0073e9SAndroid Build Coastguard Worker # foreach_div doesn't allow a scalar as the first arg 555*da0073e9SAndroid Build Coastguard Worker torch._foreach_div_(bias_correction1, lr) 556*da0073e9SAndroid Build Coastguard Worker torch._foreach_reciprocal_(bias_correction1) 557*da0073e9SAndroid Build Coastguard Worker 558*da0073e9SAndroid Build Coastguard Worker torch._foreach_sqrt_(bias_correction2) 559*da0073e9SAndroid Build Coastguard Worker 560*da0073e9SAndroid Build Coastguard Worker # Re-assign for clarity as we maintain minimal intermediates: we'll have 561*da0073e9SAndroid Build Coastguard Worker # step_size = - lr / (1 - beta1 ^ t) where t = num_steps 562*da0073e9SAndroid Build Coastguard Worker # bias_correction2_sqrt = sqrt(1 - beta2 ^ t) 563*da0073e9SAndroid Build Coastguard Worker step_size = bias_correction1 564*da0073e9SAndroid Build Coastguard Worker bias_correction2_sqrt = bias_correction2 565*da0073e9SAndroid Build Coastguard Worker 566*da0073e9SAndroid Build Coastguard Worker if amsgrad: 567*da0073e9SAndroid Build Coastguard Worker device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_) 568*da0073e9SAndroid Build Coastguard Worker 569*da0073e9SAndroid Build Coastguard Worker # Maintains the maximum of all 2nd moment running avg. till now 570*da0073e9SAndroid Build Coastguard Worker torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) 571*da0073e9SAndroid Build Coastguard Worker 572*da0073e9SAndroid Build Coastguard Worker # Use the max. for normalizing running avg. of gradient 573*da0073e9SAndroid Build Coastguard Worker exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs) 574*da0073e9SAndroid Build Coastguard Worker else: 575*da0073e9SAndroid Build Coastguard Worker exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) 576*da0073e9SAndroid Build Coastguard Worker 577*da0073e9SAndroid Build Coastguard Worker torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) 578*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_(exp_avg_sq_sqrt, eps) 579*da0073e9SAndroid Build Coastguard Worker torch._foreach_div_(exp_avg_sq_sqrt, step_size) 580*da0073e9SAndroid Build Coastguard Worker 581*da0073e9SAndroid Build Coastguard Worker # at this point, exp_avg_sq_sqrt = - (1 - beta^t) * [sqrt(exp_avg_sq / (1 - beta2^t)) + eps] / lr 582*da0073e9SAndroid Build Coastguard Worker torch._foreach_addcdiv_(device_params, device_exp_avgs, exp_avg_sq_sqrt) 583*da0073e9SAndroid Build Coastguard Worker else: 584*da0073e9SAndroid Build Coastguard Worker bias_correction1 = [ 585*da0073e9SAndroid Build Coastguard Worker 1 - beta1 ** _get_value(step) for step in device_state_steps 586*da0073e9SAndroid Build Coastguard Worker ] 587*da0073e9SAndroid Build Coastguard Worker bias_correction2 = [ 588*da0073e9SAndroid Build Coastguard Worker 1 - beta2 ** _get_value(step) for step in device_state_steps 589*da0073e9SAndroid Build Coastguard Worker ] 590*da0073e9SAndroid Build Coastguard Worker 591*da0073e9SAndroid Build Coastguard Worker step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1]) 592*da0073e9SAndroid Build Coastguard Worker 593*da0073e9SAndroid Build Coastguard Worker bias_correction2_sqrt = [ 594*da0073e9SAndroid Build Coastguard Worker bc**0.5 for bc in bias_correction2 # type: ignore[arg-type] 595*da0073e9SAndroid Build Coastguard Worker ] 596*da0073e9SAndroid Build Coastguard Worker 597*da0073e9SAndroid Build Coastguard Worker if amsgrad: 598*da0073e9SAndroid Build Coastguard Worker device_max_exp_avg_sqs = cast(List[Tensor], device_max_exp_avg_sqs_) 599*da0073e9SAndroid Build Coastguard Worker 600*da0073e9SAndroid Build Coastguard Worker # Maintains the maximum of all 2nd moment running avg. till now 601*da0073e9SAndroid Build Coastguard Worker torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) 602*da0073e9SAndroid Build Coastguard Worker 603*da0073e9SAndroid Build Coastguard Worker # Use the max. for normalizing running avg. of gradient 604*da0073e9SAndroid Build Coastguard Worker exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs) 605*da0073e9SAndroid Build Coastguard Worker else: 606*da0073e9SAndroid Build Coastguard Worker exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) 607*da0073e9SAndroid Build Coastguard Worker 608*da0073e9SAndroid Build Coastguard Worker torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) 609*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_(exp_avg_sq_sqrt, eps) 610*da0073e9SAndroid Build Coastguard Worker torch._foreach_addcdiv_( 611*da0073e9SAndroid Build Coastguard Worker device_params, 612*da0073e9SAndroid Build Coastguard Worker device_exp_avgs, 613*da0073e9SAndroid Build Coastguard Worker exp_avg_sq_sqrt, 614*da0073e9SAndroid Build Coastguard Worker step_size, # type: ignore[arg-type] 615*da0073e9SAndroid Build Coastguard Worker ) 616*da0073e9SAndroid Build Coastguard Worker 617*da0073e9SAndroid Build Coastguard Worker 618*da0073e9SAndroid Build Coastguard Workerdef _fused_adamw( 619*da0073e9SAndroid Build Coastguard Worker params: List[Tensor], 620*da0073e9SAndroid Build Coastguard Worker grads: List[Tensor], 621*da0073e9SAndroid Build Coastguard Worker exp_avgs: List[Tensor], 622*da0073e9SAndroid Build Coastguard Worker exp_avg_sqs: List[Tensor], 623*da0073e9SAndroid Build Coastguard Worker max_exp_avg_sqs: List[Tensor], 624*da0073e9SAndroid Build Coastguard Worker state_steps: List[Tensor], 625*da0073e9SAndroid Build Coastguard Worker grad_scale: Optional[Tensor], 626*da0073e9SAndroid Build Coastguard Worker found_inf: Optional[Tensor], 627*da0073e9SAndroid Build Coastguard Worker *, 628*da0073e9SAndroid Build Coastguard Worker amsgrad: bool, 629*da0073e9SAndroid Build Coastguard Worker beta1: float, 630*da0073e9SAndroid Build Coastguard Worker beta2: float, 631*da0073e9SAndroid Build Coastguard Worker lr: Union[Tensor, float], 632*da0073e9SAndroid Build Coastguard Worker weight_decay: float, 633*da0073e9SAndroid Build Coastguard Worker eps: float, 634*da0073e9SAndroid Build Coastguard Worker maximize: bool, 635*da0073e9SAndroid Build Coastguard Worker capturable: bool, # Needed for consistency. 636*da0073e9SAndroid Build Coastguard Worker differentiable: bool, 637*da0073e9SAndroid Build Coastguard Worker has_complex: bool, # Needed for consistency. 638*da0073e9SAndroid Build Coastguard Worker) -> None: 639*da0073e9SAndroid Build Coastguard Worker if not params: 640*da0073e9SAndroid Build Coastguard Worker return 641*da0073e9SAndroid Build Coastguard Worker if differentiable: 642*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("Adam with fused=True does not support differentiable=True") 643*da0073e9SAndroid Build Coastguard Worker 644*da0073e9SAndroid Build Coastguard Worker grad_scale_dict: DeviceDict = ( 645*da0073e9SAndroid Build Coastguard Worker {grad_scale.device: grad_scale} if grad_scale is not None else {} 646*da0073e9SAndroid Build Coastguard Worker ) 647*da0073e9SAndroid Build Coastguard Worker found_inf_dict: DeviceDict = ( 648*da0073e9SAndroid Build Coastguard Worker {found_inf.device: found_inf} if found_inf is not None else {} 649*da0073e9SAndroid Build Coastguard Worker ) 650*da0073e9SAndroid Build Coastguard Worker 651*da0073e9SAndroid Build Coastguard Worker # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer 652*da0073e9SAndroid Build Coastguard Worker # treating it as a scalar. 653*da0073e9SAndroid Build Coastguard Worker lr_dict: Optional[DeviceDict] = ( 654*da0073e9SAndroid Build Coastguard Worker {lr.device: lr} if isinstance(lr, Tensor) and str(lr.device) != "cpu" else None 655*da0073e9SAndroid Build Coastguard Worker ) 656*da0073e9SAndroid Build Coastguard Worker 657*da0073e9SAndroid Build Coastguard Worker grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( 658*da0073e9SAndroid Build Coastguard Worker [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item] 659*da0073e9SAndroid Build Coastguard Worker ) 660*da0073e9SAndroid Build Coastguard Worker for (device, _), ( 661*da0073e9SAndroid Build Coastguard Worker ( 662*da0073e9SAndroid Build Coastguard Worker device_params_, 663*da0073e9SAndroid Build Coastguard Worker device_grads_, 664*da0073e9SAndroid Build Coastguard Worker device_exp_avgs_, 665*da0073e9SAndroid Build Coastguard Worker device_exp_avg_sqs_, 666*da0073e9SAndroid Build Coastguard Worker device_max_exp_avg_sqs, 667*da0073e9SAndroid Build Coastguard Worker device_state_steps_, 668*da0073e9SAndroid Build Coastguard Worker ), 669*da0073e9SAndroid Build Coastguard Worker _, 670*da0073e9SAndroid Build Coastguard Worker ) in grouped_tensors.items(): 671*da0073e9SAndroid Build Coastguard Worker device_params = cast(List[Tensor], device_params_) 672*da0073e9SAndroid Build Coastguard Worker device_grads = cast(List[Tensor], device_grads_) 673*da0073e9SAndroid Build Coastguard Worker device_exp_avgs = cast(List[Tensor], device_exp_avgs_) 674*da0073e9SAndroid Build Coastguard Worker device_exp_avg_sqs = cast(List[Tensor], device_exp_avg_sqs_) 675*da0073e9SAndroid Build Coastguard Worker device_state_steps = cast(List[Tensor], device_state_steps_) 676*da0073e9SAndroid Build Coastguard Worker 677*da0073e9SAndroid Build Coastguard Worker if device.type == "mps": # type: ignore[union-attr] 678*da0073e9SAndroid Build Coastguard Worker assert found_inf is None and grad_scale is None 679*da0073e9SAndroid Build Coastguard Worker 680*da0073e9SAndroid Build Coastguard Worker device_grad_scale, device_found_inf = None, None 681*da0073e9SAndroid Build Coastguard Worker if grad_scale is not None: 682*da0073e9SAndroid Build Coastguard Worker device_grad_scale = grad_scale_dict.setdefault( 683*da0073e9SAndroid Build Coastguard Worker device, grad_scale.to(device, non_blocking=True) 684*da0073e9SAndroid Build Coastguard Worker ) 685*da0073e9SAndroid Build Coastguard Worker if found_inf is not None: 686*da0073e9SAndroid Build Coastguard Worker device_found_inf = found_inf_dict.setdefault( 687*da0073e9SAndroid Build Coastguard Worker device, found_inf.to(device, non_blocking=True) 688*da0073e9SAndroid Build Coastguard Worker ) 689*da0073e9SAndroid Build Coastguard Worker if lr_dict is not None and device not in lr_dict: 690*da0073e9SAndroid Build Coastguard Worker lr = lr_dict.setdefault( 691*da0073e9SAndroid Build Coastguard Worker device, lr.to(device=device, non_blocking=True) # type: ignore[union-attr] 692*da0073e9SAndroid Build Coastguard Worker ) 693*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_(device_state_steps, 1) 694*da0073e9SAndroid Build Coastguard Worker torch._fused_adamw_( 695*da0073e9SAndroid Build Coastguard Worker device_params, 696*da0073e9SAndroid Build Coastguard Worker device_grads, 697*da0073e9SAndroid Build Coastguard Worker device_exp_avgs, 698*da0073e9SAndroid Build Coastguard Worker device_exp_avg_sqs, 699*da0073e9SAndroid Build Coastguard Worker device_max_exp_avg_sqs, # type: ignore[arg-type] 700*da0073e9SAndroid Build Coastguard Worker device_state_steps, 701*da0073e9SAndroid Build Coastguard Worker amsgrad=amsgrad, 702*da0073e9SAndroid Build Coastguard Worker lr=lr, # type: ignore[arg-type] 703*da0073e9SAndroid Build Coastguard Worker beta1=beta1, 704*da0073e9SAndroid Build Coastguard Worker beta2=beta2, 705*da0073e9SAndroid Build Coastguard Worker weight_decay=weight_decay, 706*da0073e9SAndroid Build Coastguard Worker eps=eps, 707*da0073e9SAndroid Build Coastguard Worker maximize=maximize, 708*da0073e9SAndroid Build Coastguard Worker grad_scale=device_grad_scale, 709*da0073e9SAndroid Build Coastguard Worker found_inf=device_found_inf, 710*da0073e9SAndroid Build Coastguard Worker ) 711*da0073e9SAndroid Build Coastguard Worker if device_found_inf is not None: 712*da0073e9SAndroid Build Coastguard Worker torch._foreach_sub_( 713*da0073e9SAndroid Build Coastguard Worker device_state_steps, [device_found_inf] * len(device_state_steps) 714*da0073e9SAndroid Build Coastguard Worker ) 715*da0073e9SAndroid Build Coastguard Worker 716*da0073e9SAndroid Build Coastguard Worker 717*da0073e9SAndroid Build Coastguard Worker@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adamw) 718*da0073e9SAndroid Build Coastguard Workerdef adamw( 719*da0073e9SAndroid Build Coastguard Worker params: List[Tensor], 720*da0073e9SAndroid Build Coastguard Worker grads: List[Tensor], 721*da0073e9SAndroid Build Coastguard Worker exp_avgs: List[Tensor], 722*da0073e9SAndroid Build Coastguard Worker exp_avg_sqs: List[Tensor], 723*da0073e9SAndroid Build Coastguard Worker max_exp_avg_sqs: List[Tensor], 724*da0073e9SAndroid Build Coastguard Worker state_steps: List[Tensor], 725*da0073e9SAndroid Build Coastguard Worker # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 726*da0073e9SAndroid Build Coastguard Worker # setting this as kwarg for now as functional API is compiled by torch/distributed/optim 727*da0073e9SAndroid Build Coastguard Worker foreach: Optional[bool] = None, 728*da0073e9SAndroid Build Coastguard Worker capturable: bool = False, 729*da0073e9SAndroid Build Coastguard Worker differentiable: bool = False, 730*da0073e9SAndroid Build Coastguard Worker fused: Optional[bool] = None, 731*da0073e9SAndroid Build Coastguard Worker grad_scale: Optional[Tensor] = None, 732*da0073e9SAndroid Build Coastguard Worker found_inf: Optional[Tensor] = None, 733*da0073e9SAndroid Build Coastguard Worker has_complex: bool = False, 734*da0073e9SAndroid Build Coastguard Worker *, 735*da0073e9SAndroid Build Coastguard Worker amsgrad: bool, 736*da0073e9SAndroid Build Coastguard Worker beta1: float, 737*da0073e9SAndroid Build Coastguard Worker beta2: float, 738*da0073e9SAndroid Build Coastguard Worker lr: Union[float, Tensor], 739*da0073e9SAndroid Build Coastguard Worker weight_decay: float, 740*da0073e9SAndroid Build Coastguard Worker eps: float, 741*da0073e9SAndroid Build Coastguard Worker maximize: bool, 742*da0073e9SAndroid Build Coastguard Worker): 743*da0073e9SAndroid Build Coastguard Worker r"""Functional API that performs AdamW algorithm computation. 744*da0073e9SAndroid Build Coastguard Worker 745*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.optim.AdamW` for details. 746*da0073e9SAndroid Build Coastguard Worker """ 747*da0073e9SAndroid Build Coastguard Worker if not torch._utils.is_compiling() and not all( 748*da0073e9SAndroid Build Coastguard Worker isinstance(t, torch.Tensor) for t in state_steps 749*da0073e9SAndroid Build Coastguard Worker ): 750*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 751*da0073e9SAndroid Build Coastguard Worker "API has changed, `state_steps` argument must contain a list of singleton tensors" 752*da0073e9SAndroid Build Coastguard Worker ) 753*da0073e9SAndroid Build Coastguard Worker 754*da0073e9SAndroid Build Coastguard Worker # Respect when the user inputs False/True for foreach or fused. We only want to change 755*da0073e9SAndroid Build Coastguard Worker # the default when neither have been user-specified. Note that we default to foreach 756*da0073e9SAndroid Build Coastguard Worker # and pass False to use_fused. This is not a mistake--we want to give the fused impl 757*da0073e9SAndroid Build Coastguard Worker # bake-in time before making it the default, even if it is typically faster. 758*da0073e9SAndroid Build Coastguard Worker if fused is None and foreach is None: 759*da0073e9SAndroid Build Coastguard Worker _, foreach = _default_to_fused_or_foreach( 760*da0073e9SAndroid Build Coastguard Worker params, differentiable, use_fused=False 761*da0073e9SAndroid Build Coastguard Worker ) 762*da0073e9SAndroid Build Coastguard Worker # Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False. 763*da0073e9SAndroid Build Coastguard Worker if foreach and isinstance(lr, Tensor) and not capturable: 764*da0073e9SAndroid Build Coastguard Worker foreach = False 765*da0073e9SAndroid Build Coastguard Worker if fused is None: 766*da0073e9SAndroid Build Coastguard Worker fused = False 767*da0073e9SAndroid Build Coastguard Worker if foreach is None: 768*da0073e9SAndroid Build Coastguard Worker foreach = False 769*da0073e9SAndroid Build Coastguard Worker 770*da0073e9SAndroid Build Coastguard Worker if foreach and torch.jit.is_scripting(): 771*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("torch.jit.script not supported with foreach optimizers") 772*da0073e9SAndroid Build Coastguard Worker if fused and torch.jit.is_scripting(): 773*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("torch.jit.script not supported with fused optimizers") 774*da0073e9SAndroid Build Coastguard Worker 775*da0073e9SAndroid Build Coastguard Worker if fused and not torch.jit.is_scripting(): 776*da0073e9SAndroid Build Coastguard Worker func = _fused_adamw 777*da0073e9SAndroid Build Coastguard Worker elif foreach and not torch.jit.is_scripting(): 778*da0073e9SAndroid Build Coastguard Worker func = _multi_tensor_adamw 779*da0073e9SAndroid Build Coastguard Worker else: 780*da0073e9SAndroid Build Coastguard Worker func = _single_tensor_adamw 781*da0073e9SAndroid Build Coastguard Worker 782*da0073e9SAndroid Build Coastguard Worker func( 783*da0073e9SAndroid Build Coastguard Worker params, 784*da0073e9SAndroid Build Coastguard Worker grads, 785*da0073e9SAndroid Build Coastguard Worker exp_avgs, 786*da0073e9SAndroid Build Coastguard Worker exp_avg_sqs, 787*da0073e9SAndroid Build Coastguard Worker max_exp_avg_sqs, 788*da0073e9SAndroid Build Coastguard Worker state_steps, 789*da0073e9SAndroid Build Coastguard Worker amsgrad=amsgrad, 790*da0073e9SAndroid Build Coastguard Worker beta1=beta1, 791*da0073e9SAndroid Build Coastguard Worker beta2=beta2, 792*da0073e9SAndroid Build Coastguard Worker lr=lr, 793*da0073e9SAndroid Build Coastguard Worker weight_decay=weight_decay, 794*da0073e9SAndroid Build Coastguard Worker eps=eps, 795*da0073e9SAndroid Build Coastguard Worker maximize=maximize, 796*da0073e9SAndroid Build Coastguard Worker capturable=capturable, 797*da0073e9SAndroid Build Coastguard Worker differentiable=differentiable, 798*da0073e9SAndroid Build Coastguard Worker grad_scale=grad_scale, 799*da0073e9SAndroid Build Coastguard Worker found_inf=found_inf, 800*da0073e9SAndroid Build Coastguard Worker has_complex=has_complex, 801*da0073e9SAndroid Build Coastguard Worker ) 802