1# mypy: allow-untyped-defs 2from typing import Dict, List, Optional 3 4import torch 5import torch.optim._functional as F 6from torch import Tensor 7 8 9__all__: List[str] = [] 10 11 12# Define a TorchScript compatible Functional Adadelta Optimizer 13# where we use these optimizer in a functional way. 14# Instead of using the `param.grad` when updating parameters, 15# we explicitly allow the distributed optimizer pass gradients to 16# the `step` function. In this way, we could separate the gradients 17# and parameters and allow multithreaded trainer to update the 18# parameters without data traces on accumulating to the same .grad. 19# NOTE: This should be only used by distributed optimizer internals 20# and not meant to expose to the user. 21@torch.jit.script 22class _FunctionalAdadelta: 23 def __init__( 24 self, 25 params: List[Tensor], 26 lr: float = 1.0, 27 rho: float = 0.9, 28 eps: float = 1e-6, 29 weight_decay: float = 0.0, 30 foreach: bool = False, 31 maximize: bool = False, 32 _allow_empty_param_list: bool = False, 33 ): 34 self.defaults = { 35 "lr": lr, 36 "rho": rho, 37 "eps": eps, 38 "weight_decay": weight_decay, 39 } 40 self.foreach = foreach 41 self.maximize = maximize 42 43 if len(params) == 0 and not _allow_empty_param_list: 44 raise ValueError("optimizer got an empty parameter list") 45 46 # NOTE: we only have one param_group and don't allow user to add additional 47 # param group as it's not a common use case. 48 self.param_group = {"params": params} 49 50 self.state = torch.jit.annotate(Dict[torch.Tensor, Dict[str, torch.Tensor]], {}) 51 52 def step(self, gradients: List[Optional[Tensor]]): 53 params = self.param_group["params"] 54 params_with_grad = [] 55 grads = [] 56 square_avgs = [] 57 acc_deltas = [] 58 state_steps = [] 59 lr = self.defaults["lr"] 60 rho = self.defaults["rho"] 61 eps = self.defaults["eps"] 62 weight_decay = self.defaults["weight_decay"] 63 64 if len(params) != len(gradients): 65 raise ValueError( 66 "the gradients passed in does not equal to the size of the parameters!" 67 + f"Params length: {len(params)}. " 68 + f"Gradients length: {len(gradients)}" 69 ) 70 has_complex = False 71 for param, gradient in zip(params, gradients): 72 if gradient is not None: 73 has_complex |= torch.is_complex(param) 74 params_with_grad.append(param) 75 grads.append(gradient) 76 # Lazy state initialization 77 if param not in self.state: 78 self.state[param] = {} 79 state = self.state[param] 80 state["step"] = torch.tensor(0.0) 81 state["square_avg"] = torch.zeros_like( 82 param, memory_format=torch.preserve_format 83 ) 84 state["acc_delta"] = torch.zeros_like( 85 param, memory_format=torch.preserve_format 86 ) 87 88 state = self.state[param] 89 square_avgs.append(state["square_avg"]) 90 acc_deltas.append(state["acc_delta"]) 91 state_steps.append(state["step"]) 92 93 with torch.no_grad(): 94 F.adadelta( 95 params_with_grad, 96 grads, 97 square_avgs, 98 acc_deltas, 99 state_steps, 100 lr=lr, 101 rho=rho, 102 eps=eps, 103 weight_decay=weight_decay, 104 foreach=self.foreach, 105 maximize=self.maximize, 106 has_complex=has_complex, 107 ) 108