1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerfrom typing import cast, List, Optional, Union 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerimport torch 5*da0073e9SAndroid Build Coastguard Workerfrom torch import Tensor 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerfrom .optimizer import ( 8*da0073e9SAndroid Build Coastguard Worker _default_to_fused_or_foreach, 9*da0073e9SAndroid Build Coastguard Worker _device_dtype_check_for_fused, 10*da0073e9SAndroid Build Coastguard Worker _differentiable_doc, 11*da0073e9SAndroid Build Coastguard Worker _foreach_doc, 12*da0073e9SAndroid Build Coastguard Worker _get_scalar_dtype, 13*da0073e9SAndroid Build Coastguard Worker _get_value, 14*da0073e9SAndroid Build Coastguard Worker _maximize_doc, 15*da0073e9SAndroid Build Coastguard Worker _use_grad_for_differentiable, 16*da0073e9SAndroid Build Coastguard Worker _view_as_real, 17*da0073e9SAndroid Build Coastguard Worker Optimizer, 18*da0073e9SAndroid Build Coastguard Worker ParamsT, 19*da0073e9SAndroid Build Coastguard Worker) 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker 22*da0073e9SAndroid Build Coastguard Worker__all__ = ["Adagrad", "adagrad"] 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Workerclass Adagrad(Optimizer): 26*da0073e9SAndroid Build Coastguard Worker def __init__( 27*da0073e9SAndroid Build Coastguard Worker self, 28*da0073e9SAndroid Build Coastguard Worker params: ParamsT, 29*da0073e9SAndroid Build Coastguard Worker lr: Union[float, Tensor] = 1e-2, 30*da0073e9SAndroid Build Coastguard Worker lr_decay: float = 0, 31*da0073e9SAndroid Build Coastguard Worker weight_decay: float = 0, 32*da0073e9SAndroid Build Coastguard Worker initial_accumulator_value: float = 0, 33*da0073e9SAndroid Build Coastguard Worker eps: float = 1e-10, 34*da0073e9SAndroid Build Coastguard Worker foreach: Optional[bool] = None, 35*da0073e9SAndroid Build Coastguard Worker *, 36*da0073e9SAndroid Build Coastguard Worker maximize: bool = False, 37*da0073e9SAndroid Build Coastguard Worker differentiable: bool = False, 38*da0073e9SAndroid Build Coastguard Worker fused: Optional[bool] = None, 39*da0073e9SAndroid Build Coastguard Worker ): 40*da0073e9SAndroid Build Coastguard Worker if isinstance(lr, Tensor) and lr.numel() != 1: 41*da0073e9SAndroid Build Coastguard Worker raise ValueError("Tensor lr must be 1-element") 42*da0073e9SAndroid Build Coastguard Worker if not 0.0 <= lr: 43*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Invalid learning rate: {lr}") 44*da0073e9SAndroid Build Coastguard Worker if not 0.0 <= lr_decay: 45*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Invalid lr_decay value: {lr_decay}") 46*da0073e9SAndroid Build Coastguard Worker if not 0.0 <= weight_decay: 47*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Invalid weight_decay value: {weight_decay}") 48*da0073e9SAndroid Build Coastguard Worker if not 0.0 <= initial_accumulator_value: 49*da0073e9SAndroid Build Coastguard Worker raise ValueError( 50*da0073e9SAndroid Build Coastguard Worker f"Invalid initial_accumulator_value value: {initial_accumulator_value}" 51*da0073e9SAndroid Build Coastguard Worker ) 52*da0073e9SAndroid Build Coastguard Worker if not 0.0 <= eps: 53*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Invalid epsilon value: {eps}") 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker defaults = dict( 56*da0073e9SAndroid Build Coastguard Worker lr=lr, 57*da0073e9SAndroid Build Coastguard Worker lr_decay=lr_decay, 58*da0073e9SAndroid Build Coastguard Worker eps=eps, 59*da0073e9SAndroid Build Coastguard Worker weight_decay=weight_decay, 60*da0073e9SAndroid Build Coastguard Worker initial_accumulator_value=initial_accumulator_value, 61*da0073e9SAndroid Build Coastguard Worker foreach=foreach, 62*da0073e9SAndroid Build Coastguard Worker maximize=maximize, 63*da0073e9SAndroid Build Coastguard Worker differentiable=differentiable, 64*da0073e9SAndroid Build Coastguard Worker fused=fused, 65*da0073e9SAndroid Build Coastguard Worker ) 66*da0073e9SAndroid Build Coastguard Worker super().__init__(params, defaults) 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Worker if fused: 69*da0073e9SAndroid Build Coastguard Worker if differentiable: 70*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("`fused` does not support `differentiable`") 71*da0073e9SAndroid Build Coastguard Worker if foreach: 72*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("`fused` and `foreach` cannot be `True` together.") 73*da0073e9SAndroid Build Coastguard Worker self._need_device_dtype_check_for_fused = True 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker for group in self.param_groups: 76*da0073e9SAndroid Build Coastguard Worker for p in group["params"]: 77*da0073e9SAndroid Build Coastguard Worker state = self.state[p] 78*da0073e9SAndroid Build Coastguard Worker state["step"] = ( 79*da0073e9SAndroid Build Coastguard Worker torch.zeros( 80*da0073e9SAndroid Build Coastguard Worker (), 81*da0073e9SAndroid Build Coastguard Worker dtype=_get_scalar_dtype(is_fused=group["fused"]), 82*da0073e9SAndroid Build Coastguard Worker device=p.device, 83*da0073e9SAndroid Build Coastguard Worker ) 84*da0073e9SAndroid Build Coastguard Worker if group["fused"] 85*da0073e9SAndroid Build Coastguard Worker else torch.tensor(0.0, dtype=_get_scalar_dtype()) 86*da0073e9SAndroid Build Coastguard Worker ) 87*da0073e9SAndroid Build Coastguard Worker init_value = ( 88*da0073e9SAndroid Build Coastguard Worker complex(initial_accumulator_value, initial_accumulator_value) 89*da0073e9SAndroid Build Coastguard Worker if torch.is_complex(p) 90*da0073e9SAndroid Build Coastguard Worker else initial_accumulator_value 91*da0073e9SAndroid Build Coastguard Worker ) 92*da0073e9SAndroid Build Coastguard Worker state["sum"] = torch.full_like( 93*da0073e9SAndroid Build Coastguard Worker p, init_value, memory_format=torch.preserve_format 94*da0073e9SAndroid Build Coastguard Worker ) 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker def __setstate__(self, state): 97*da0073e9SAndroid Build Coastguard Worker super().__setstate__(state) 98*da0073e9SAndroid Build Coastguard Worker # define "fused" for 99*da0073e9SAndroid Build Coastguard Worker # MYPY error: Name "fused" may be undefined 100*da0073e9SAndroid Build Coastguard Worker fused = None 101*da0073e9SAndroid Build Coastguard Worker for group in self.param_groups: 102*da0073e9SAndroid Build Coastguard Worker group.setdefault("foreach", None) 103*da0073e9SAndroid Build Coastguard Worker group.setdefault("maximize", False) 104*da0073e9SAndroid Build Coastguard Worker group.setdefault("differentiable", False) 105*da0073e9SAndroid Build Coastguard Worker fused = group.setdefault("fused", None) 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker state_values = list(self.state.values()) 108*da0073e9SAndroid Build Coastguard Worker step_is_tensor = (len(state_values) != 0) and torch.is_tensor( 109*da0073e9SAndroid Build Coastguard Worker state_values[0]["step"] 110*da0073e9SAndroid Build Coastguard Worker ) 111*da0073e9SAndroid Build Coastguard Worker if not step_is_tensor: 112*da0073e9SAndroid Build Coastguard Worker for s in state_values: 113*da0073e9SAndroid Build Coastguard Worker s["step"] = torch.tensor( 114*da0073e9SAndroid Build Coastguard Worker float(s["step"]), dtype=_get_scalar_dtype(is_fused=fused) 115*da0073e9SAndroid Build Coastguard Worker ) 116*da0073e9SAndroid Build Coastguard Worker 117*da0073e9SAndroid Build Coastguard Worker def share_memory(self): 118*da0073e9SAndroid Build Coastguard Worker for group in self.param_groups: 119*da0073e9SAndroid Build Coastguard Worker for p in group["params"]: 120*da0073e9SAndroid Build Coastguard Worker state = self.state[p] 121*da0073e9SAndroid Build Coastguard Worker state["sum"].share_memory_() 122*da0073e9SAndroid Build Coastguard Worker 123*da0073e9SAndroid Build Coastguard Worker def _init_group(self, group, params_with_grad, grads, state_sums, state_steps): 124*da0073e9SAndroid Build Coastguard Worker has_sparse_grad, has_complex = False, False 125*da0073e9SAndroid Build Coastguard Worker for p in group["params"]: 126*da0073e9SAndroid Build Coastguard Worker if p.grad is not None: 127*da0073e9SAndroid Build Coastguard Worker if group["fused"] and getattr( 128*da0073e9SAndroid Build Coastguard Worker self, 129*da0073e9SAndroid Build Coastguard Worker "_need_device_dtype_check_for_fused", 130*da0073e9SAndroid Build Coastguard Worker True, 131*da0073e9SAndroid Build Coastguard Worker ): 132*da0073e9SAndroid Build Coastguard Worker _device_dtype_check_for_fused(p, cuda_unsupported=True) 133*da0073e9SAndroid Build Coastguard Worker self._need_device_dtype_check_for_fused = False 134*da0073e9SAndroid Build Coastguard Worker has_sparse_grad |= p.grad.is_sparse 135*da0073e9SAndroid Build Coastguard Worker has_complex |= torch.is_complex(p) 136*da0073e9SAndroid Build Coastguard Worker params_with_grad.append(p) 137*da0073e9SAndroid Build Coastguard Worker grads.append(p.grad) 138*da0073e9SAndroid Build Coastguard Worker state = self.state[p] 139*da0073e9SAndroid Build Coastguard Worker state_sums.append(state["sum"]) 140*da0073e9SAndroid Build Coastguard Worker state_steps.append(state["step"]) 141*da0073e9SAndroid Build Coastguard Worker 142*da0073e9SAndroid Build Coastguard Worker return has_sparse_grad, has_complex 143*da0073e9SAndroid Build Coastguard Worker 144*da0073e9SAndroid Build Coastguard Worker @_use_grad_for_differentiable 145*da0073e9SAndroid Build Coastguard Worker def step(self, closure=None): 146*da0073e9SAndroid Build Coastguard Worker """Perform a single optimization step. 147*da0073e9SAndroid Build Coastguard Worker 148*da0073e9SAndroid Build Coastguard Worker Args: 149*da0073e9SAndroid Build Coastguard Worker closure (Callable, optional): A closure that reevaluates the model 150*da0073e9SAndroid Build Coastguard Worker and returns the loss. 151*da0073e9SAndroid Build Coastguard Worker """ 152*da0073e9SAndroid Build Coastguard Worker loss = None 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker if closure is not None: 155*da0073e9SAndroid Build Coastguard Worker with torch.enable_grad(): 156*da0073e9SAndroid Build Coastguard Worker loss = closure() 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker for group in self.param_groups: 159*da0073e9SAndroid Build Coastguard Worker params_with_grad: List[Tensor] = [] 160*da0073e9SAndroid Build Coastguard Worker grads: List[Tensor] = [] 161*da0073e9SAndroid Build Coastguard Worker state_sums: List[Tensor] = [] 162*da0073e9SAndroid Build Coastguard Worker state_steps: List[Tensor] = [] 163*da0073e9SAndroid Build Coastguard Worker 164*da0073e9SAndroid Build Coastguard Worker has_sparse_grad, has_complex = self._init_group( 165*da0073e9SAndroid Build Coastguard Worker group, params_with_grad, grads, state_sums, state_steps 166*da0073e9SAndroid Build Coastguard Worker ) 167*da0073e9SAndroid Build Coastguard Worker 168*da0073e9SAndroid Build Coastguard Worker adagrad( 169*da0073e9SAndroid Build Coastguard Worker params_with_grad, 170*da0073e9SAndroid Build Coastguard Worker grads, 171*da0073e9SAndroid Build Coastguard Worker state_sums, 172*da0073e9SAndroid Build Coastguard Worker state_steps, 173*da0073e9SAndroid Build Coastguard Worker lr=group["lr"], 174*da0073e9SAndroid Build Coastguard Worker weight_decay=group["weight_decay"], 175*da0073e9SAndroid Build Coastguard Worker lr_decay=group["lr_decay"], 176*da0073e9SAndroid Build Coastguard Worker eps=group["eps"], 177*da0073e9SAndroid Build Coastguard Worker has_sparse_grad=has_sparse_grad, 178*da0073e9SAndroid Build Coastguard Worker foreach=group["foreach"], 179*da0073e9SAndroid Build Coastguard Worker maximize=group["maximize"], 180*da0073e9SAndroid Build Coastguard Worker differentiable=group["differentiable"], 181*da0073e9SAndroid Build Coastguard Worker has_complex=has_complex, 182*da0073e9SAndroid Build Coastguard Worker fused=group["fused"], 183*da0073e9SAndroid Build Coastguard Worker grad_scale=getattr(self, "grad_scale", None), 184*da0073e9SAndroid Build Coastguard Worker found_inf=getattr(self, "found_inf", None), 185*da0073e9SAndroid Build Coastguard Worker ) 186*da0073e9SAndroid Build Coastguard Worker 187*da0073e9SAndroid Build Coastguard Worker return loss 188*da0073e9SAndroid Build Coastguard Worker 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard WorkerAdagrad.__doc__ = ( 191*da0073e9SAndroid Build Coastguard Worker r"""Implements Adagrad algorithm. 192*da0073e9SAndroid Build Coastguard Worker 193*da0073e9SAndroid Build Coastguard Worker .. math:: 194*da0073e9SAndroid Build Coastguard Worker \begin{aligned} 195*da0073e9SAndroid Build Coastguard Worker &\rule{110mm}{0.4pt} \\ 196*da0073e9SAndroid Build Coastguard Worker &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta) 197*da0073e9SAndroid Build Coastguard Worker \text{ (objective)}, \: \lambda \text{ (weight decay)}, \\ 198*da0073e9SAndroid Build Coastguard Worker &\hspace{12mm} \tau \text{ (initial accumulator value)}, \: \eta\text{ (lr decay)}\\ 199*da0073e9SAndroid Build Coastguard Worker &\textbf{initialize} : state\_sum_0 \leftarrow \tau \\[-1.ex] 200*da0073e9SAndroid Build Coastguard Worker &\rule{110mm}{0.4pt} \\ 201*da0073e9SAndroid Build Coastguard Worker &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ 202*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ 203*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm} \tilde{\gamma} \leftarrow \gamma / (1 +(t-1) \eta) \\ 204*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm} \textbf{if} \: \lambda \neq 0 \\ 205*da0073e9SAndroid Build Coastguard Worker &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ 206*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}state\_sum_t \leftarrow state\_sum_{t-1} + g^2_t \\ 207*da0073e9SAndroid Build Coastguard Worker &\hspace{5mm}\theta_t \leftarrow 208*da0073e9SAndroid Build Coastguard Worker \theta_{t-1}- \tilde{\gamma} \frac{g_t}{\sqrt{state\_sum_t}+\epsilon} \\ 209*da0073e9SAndroid Build Coastguard Worker &\rule{110mm}{0.4pt} \\[-1.ex] 210*da0073e9SAndroid Build Coastguard Worker &\bf{return} \: \theta_t \\[-1.ex] 211*da0073e9SAndroid Build Coastguard Worker &\rule{110mm}{0.4pt} \\[-1.ex] 212*da0073e9SAndroid Build Coastguard Worker \end{aligned} 213*da0073e9SAndroid Build Coastguard Worker 214*da0073e9SAndroid Build Coastguard Worker For further details regarding the algorithm we refer to `Adaptive Subgradient Methods for Online Learning 215*da0073e9SAndroid Build Coastguard Worker and Stochastic Optimization`_. 216*da0073e9SAndroid Build Coastguard Worker """ 217*da0073e9SAndroid Build Coastguard Worker + rf""" 218*da0073e9SAndroid Build Coastguard Worker Args: 219*da0073e9SAndroid Build Coastguard Worker params (iterable): iterable of parameters to optimize or dicts defining 220*da0073e9SAndroid Build Coastguard Worker parameter groups 221*da0073e9SAndroid Build Coastguard Worker lr (float, Tensor, optional): learning rate (default: 1e-2) 222*da0073e9SAndroid Build Coastguard Worker lr_decay (float, optional): learning rate decay (default: 0) 223*da0073e9SAndroid Build Coastguard Worker weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 224*da0073e9SAndroid Build Coastguard Worker initial_accumulator_value (float, optional): initial value of the 225*da0073e9SAndroid Build Coastguard Worker sum of squares of gradients (default: 0) 226*da0073e9SAndroid Build Coastguard Worker eps (float, optional): term added to the denominator to improve 227*da0073e9SAndroid Build Coastguard Worker numerical stability (default: 1e-10) 228*da0073e9SAndroid Build Coastguard Worker {_foreach_doc} 229*da0073e9SAndroid Build Coastguard Worker {_maximize_doc} 230*da0073e9SAndroid Build Coastguard Worker {_differentiable_doc} 231*da0073e9SAndroid Build Coastguard Worker fused (bool, optional): whether the fused implementation (CPU only) is used. 232*da0073e9SAndroid Build Coastguard Worker Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16` 233*da0073e9SAndroid Build Coastguard Worker are supported. (default: None). Please note that the fused implementations does not 234*da0073e9SAndroid Build Coastguard Worker support sparse or complex gradients. 235*da0073e9SAndroid Build Coastguard Worker .. _Adaptive Subgradient Methods for Online Learning and Stochastic 236*da0073e9SAndroid Build Coastguard Worker Optimization: http://jmlr.org/papers/v12/duchi11a.html 237*da0073e9SAndroid Build Coastguard Worker 238*da0073e9SAndroid Build Coastguard Worker """ 239*da0073e9SAndroid Build Coastguard Worker) 240*da0073e9SAndroid Build Coastguard Worker 241*da0073e9SAndroid Build Coastguard Worker 242*da0073e9SAndroid Build Coastguard Workerdef adagrad( 243*da0073e9SAndroid Build Coastguard Worker params: List[Tensor], 244*da0073e9SAndroid Build Coastguard Worker grads: List[Tensor], 245*da0073e9SAndroid Build Coastguard Worker state_sums: List[Tensor], 246*da0073e9SAndroid Build Coastguard Worker state_steps: List[Tensor], 247*da0073e9SAndroid Build Coastguard Worker fused: Optional[bool] = None, 248*da0073e9SAndroid Build Coastguard Worker grad_scale: Optional[Tensor] = None, 249*da0073e9SAndroid Build Coastguard Worker found_inf: Optional[Tensor] = None, 250*da0073e9SAndroid Build Coastguard Worker # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 251*da0073e9SAndroid Build Coastguard Worker # setting these as kwargs for now as functional API is compiled by torch/distributed/optim 252*da0073e9SAndroid Build Coastguard Worker has_sparse_grad: bool = False, 253*da0073e9SAndroid Build Coastguard Worker foreach: Optional[bool] = None, 254*da0073e9SAndroid Build Coastguard Worker differentiable: bool = False, 255*da0073e9SAndroid Build Coastguard Worker has_complex: bool = False, 256*da0073e9SAndroid Build Coastguard Worker *, 257*da0073e9SAndroid Build Coastguard Worker lr: float, 258*da0073e9SAndroid Build Coastguard Worker weight_decay: float, 259*da0073e9SAndroid Build Coastguard Worker lr_decay: float, 260*da0073e9SAndroid Build Coastguard Worker eps: float, 261*da0073e9SAndroid Build Coastguard Worker maximize: bool, 262*da0073e9SAndroid Build Coastguard Worker): 263*da0073e9SAndroid Build Coastguard Worker r"""Functional API that performs Adagrad algorithm computation. 264*da0073e9SAndroid Build Coastguard Worker 265*da0073e9SAndroid Build Coastguard Worker See :class:`~torch.optim.Adagrad` for details. 266*da0073e9SAndroid Build Coastguard Worker """ 267*da0073e9SAndroid Build Coastguard Worker if not all(isinstance(t, torch.Tensor) for t in state_steps): 268*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 269*da0073e9SAndroid Build Coastguard Worker "API has changed, `state_steps` argument must contain a list of singleton tensors" 270*da0073e9SAndroid Build Coastguard Worker ) 271*da0073e9SAndroid Build Coastguard Worker 272*da0073e9SAndroid Build Coastguard Worker # Respect when the user inputs False/True for foreach or fused. We only want to change 273*da0073e9SAndroid Build Coastguard Worker # the default when neither have been user-specified. Note that we default to foreach 274*da0073e9SAndroid Build Coastguard Worker # and pass False to use_fused. This is not a mistake--we want to give the fused impl 275*da0073e9SAndroid Build Coastguard Worker # bake-in time before making it the default, even if it is typically faster. 276*da0073e9SAndroid Build Coastguard Worker if fused is None and foreach is None: 277*da0073e9SAndroid Build Coastguard Worker _, foreach = _default_to_fused_or_foreach( 278*da0073e9SAndroid Build Coastguard Worker params, differentiable, use_fused=False 279*da0073e9SAndroid Build Coastguard Worker ) 280*da0073e9SAndroid Build Coastguard Worker 281*da0073e9SAndroid Build Coastguard Worker if fused is None: 282*da0073e9SAndroid Build Coastguard Worker fused = False 283*da0073e9SAndroid Build Coastguard Worker if foreach is None: 284*da0073e9SAndroid Build Coastguard Worker foreach = False 285*da0073e9SAndroid Build Coastguard Worker 286*da0073e9SAndroid Build Coastguard Worker if foreach and torch.jit.is_scripting(): 287*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("torch.jit.script not supported with foreach optimizers") 288*da0073e9SAndroid Build Coastguard Worker if fused and torch.jit.is_scripting(): 289*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("torch.jit.script not supported with fused optimizers") 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Worker if fused and not torch.jit.is_scripting(): 292*da0073e9SAndroid Build Coastguard Worker func = _fused_adagrad 293*da0073e9SAndroid Build Coastguard Worker elif foreach and not torch.jit.is_scripting(): 294*da0073e9SAndroid Build Coastguard Worker func = _multi_tensor_adagrad 295*da0073e9SAndroid Build Coastguard Worker else: 296*da0073e9SAndroid Build Coastguard Worker func = _single_tensor_adagrad 297*da0073e9SAndroid Build Coastguard Worker 298*da0073e9SAndroid Build Coastguard Worker func( 299*da0073e9SAndroid Build Coastguard Worker params, 300*da0073e9SAndroid Build Coastguard Worker grads, 301*da0073e9SAndroid Build Coastguard Worker state_sums, 302*da0073e9SAndroid Build Coastguard Worker state_steps, 303*da0073e9SAndroid Build Coastguard Worker lr=lr, 304*da0073e9SAndroid Build Coastguard Worker weight_decay=weight_decay, 305*da0073e9SAndroid Build Coastguard Worker lr_decay=lr_decay, 306*da0073e9SAndroid Build Coastguard Worker eps=eps, 307*da0073e9SAndroid Build Coastguard Worker has_sparse_grad=has_sparse_grad, 308*da0073e9SAndroid Build Coastguard Worker maximize=maximize, 309*da0073e9SAndroid Build Coastguard Worker differentiable=differentiable, 310*da0073e9SAndroid Build Coastguard Worker has_complex=has_complex, 311*da0073e9SAndroid Build Coastguard Worker grad_scale=grad_scale, 312*da0073e9SAndroid Build Coastguard Worker found_inf=found_inf, 313*da0073e9SAndroid Build Coastguard Worker ) 314*da0073e9SAndroid Build Coastguard Worker 315*da0073e9SAndroid Build Coastguard Worker 316*da0073e9SAndroid Build Coastguard Workerdef _make_sparse(grad, grad_indices, values): 317*da0073e9SAndroid Build Coastguard Worker size = grad.size() 318*da0073e9SAndroid Build Coastguard Worker return torch.sparse_coo_tensor(grad_indices, values, size) 319*da0073e9SAndroid Build Coastguard Worker 320*da0073e9SAndroid Build Coastguard Worker 321*da0073e9SAndroid Build Coastguard Workerdef _single_tensor_adagrad( 322*da0073e9SAndroid Build Coastguard Worker params: List[Tensor], 323*da0073e9SAndroid Build Coastguard Worker grads: List[Tensor], 324*da0073e9SAndroid Build Coastguard Worker state_sums: List[Tensor], 325*da0073e9SAndroid Build Coastguard Worker state_steps: List[Tensor], 326*da0073e9SAndroid Build Coastguard Worker grad_scale: Optional[Tensor], 327*da0073e9SAndroid Build Coastguard Worker found_inf: Optional[Tensor], 328*da0073e9SAndroid Build Coastguard Worker *, 329*da0073e9SAndroid Build Coastguard Worker lr: float, 330*da0073e9SAndroid Build Coastguard Worker weight_decay: float, 331*da0073e9SAndroid Build Coastguard Worker lr_decay: float, 332*da0073e9SAndroid Build Coastguard Worker eps: float, 333*da0073e9SAndroid Build Coastguard Worker has_sparse_grad: bool, 334*da0073e9SAndroid Build Coastguard Worker maximize: bool, 335*da0073e9SAndroid Build Coastguard Worker differentiable: bool, 336*da0073e9SAndroid Build Coastguard Worker has_complex: bool, 337*da0073e9SAndroid Build Coastguard Worker): 338*da0073e9SAndroid Build Coastguard Worker assert grad_scale is None and found_inf is None 339*da0073e9SAndroid Build Coastguard Worker for param, grad, state_sum, step_t in zip(params, grads, state_sums, state_steps): 340*da0073e9SAndroid Build Coastguard Worker # update step 341*da0073e9SAndroid Build Coastguard Worker step_t += 1 342*da0073e9SAndroid Build Coastguard Worker step = _get_value(step_t) 343*da0073e9SAndroid Build Coastguard Worker grad = grad if not maximize else -grad 344*da0073e9SAndroid Build Coastguard Worker 345*da0073e9SAndroid Build Coastguard Worker if weight_decay != 0: 346*da0073e9SAndroid Build Coastguard Worker if grad.is_sparse: 347*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 348*da0073e9SAndroid Build Coastguard Worker "weight_decay option is not compatible with sparse gradients" 349*da0073e9SAndroid Build Coastguard Worker ) 350*da0073e9SAndroid Build Coastguard Worker grad = grad.add(param, alpha=weight_decay) 351*da0073e9SAndroid Build Coastguard Worker 352*da0073e9SAndroid Build Coastguard Worker clr = lr / (1 + (step - 1) * lr_decay) 353*da0073e9SAndroid Build Coastguard Worker 354*da0073e9SAndroid Build Coastguard Worker if grad.is_sparse: 355*da0073e9SAndroid Build Coastguard Worker grad = grad.coalesce() # the update is non-linear so indices must be unique 356*da0073e9SAndroid Build Coastguard Worker grad_indices = grad._indices() 357*da0073e9SAndroid Build Coastguard Worker grad_values = grad._values() 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Worker state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2))) 360*da0073e9SAndroid Build Coastguard Worker std = state_sum.sparse_mask(grad) 361*da0073e9SAndroid Build Coastguard Worker std_values = std._values().sqrt_().add_(eps) 362*da0073e9SAndroid Build Coastguard Worker param.add_( 363*da0073e9SAndroid Build Coastguard Worker _make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr 364*da0073e9SAndroid Build Coastguard Worker ) 365*da0073e9SAndroid Build Coastguard Worker else: 366*da0073e9SAndroid Build Coastguard Worker is_complex = torch.is_complex(param) 367*da0073e9SAndroid Build Coastguard Worker if is_complex: 368*da0073e9SAndroid Build Coastguard Worker grad = torch.view_as_real(grad) 369*da0073e9SAndroid Build Coastguard Worker state_sum = torch.view_as_real(state_sum) 370*da0073e9SAndroid Build Coastguard Worker param = torch.view_as_real(param) 371*da0073e9SAndroid Build Coastguard Worker state_sum.addcmul_(grad, grad, value=1) 372*da0073e9SAndroid Build Coastguard Worker if differentiable: 373*da0073e9SAndroid Build Coastguard Worker std = state_sum.sqrt() + eps 374*da0073e9SAndroid Build Coastguard Worker else: 375*da0073e9SAndroid Build Coastguard Worker std = state_sum.sqrt().add_(eps) 376*da0073e9SAndroid Build Coastguard Worker param.addcdiv_(grad, std, value=-clr) 377*da0073e9SAndroid Build Coastguard Worker if is_complex: 378*da0073e9SAndroid Build Coastguard Worker param = torch.view_as_complex(param) 379*da0073e9SAndroid Build Coastguard Worker state_sum = torch.view_as_complex(state_sum) 380*da0073e9SAndroid Build Coastguard Worker 381*da0073e9SAndroid Build Coastguard Worker 382*da0073e9SAndroid Build Coastguard Workerdef _multi_tensor_adagrad( 383*da0073e9SAndroid Build Coastguard Worker params: List[Tensor], 384*da0073e9SAndroid Build Coastguard Worker grads: List[Tensor], 385*da0073e9SAndroid Build Coastguard Worker state_sums: List[Tensor], 386*da0073e9SAndroid Build Coastguard Worker state_steps: List[Tensor], 387*da0073e9SAndroid Build Coastguard Worker grad_scale: Optional[Tensor], 388*da0073e9SAndroid Build Coastguard Worker found_inf: Optional[Tensor], 389*da0073e9SAndroid Build Coastguard Worker *, 390*da0073e9SAndroid Build Coastguard Worker lr: float, 391*da0073e9SAndroid Build Coastguard Worker weight_decay: float, 392*da0073e9SAndroid Build Coastguard Worker lr_decay: float, 393*da0073e9SAndroid Build Coastguard Worker eps: float, 394*da0073e9SAndroid Build Coastguard Worker has_sparse_grad: bool, 395*da0073e9SAndroid Build Coastguard Worker maximize: bool, 396*da0073e9SAndroid Build Coastguard Worker differentiable: bool, 397*da0073e9SAndroid Build Coastguard Worker has_complex: bool, 398*da0073e9SAndroid Build Coastguard Worker): 399*da0073e9SAndroid Build Coastguard Worker assert not differentiable, "_foreach ops don't support autograd" 400*da0073e9SAndroid Build Coastguard Worker assert grad_scale is None and found_inf is None 401*da0073e9SAndroid Build Coastguard Worker 402*da0073e9SAndroid Build Coastguard Worker # Foreach functions will throw errors if given empty lists 403*da0073e9SAndroid Build Coastguard Worker if len(params) == 0: 404*da0073e9SAndroid Build Coastguard Worker return 405*da0073e9SAndroid Build Coastguard Worker 406*da0073e9SAndroid Build Coastguard Worker grouped_tensorlists = Optimizer._group_tensors_by_device_and_dtype( 407*da0073e9SAndroid Build Coastguard Worker [params, grads, state_sums, state_steps] # type: ignore[list-item] 408*da0073e9SAndroid Build Coastguard Worker ) 409*da0073e9SAndroid Build Coastguard Worker for ( 410*da0073e9SAndroid Build Coastguard Worker device_params_, 411*da0073e9SAndroid Build Coastguard Worker device_grads_, 412*da0073e9SAndroid Build Coastguard Worker device_state_sums_, 413*da0073e9SAndroid Build Coastguard Worker device_state_steps_, 414*da0073e9SAndroid Build Coastguard Worker ), _ in grouped_tensorlists.values(): 415*da0073e9SAndroid Build Coastguard Worker device_params = cast(List[Tensor], device_params_) 416*da0073e9SAndroid Build Coastguard Worker device_grads = cast(List[Tensor], device_grads_) 417*da0073e9SAndroid Build Coastguard Worker device_state_sums = cast(List[Tensor], device_state_sums_) 418*da0073e9SAndroid Build Coastguard Worker device_state_steps = cast(List[Tensor], device_state_steps_) 419*da0073e9SAndroid Build Coastguard Worker 420*da0073e9SAndroid Build Coastguard Worker device_has_sparse_grad = has_sparse_grad and any( 421*da0073e9SAndroid Build Coastguard Worker grad.is_sparse for grad in device_grads 422*da0073e9SAndroid Build Coastguard Worker ) 423*da0073e9SAndroid Build Coastguard Worker 424*da0073e9SAndroid Build Coastguard Worker if device_has_sparse_grad: 425*da0073e9SAndroid Build Coastguard Worker _single_tensor_adagrad( 426*da0073e9SAndroid Build Coastguard Worker device_params, 427*da0073e9SAndroid Build Coastguard Worker device_grads, 428*da0073e9SAndroid Build Coastguard Worker device_state_sums, 429*da0073e9SAndroid Build Coastguard Worker device_state_steps, 430*da0073e9SAndroid Build Coastguard Worker lr=lr, 431*da0073e9SAndroid Build Coastguard Worker weight_decay=weight_decay, 432*da0073e9SAndroid Build Coastguard Worker lr_decay=lr_decay, 433*da0073e9SAndroid Build Coastguard Worker eps=eps, 434*da0073e9SAndroid Build Coastguard Worker has_sparse_grad=True, 435*da0073e9SAndroid Build Coastguard Worker maximize=maximize, 436*da0073e9SAndroid Build Coastguard Worker differentiable=differentiable, 437*da0073e9SAndroid Build Coastguard Worker has_complex=has_complex, 438*da0073e9SAndroid Build Coastguard Worker grad_scale=grad_scale, 439*da0073e9SAndroid Build Coastguard Worker found_inf=found_inf, 440*da0073e9SAndroid Build Coastguard Worker ) 441*da0073e9SAndroid Build Coastguard Worker continue 442*da0073e9SAndroid Build Coastguard Worker 443*da0073e9SAndroid Build Coastguard Worker # Handle complex parameters 444*da0073e9SAndroid Build Coastguard Worker if has_complex: 445*da0073e9SAndroid Build Coastguard Worker _view_as_real(device_params, device_grads, device_state_sums) 446*da0073e9SAndroid Build Coastguard Worker 447*da0073e9SAndroid Build Coastguard Worker if maximize: 448*da0073e9SAndroid Build Coastguard Worker device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment] 449*da0073e9SAndroid Build Coastguard Worker 450*da0073e9SAndroid Build Coastguard Worker # Update steps 451*da0073e9SAndroid Build Coastguard Worker # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over 452*da0073e9SAndroid Build Coastguard Worker # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just 453*da0073e9SAndroid Build Coastguard Worker # wrapped it once now. The alpha is required to assure we go to the right overload. 454*da0073e9SAndroid Build Coastguard Worker if not torch._utils.is_compiling() and device_state_steps[0].is_cpu: 455*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_( 456*da0073e9SAndroid Build Coastguard Worker device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0 457*da0073e9SAndroid Build Coastguard Worker ) 458*da0073e9SAndroid Build Coastguard Worker else: 459*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_(device_state_steps, 1) 460*da0073e9SAndroid Build Coastguard Worker 461*da0073e9SAndroid Build Coastguard Worker if weight_decay != 0: 462*da0073e9SAndroid Build Coastguard Worker # Re-use the intermediate memory (device_grads) already allocated for maximize 463*da0073e9SAndroid Build Coastguard Worker if maximize: 464*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_(device_grads, device_params, alpha=weight_decay) 465*da0073e9SAndroid Build Coastguard Worker else: 466*da0073e9SAndroid Build Coastguard Worker device_grads = torch._foreach_add( # type: ignore[assignment] 467*da0073e9SAndroid Build Coastguard Worker device_grads, device_params, alpha=weight_decay 468*da0073e9SAndroid Build Coastguard Worker ) 469*da0073e9SAndroid Build Coastguard Worker 470*da0073e9SAndroid Build Coastguard Worker minus_clr = [ 471*da0073e9SAndroid Build Coastguard Worker -lr / (1 + (_get_value(step) - 1) * lr_decay) for step in device_state_steps 472*da0073e9SAndroid Build Coastguard Worker ] 473*da0073e9SAndroid Build Coastguard Worker 474*da0073e9SAndroid Build Coastguard Worker torch._foreach_addcmul_(device_state_sums, device_grads, device_grads, value=1) 475*da0073e9SAndroid Build Coastguard Worker 476*da0073e9SAndroid Build Coastguard Worker std = torch._foreach_sqrt(device_state_sums) 477*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_(std, eps) 478*da0073e9SAndroid Build Coastguard Worker 479*da0073e9SAndroid Build Coastguard Worker if weight_decay != 0 or maximize: 480*da0073e9SAndroid Build Coastguard Worker # Again, re-use the intermediate memory (device_grads) already allocated 481*da0073e9SAndroid Build Coastguard Worker torch._foreach_mul_(device_grads, minus_clr) 482*da0073e9SAndroid Build Coastguard Worker numerator = device_grads 483*da0073e9SAndroid Build Coastguard Worker else: 484*da0073e9SAndroid Build Coastguard Worker numerator = torch._foreach_mul(device_grads, minus_clr) # type: ignore[assignment] 485*da0073e9SAndroid Build Coastguard Worker 486*da0073e9SAndroid Build Coastguard Worker torch._foreach_addcdiv_(device_params, numerator, std) 487*da0073e9SAndroid Build Coastguard Worker 488*da0073e9SAndroid Build Coastguard Worker 489*da0073e9SAndroid Build Coastguard Workerdef _fused_adagrad( 490*da0073e9SAndroid Build Coastguard Worker params: List[Tensor], 491*da0073e9SAndroid Build Coastguard Worker grads: List[Tensor], 492*da0073e9SAndroid Build Coastguard Worker state_sums: List[Tensor], 493*da0073e9SAndroid Build Coastguard Worker state_steps: List[Tensor], 494*da0073e9SAndroid Build Coastguard Worker grad_scale: Optional[Tensor], 495*da0073e9SAndroid Build Coastguard Worker found_inf: Optional[Tensor], 496*da0073e9SAndroid Build Coastguard Worker *, 497*da0073e9SAndroid Build Coastguard Worker lr: float, 498*da0073e9SAndroid Build Coastguard Worker weight_decay: float, 499*da0073e9SAndroid Build Coastguard Worker lr_decay: float, 500*da0073e9SAndroid Build Coastguard Worker eps: float, 501*da0073e9SAndroid Build Coastguard Worker has_sparse_grad: bool, 502*da0073e9SAndroid Build Coastguard Worker maximize: bool, 503*da0073e9SAndroid Build Coastguard Worker differentiable: bool, 504*da0073e9SAndroid Build Coastguard Worker has_complex: bool, 505*da0073e9SAndroid Build Coastguard Worker) -> None: 506*da0073e9SAndroid Build Coastguard Worker if not params: 507*da0073e9SAndroid Build Coastguard Worker return 508*da0073e9SAndroid Build Coastguard Worker if has_sparse_grad or has_complex: 509*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("`fused` does not support sparse grad or complex param") 510*da0073e9SAndroid Build Coastguard Worker 511*da0073e9SAndroid Build Coastguard Worker if differentiable: 512*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 513*da0073e9SAndroid Build Coastguard Worker "adagrad with fused=True does not support differentiable=True" 514*da0073e9SAndroid Build Coastguard Worker ) 515*da0073e9SAndroid Build Coastguard Worker 516*da0073e9SAndroid Build Coastguard Worker grad_scale_dict = ( 517*da0073e9SAndroid Build Coastguard Worker {grad_scale.device: grad_scale} if grad_scale is not None else None 518*da0073e9SAndroid Build Coastguard Worker ) 519*da0073e9SAndroid Build Coastguard Worker found_inf_dict = {found_inf.device: found_inf} if found_inf is not None else None 520*da0073e9SAndroid Build Coastguard Worker 521*da0073e9SAndroid Build Coastguard Worker grouped_tensors = Optimizer._group_tensors_by_device_and_dtype( 522*da0073e9SAndroid Build Coastguard Worker [params, grads, state_sums, state_steps] # type: ignore[list-item] 523*da0073e9SAndroid Build Coastguard Worker ) 524*da0073e9SAndroid Build Coastguard Worker for (device, _), ( 525*da0073e9SAndroid Build Coastguard Worker ( 526*da0073e9SAndroid Build Coastguard Worker device_params_, 527*da0073e9SAndroid Build Coastguard Worker device_grads_, 528*da0073e9SAndroid Build Coastguard Worker device_state_sums_, 529*da0073e9SAndroid Build Coastguard Worker device_state_steps_, 530*da0073e9SAndroid Build Coastguard Worker ), 531*da0073e9SAndroid Build Coastguard Worker _, 532*da0073e9SAndroid Build Coastguard Worker ) in grouped_tensors.items(): 533*da0073e9SAndroid Build Coastguard Worker device_params = cast(List[Tensor], device_params_) 534*da0073e9SAndroid Build Coastguard Worker device_grads = cast(List[Tensor], device_grads_) 535*da0073e9SAndroid Build Coastguard Worker device_state_sums = cast(List[Tensor], device_state_sums_) 536*da0073e9SAndroid Build Coastguard Worker device_state_steps = cast(List[Tensor], device_state_steps_) 537*da0073e9SAndroid Build Coastguard Worker 538*da0073e9SAndroid Build Coastguard Worker device_grad_scale, device_found_inf = None, None 539*da0073e9SAndroid Build Coastguard Worker if grad_scale is not None and grad_scale_dict is not None: 540*da0073e9SAndroid Build Coastguard Worker if device not in grad_scale_dict: 541*da0073e9SAndroid Build Coastguard Worker grad_scale_dict[device] = grad_scale.to(device, non_blocking=True) # type: ignore[index] 542*da0073e9SAndroid Build Coastguard Worker device_grad_scale = grad_scale_dict[device] # type: ignore[index] 543*da0073e9SAndroid Build Coastguard Worker if found_inf is not None and found_inf_dict is not None: 544*da0073e9SAndroid Build Coastguard Worker if found_inf not in found_inf_dict: 545*da0073e9SAndroid Build Coastguard Worker found_inf_dict[device] = found_inf.to(device, non_blocking=True) # type: ignore[index] 546*da0073e9SAndroid Build Coastguard Worker device_found_inf = found_inf_dict[device] # type: ignore[index] 547*da0073e9SAndroid Build Coastguard Worker torch._foreach_add_(device_state_steps, 1) 548*da0073e9SAndroid Build Coastguard Worker torch._fused_adagrad_( 549*da0073e9SAndroid Build Coastguard Worker device_params, 550*da0073e9SAndroid Build Coastguard Worker device_grads, 551*da0073e9SAndroid Build Coastguard Worker device_state_sums, 552*da0073e9SAndroid Build Coastguard Worker device_state_steps, 553*da0073e9SAndroid Build Coastguard Worker lr=lr, 554*da0073e9SAndroid Build Coastguard Worker lr_decay=lr_decay, 555*da0073e9SAndroid Build Coastguard Worker weight_decay=weight_decay, 556*da0073e9SAndroid Build Coastguard Worker eps=eps, 557*da0073e9SAndroid Build Coastguard Worker maximize=maximize, 558*da0073e9SAndroid Build Coastguard Worker grad_scale=device_grad_scale, 559*da0073e9SAndroid Build Coastguard Worker found_inf=device_found_inf, 560*da0073e9SAndroid Build Coastguard Worker ) 561*da0073e9SAndroid Build Coastguard Worker if device_found_inf is not None: 562*da0073e9SAndroid Build Coastguard Worker torch._foreach_sub_( 563*da0073e9SAndroid Build Coastguard Worker device_state_steps, [device_found_inf] * len(device_state_steps) 564*da0073e9SAndroid Build Coastguard Worker ) 565