1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-decorators 2*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 3*da0073e9SAndroid Build Coastguard Worker"""Base optimizer.""" 4*da0073e9SAndroid Build Coastguard Workerimport functools 5*da0073e9SAndroid Build Coastguard Workerimport warnings 6*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict, OrderedDict 7*da0073e9SAndroid Build Coastguard Workerfrom copy import deepcopy 8*da0073e9SAndroid Build Coastguard Workerfrom itertools import chain 9*da0073e9SAndroid Build Coastguard Workerfrom typing import ( 10*da0073e9SAndroid Build Coastguard Worker Any, 11*da0073e9SAndroid Build Coastguard Worker Callable, 12*da0073e9SAndroid Build Coastguard Worker cast, 13*da0073e9SAndroid Build Coastguard Worker DefaultDict, 14*da0073e9SAndroid Build Coastguard Worker Dict, 15*da0073e9SAndroid Build Coastguard Worker Hashable, 16*da0073e9SAndroid Build Coastguard Worker Iterable, 17*da0073e9SAndroid Build Coastguard Worker List, 18*da0073e9SAndroid Build Coastguard Worker Optional, 19*da0073e9SAndroid Build Coastguard Worker overload, 20*da0073e9SAndroid Build Coastguard Worker Set, 21*da0073e9SAndroid Build Coastguard Worker Tuple, 22*da0073e9SAndroid Build Coastguard Worker TypeVar, 23*da0073e9SAndroid Build Coastguard Worker Union, 24*da0073e9SAndroid Build Coastguard Worker) 25*da0073e9SAndroid Build Coastguard Workerfrom typing_extensions import ParamSpec, Self, TypeAlias 26*da0073e9SAndroid Build Coastguard Worker 27*da0073e9SAndroid Build Coastguard Workerimport torch 28*da0073e9SAndroid Build Coastguard Workerimport torch.utils.hooks as hooks 29*da0073e9SAndroid Build Coastguard Workerfrom torch._utils import is_compiling 30*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._foreach_utils import ( 31*da0073e9SAndroid Build Coastguard Worker _get_foreach_kernels_supported_devices, 32*da0073e9SAndroid Build Coastguard Worker _get_fused_kernels_supported_devices, 33*da0073e9SAndroid Build Coastguard Worker _group_tensors_by_device_and_dtype, 34*da0073e9SAndroid Build Coastguard Worker Indices, 35*da0073e9SAndroid Build Coastguard Worker TensorListList, 36*da0073e9SAndroid Build Coastguard Worker) 37*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.hooks import RemovableHandle 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard WorkerArgs: TypeAlias = Tuple[Any, ...] 41*da0073e9SAndroid Build Coastguard WorkerKwargs: TypeAlias = Dict[str, Any] 42*da0073e9SAndroid Build Coastguard WorkerStateDict: TypeAlias = Dict[str, Any] 43*da0073e9SAndroid Build Coastguard WorkerDeviceDict = Dict[Optional[torch.device], torch.Tensor] 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard WorkerGlobalOptimizerPreHook: TypeAlias = Callable[ 47*da0073e9SAndroid Build Coastguard Worker ["Optimizer", Args, Kwargs], Optional[Tuple[Args, Kwargs]] 48*da0073e9SAndroid Build Coastguard Worker] 49*da0073e9SAndroid Build Coastguard WorkerGlobalOptimizerPostHook: TypeAlias = Callable[["Optimizer", Args, Kwargs], None] 50*da0073e9SAndroid Build Coastguard Worker 51*da0073e9SAndroid Build Coastguard Worker__all__ = [ 52*da0073e9SAndroid Build Coastguard Worker "Optimizer", 53*da0073e9SAndroid Build Coastguard Worker "register_optimizer_step_pre_hook", 54*da0073e9SAndroid Build Coastguard Worker "register_optimizer_step_post_hook", 55*da0073e9SAndroid Build Coastguard Worker] 56*da0073e9SAndroid Build Coastguard Worker_global_optimizer_pre_hooks: Dict[int, GlobalOptimizerPreHook] = OrderedDict() 57*da0073e9SAndroid Build Coastguard Worker_global_optimizer_post_hooks: Dict[int, GlobalOptimizerPostHook] = OrderedDict() 58*da0073e9SAndroid Build Coastguard Worker_foreach_supported_types = [torch.Tensor, torch.nn.parameter.Parameter] 59*da0073e9SAndroid Build Coastguard Worker 60*da0073e9SAndroid Build Coastguard Worker 61*da0073e9SAndroid Build Coastguard Workerclass _RequiredParameter: 62*da0073e9SAndroid Build Coastguard Worker """Singleton class representing a required parameter for an Optimizer.""" 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker def __repr__(self) -> str: 65*da0073e9SAndroid Build Coastguard Worker return "<required parameter>" 66*da0073e9SAndroid Build Coastguard Worker 67*da0073e9SAndroid Build Coastguard Worker 68*da0073e9SAndroid Build Coastguard Workerrequired = _RequiredParameter() 69*da0073e9SAndroid Build Coastguard Worker 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Workerdef _use_grad_for_differentiable(func): 72*da0073e9SAndroid Build Coastguard Worker def _use_grad(self, *args, **kwargs): 73*da0073e9SAndroid Build Coastguard Worker import torch._dynamo 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker prev_grad = torch.is_grad_enabled() 76*da0073e9SAndroid Build Coastguard Worker try: 77*da0073e9SAndroid Build Coastguard Worker # Note on graph break below: 78*da0073e9SAndroid Build Coastguard Worker # we need to graph break to ensure that aot respects the no_grad annotation. 79*da0073e9SAndroid Build Coastguard Worker # This is important for perf because without this, functionalization will generate an epilogue 80*da0073e9SAndroid Build Coastguard Worker # which updates the mutated parameters of the optimizer which is *not* visible to inductor, as a result, 81*da0073e9SAndroid Build Coastguard Worker # inductor will allocate for every parameter in the model, which is horrible. 82*da0073e9SAndroid Build Coastguard Worker # With this, aot correctly sees that this is an inference graph, and functionalization will generate 83*da0073e9SAndroid Build Coastguard Worker # an epilogue which is appended to the graph, which *is* visible to inductor, as a result, inductor sees that 84*da0073e9SAndroid Build Coastguard Worker # step is in place and is able to avoid the extra allocation. 85*da0073e9SAndroid Build Coastguard Worker # In the future, we will either 1) continue to graph break on backward, so this graph break does not matter 86*da0073e9SAndroid Build Coastguard Worker # or 2) have a fully fused forward and backward graph, which will have no_grad by default, and we can remove this 87*da0073e9SAndroid Build Coastguard Worker # graph break to allow the fully fused fwd-bwd-optimizer graph to be compiled. 88*da0073e9SAndroid Build Coastguard Worker # see https://github.com/pytorch/pytorch/issues/104053 89*da0073e9SAndroid Build Coastguard Worker torch.set_grad_enabled(self.defaults["differentiable"]) 90*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 91*da0073e9SAndroid Build Coastguard Worker ret = func(self, *args, **kwargs) 92*da0073e9SAndroid Build Coastguard Worker finally: 93*da0073e9SAndroid Build Coastguard Worker torch._dynamo.graph_break() 94*da0073e9SAndroid Build Coastguard Worker torch.set_grad_enabled(prev_grad) 95*da0073e9SAndroid Build Coastguard Worker return ret 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Worker functools.update_wrapper(_use_grad, func) 98*da0073e9SAndroid Build Coastguard Worker return _use_grad 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Workerdef _get_value(x): 102*da0073e9SAndroid Build Coastguard Worker # item is significantly faster than a cpu tensor in eager mode 103*da0073e9SAndroid Build Coastguard Worker if not torch.jit.is_scripting() and is_compiling(): 104*da0073e9SAndroid Build Coastguard Worker return x 105*da0073e9SAndroid Build Coastguard Worker else: 106*da0073e9SAndroid Build Coastguard Worker return x.item() if isinstance(x, torch.Tensor) else x 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker 109*da0073e9SAndroid Build Coastguard Workerdef _stack_if_compiling(x): 110*da0073e9SAndroid Build Coastguard Worker if not torch.jit.is_scripting() and is_compiling(): 111*da0073e9SAndroid Build Coastguard Worker return torch.stack(x) 112*da0073e9SAndroid Build Coastguard Worker else: 113*da0073e9SAndroid Build Coastguard Worker return x 114*da0073e9SAndroid Build Coastguard Worker 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Workerdef _disable_dynamo_if_unsupported(single_tensor_fn=None): 117*da0073e9SAndroid Build Coastguard Worker # workaround for torchscript BC 118*da0073e9SAndroid Build Coastguard Worker # it requires all called functions to be in the 119*da0073e9SAndroid Build Coastguard Worker # global environment at the site at which the 120*da0073e9SAndroid Build Coastguard Worker # maybe_fallback closure is created 121*da0073e9SAndroid Build Coastguard Worker if single_tensor_fn: 122*da0073e9SAndroid Build Coastguard Worker globals()[single_tensor_fn.__name__] = single_tensor_fn 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker def wrapper(func): 125*da0073e9SAndroid Build Coastguard Worker import inspect 126*da0073e9SAndroid Build Coastguard Worker 127*da0073e9SAndroid Build Coastguard Worker disabled_func = torch._disable_dynamo(func) 128*da0073e9SAndroid Build Coastguard Worker ps = inspect.signature(func).parameters 129*da0073e9SAndroid Build Coastguard Worker has_state_steps = True 130*da0073e9SAndroid Build Coastguard Worker try: 131*da0073e9SAndroid Build Coastguard Worker state_steps_ind = list(ps.keys()).index("state_steps") 132*da0073e9SAndroid Build Coastguard Worker except ValueError: 133*da0073e9SAndroid Build Coastguard Worker has_state_steps = False 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Worker # Today, there are cases where we stack state steps 136*da0073e9SAndroid Build Coastguard Worker # and pass them as the value arg of foreach ops. 137*da0073e9SAndroid Build Coastguard Worker # Having state steps on cuda as the value arg is not supported in eager, 138*da0073e9SAndroid Build Coastguard Worker # but this only occurs in the rare case that the user explicitly deletes 139*da0073e9SAndroid Build Coastguard Worker # the capturable flag. If capturable=True, this is not a problem. 140*da0073e9SAndroid Build Coastguard Worker @functools.wraps(func) 141*da0073e9SAndroid Build Coastguard Worker def maybe_fallback(*args, **kwargs): 142*da0073e9SAndroid Build Coastguard Worker if is_compiling() and ( 143*da0073e9SAndroid Build Coastguard Worker not kwargs.get("capturable", False) 144*da0073e9SAndroid Build Coastguard Worker and has_state_steps 145*da0073e9SAndroid Build Coastguard Worker and (args[state_steps_ind] and args[state_steps_ind][0].is_cuda) 146*da0073e9SAndroid Build Coastguard Worker or ( 147*da0073e9SAndroid Build Coastguard Worker "state_steps" in kwargs 148*da0073e9SAndroid Build Coastguard Worker and kwargs["state_steps"] 149*da0073e9SAndroid Build Coastguard Worker and kwargs["state_steps"][0].is_cuda 150*da0073e9SAndroid Build Coastguard Worker ) 151*da0073e9SAndroid Build Coastguard Worker ): 152*da0073e9SAndroid Build Coastguard Worker return disabled_func(*args, **kwargs) 153*da0073e9SAndroid Build Coastguard Worker else: 154*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 155*da0073e9SAndroid Build Coastguard Worker 156*da0073e9SAndroid Build Coastguard Worker return maybe_fallback 157*da0073e9SAndroid Build Coastguard Worker 158*da0073e9SAndroid Build Coastguard Worker return wrapper 159*da0073e9SAndroid Build Coastguard Worker 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker# For any optimizer with a faster implementation, we attempt to default to the 162*da0073e9SAndroid Build Coastguard Worker# fastest + stablest whenever possible. For foreach, the requirements are to have 163*da0073e9SAndroid Build Coastguard Worker# native params all on CUDA. For fused, there's currently the additional requirement 164*da0073e9SAndroid Build Coastguard Worker# that the tensors' dtypes must be floating point. Neither alternative supports 165*da0073e9SAndroid Build Coastguard Worker# torch.jit.script nor differentiable, so we fall back to the single tensor 166*da0073e9SAndroid Build Coastguard Worker# implementation in those cases. 167*da0073e9SAndroid Build Coastguard Workerdef _default_to_fused_or_foreach( 168*da0073e9SAndroid Build Coastguard Worker params: List[torch.Tensor], differentiable: bool, use_fused: bool = False 169*da0073e9SAndroid Build Coastguard Worker) -> Tuple[bool, bool]: 170*da0073e9SAndroid Build Coastguard Worker if torch.jit.is_scripting() or differentiable: 171*da0073e9SAndroid Build Coastguard Worker return False, False 172*da0073e9SAndroid Build Coastguard Worker 173*da0073e9SAndroid Build Coastguard Worker fused_supported_devices = _get_fused_kernels_supported_devices() 174*da0073e9SAndroid Build Coastguard Worker foreach_supported_devices = _get_foreach_kernels_supported_devices() 175*da0073e9SAndroid Build Coastguard Worker fused = use_fused and all( 176*da0073e9SAndroid Build Coastguard Worker p is None 177*da0073e9SAndroid Build Coastguard Worker or ( 178*da0073e9SAndroid Build Coastguard Worker type(p) in _foreach_supported_types 179*da0073e9SAndroid Build Coastguard Worker and p.device.type in fused_supported_devices 180*da0073e9SAndroid Build Coastguard Worker and torch.is_floating_point(p) 181*da0073e9SAndroid Build Coastguard Worker ) 182*da0073e9SAndroid Build Coastguard Worker for p in params 183*da0073e9SAndroid Build Coastguard Worker ) 184*da0073e9SAndroid Build Coastguard Worker foreach = not fused and all( 185*da0073e9SAndroid Build Coastguard Worker p is None 186*da0073e9SAndroid Build Coastguard Worker or ( 187*da0073e9SAndroid Build Coastguard Worker type(p) in _foreach_supported_types 188*da0073e9SAndroid Build Coastguard Worker and p.device.type in foreach_supported_devices 189*da0073e9SAndroid Build Coastguard Worker ) 190*da0073e9SAndroid Build Coastguard Worker for p in params 191*da0073e9SAndroid Build Coastguard Worker ) 192*da0073e9SAndroid Build Coastguard Worker return fused, foreach 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Worker 195*da0073e9SAndroid Build Coastguard Workerdef _device_dtype_check_for_fused( 196*da0073e9SAndroid Build Coastguard Worker p: torch.Tensor, cuda_unsupported: bool = False 197*da0073e9SAndroid Build Coastguard Worker) -> None: 198*da0073e9SAndroid Build Coastguard Worker fused_supported_devices = _get_fused_kernels_supported_devices() 199*da0073e9SAndroid Build Coastguard Worker if cuda_unsupported: 200*da0073e9SAndroid Build Coastguard Worker fused_supported_devices.remove("cuda") 201*da0073e9SAndroid Build Coastguard Worker if not (p.device.type in fused_supported_devices and torch.is_floating_point(p)): 202*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 203*da0073e9SAndroid Build Coastguard Worker "`fused=True` requires all the params to be floating point Tensors of " 204*da0073e9SAndroid Build Coastguard Worker f"supported devices: {fused_supported_devices} but {p.dtype} and {p.device.type}" 205*da0073e9SAndroid Build Coastguard Worker ) 206*da0073e9SAndroid Build Coastguard Worker 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Workerdef _view_as_real(params, *state_and_grads): 209*da0073e9SAndroid Build Coastguard Worker for i, p in enumerate(params): 210*da0073e9SAndroid Build Coastguard Worker if torch.is_complex(p): 211*da0073e9SAndroid Build Coastguard Worker params[i] = torch.view_as_real(params[i]) 212*da0073e9SAndroid Build Coastguard Worker for s in state_and_grads: 213*da0073e9SAndroid Build Coastguard Worker s[i] = torch.view_as_real(s[i]) 214*da0073e9SAndroid Build Coastguard Worker 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Workerdef _get_scalar_dtype(is_fused=None): 217*da0073e9SAndroid Build Coastguard Worker if is_fused: 218*da0073e9SAndroid Build Coastguard Worker return torch.float32 219*da0073e9SAndroid Build Coastguard Worker return ( 220*da0073e9SAndroid Build Coastguard Worker torch.float64 if torch.get_default_dtype() == torch.float64 else torch.float32 221*da0073e9SAndroid Build Coastguard Worker ) 222*da0073e9SAndroid Build Coastguard Worker 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Workerdef _get_capturable_supported_devices(supports_xla: bool = True) -> List[str]: 225*da0073e9SAndroid Build Coastguard Worker r"""Return the device type list that supports capturable optimizer.""" 226*da0073e9SAndroid Build Coastguard Worker capturable_supported_devices = ["cuda", "xpu", "hpu"] 227*da0073e9SAndroid Build Coastguard Worker if not torch.jit.is_scripting(): 228*da0073e9SAndroid Build Coastguard Worker capturable_supported_devices.append(torch._C._get_privateuse1_backend_name()) 229*da0073e9SAndroid Build Coastguard Worker if supports_xla: 230*da0073e9SAndroid Build Coastguard Worker capturable_supported_devices.append("xla") 231*da0073e9SAndroid Build Coastguard Worker return capturable_supported_devices 232*da0073e9SAndroid Build Coastguard Worker 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker# Common doc strings among optimizers 235*da0073e9SAndroid Build Coastguard Worker_foreach_doc = r"""foreach (bool, optional): whether foreach implementation of optimizer 236*da0073e9SAndroid Build Coastguard Worker is used. If unspecified by the user (so foreach is None), we will try to use 237*da0073e9SAndroid Build Coastguard Worker foreach over the for-loop implementation on CUDA, since it is usually 238*da0073e9SAndroid Build Coastguard Worker significantly more performant. Note that the foreach implementation uses 239*da0073e9SAndroid Build Coastguard Worker ~ sizeof(params) more peak memory than the for-loop version due to the intermediates 240*da0073e9SAndroid Build Coastguard Worker being a tensorlist vs just one tensor. If memory is prohibitive, batch fewer 241*da0073e9SAndroid Build Coastguard Worker parameters through the optimizer at a time or switch this flag to False (default: None)""" 242*da0073e9SAndroid Build Coastguard Worker 243*da0073e9SAndroid Build Coastguard Worker_fused_doc = r"""fused (bool, optional): whether the fused implementation is used. 244*da0073e9SAndroid Build Coastguard Worker Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16` 245*da0073e9SAndroid Build Coastguard Worker are supported. (default: None) 246*da0073e9SAndroid Build Coastguard Worker 247*da0073e9SAndroid Build Coastguard Worker .. note:: The foreach and fused implementations are typically faster than the for-loop, 248*da0073e9SAndroid Build Coastguard Worker single-tensor implementation, with fused being theoretically fastest with both 249*da0073e9SAndroid Build Coastguard Worker vertical and horizontal fusion. As such, if the user has not specified either 250*da0073e9SAndroid Build Coastguard Worker flag (i.e., when foreach = fused = None), we will attempt defaulting to the foreach 251*da0073e9SAndroid Build Coastguard Worker implementation when the tensors are all on CUDA. Why not fused? Since the fused 252*da0073e9SAndroid Build Coastguard Worker implementation is relatively new, we want to give it sufficient bake-in time. 253*da0073e9SAndroid Build Coastguard Worker To specify fused, pass True for fused. To force running the for-loop 254*da0073e9SAndroid Build Coastguard Worker implementation, pass False for either foreach or fused. """ 255*da0073e9SAndroid Build Coastguard Worker 256*da0073e9SAndroid Build Coastguard Worker_capturable_doc = r"""capturable (bool, optional): whether this instance is safe to 257*da0073e9SAndroid Build Coastguard Worker capture in a CUDA graph. Passing True can impair ungraphed performance, 258*da0073e9SAndroid Build Coastguard Worker so if you don't intend to graph capture this instance, leave it False 259*da0073e9SAndroid Build Coastguard Worker (default: False)""" 260*da0073e9SAndroid Build Coastguard Worker 261*da0073e9SAndroid Build Coastguard Worker_differentiable_doc = r"""differentiable (bool, optional): whether autograd should 262*da0073e9SAndroid Build Coastguard Worker occur through the optimizer step in training. Otherwise, the step() 263*da0073e9SAndroid Build Coastguard Worker function runs in a torch.no_grad() context. Setting to True can impair 264*da0073e9SAndroid Build Coastguard Worker performance, so leave it False if you don't intend to run autograd 265*da0073e9SAndroid Build Coastguard Worker through this instance (default: False)""" 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Worker_maximize_doc = r"""maximize (bool, optional): maximize the objective with respect to the 268*da0073e9SAndroid Build Coastguard Worker params, instead of minimizing (default: False)""" 269*da0073e9SAndroid Build Coastguard Worker 270*da0073e9SAndroid Build Coastguard Worker 271*da0073e9SAndroid Build Coastguard Workerdef register_optimizer_step_pre_hook(hook: GlobalOptimizerPreHook) -> RemovableHandle: 272*da0073e9SAndroid Build Coastguard Worker r"""Register a pre hook common to all optimizers. 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker The hook should have the following signature:: 275*da0073e9SAndroid Build Coastguard Worker 276*da0073e9SAndroid Build Coastguard Worker hook(optimizer, args, kwargs) -> None or modified args and kwargs 277*da0073e9SAndroid Build Coastguard Worker 278*da0073e9SAndroid Build Coastguard Worker Args: 279*da0073e9SAndroid Build Coastguard Worker hook (Callable): A user defined hook which is registered on all optimizers. 280*da0073e9SAndroid Build Coastguard Worker 281*da0073e9SAndroid Build Coastguard Worker Returns: 282*da0073e9SAndroid Build Coastguard Worker :class:`torch.utils.hooks.RemovableHandle`: 283*da0073e9SAndroid Build Coastguard Worker a handle that can be used to remove the added hook by calling 284*da0073e9SAndroid Build Coastguard Worker ``handle.remove()`` 285*da0073e9SAndroid Build Coastguard Worker """ 286*da0073e9SAndroid Build Coastguard Worker handle = hooks.RemovableHandle(_global_optimizer_pre_hooks) 287*da0073e9SAndroid Build Coastguard Worker _global_optimizer_pre_hooks[handle.id] = hook 288*da0073e9SAndroid Build Coastguard Worker return handle 289*da0073e9SAndroid Build Coastguard Worker 290*da0073e9SAndroid Build Coastguard Worker 291*da0073e9SAndroid Build Coastguard Workerdef register_optimizer_step_post_hook(hook: GlobalOptimizerPostHook) -> RemovableHandle: 292*da0073e9SAndroid Build Coastguard Worker r"""Register a post hook common to all optimizers. 293*da0073e9SAndroid Build Coastguard Worker 294*da0073e9SAndroid Build Coastguard Worker The hook should have the following signature:: 295*da0073e9SAndroid Build Coastguard Worker 296*da0073e9SAndroid Build Coastguard Worker hook(optimizer, args, kwargs) -> None 297*da0073e9SAndroid Build Coastguard Worker 298*da0073e9SAndroid Build Coastguard Worker Args: 299*da0073e9SAndroid Build Coastguard Worker hook (Callable): A user defined hook which is registered on all optimizers. 300*da0073e9SAndroid Build Coastguard Worker 301*da0073e9SAndroid Build Coastguard Worker Returns: 302*da0073e9SAndroid Build Coastguard Worker :class:`torch.utils.hooks.RemovableHandle`: 303*da0073e9SAndroid Build Coastguard Worker a handle that can be used to remove the added hook by calling 304*da0073e9SAndroid Build Coastguard Worker ``handle.remove()`` 305*da0073e9SAndroid Build Coastguard Worker """ 306*da0073e9SAndroid Build Coastguard Worker handle = hooks.RemovableHandle(_global_optimizer_post_hooks) 307*da0073e9SAndroid Build Coastguard Worker _global_optimizer_post_hooks[handle.id] = hook 308*da0073e9SAndroid Build Coastguard Worker return handle 309*da0073e9SAndroid Build Coastguard Worker 310*da0073e9SAndroid Build Coastguard Worker 311*da0073e9SAndroid Build Coastguard WorkerParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]] 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker_P = ParamSpec("_P") 314*da0073e9SAndroid Build Coastguard WorkerR = TypeVar("R") 315*da0073e9SAndroid Build Coastguard WorkerT = TypeVar("T") 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard Worker 318*da0073e9SAndroid Build Coastguard Workerclass Optimizer: 319*da0073e9SAndroid Build Coastguard Worker r"""Base class for all optimizers. 320*da0073e9SAndroid Build Coastguard Worker 321*da0073e9SAndroid Build Coastguard Worker .. warning:: 322*da0073e9SAndroid Build Coastguard Worker Parameters need to be specified as collections that have a deterministic 323*da0073e9SAndroid Build Coastguard Worker ordering that is consistent between runs. Examples of objects that don't 324*da0073e9SAndroid Build Coastguard Worker satisfy those properties are sets and iterators over values of dictionaries. 325*da0073e9SAndroid Build Coastguard Worker 326*da0073e9SAndroid Build Coastguard Worker Args: 327*da0073e9SAndroid Build Coastguard Worker params (iterable): an iterable of :class:`torch.Tensor` s or 328*da0073e9SAndroid Build Coastguard Worker :class:`dict` s. Specifies what Tensors should be optimized. 329*da0073e9SAndroid Build Coastguard Worker defaults: (dict): a dict containing default values of optimization 330*da0073e9SAndroid Build Coastguard Worker options (used when a parameter group doesn't specify them). 331*da0073e9SAndroid Build Coastguard Worker """ 332*da0073e9SAndroid Build Coastguard Worker 333*da0073e9SAndroid Build Coastguard Worker OptimizerPreHook: TypeAlias = Callable[[Self, Args, Kwargs], Optional[Tuple[Args, Kwargs]]] # type: ignore[misc] 334*da0073e9SAndroid Build Coastguard Worker OptimizerPostHook: TypeAlias = Callable[[Self, Args, Kwargs], None] # type: ignore[misc] 335*da0073e9SAndroid Build Coastguard Worker 336*da0073e9SAndroid Build Coastguard Worker _optimizer_step_pre_hooks: Dict[int, OptimizerPreHook] 337*da0073e9SAndroid Build Coastguard Worker _optimizer_step_post_hooks: Dict[int, OptimizerPostHook] 338*da0073e9SAndroid Build Coastguard Worker _optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]' 339*da0073e9SAndroid Build Coastguard Worker _optimizer_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]' 340*da0073e9SAndroid Build Coastguard Worker _optimizer_load_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]' 341*da0073e9SAndroid Build Coastguard Worker _optimizer_load_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]' 342*da0073e9SAndroid Build Coastguard Worker 343*da0073e9SAndroid Build Coastguard Worker def __init__(self, params: ParamsT, defaults: Dict[str, Any]) -> None: # noqa: D107 344*da0073e9SAndroid Build Coastguard Worker torch._C._log_api_usage_once("python.optimizer") 345*da0073e9SAndroid Build Coastguard Worker self.defaults = defaults 346*da0073e9SAndroid Build Coastguard Worker self._optimizer_step_pre_hooks = OrderedDict() 347*da0073e9SAndroid Build Coastguard Worker self._optimizer_step_post_hooks = OrderedDict() 348*da0073e9SAndroid Build Coastguard Worker self._optimizer_state_dict_pre_hooks = OrderedDict() 349*da0073e9SAndroid Build Coastguard Worker self._optimizer_state_dict_post_hooks = OrderedDict() 350*da0073e9SAndroid Build Coastguard Worker self._optimizer_load_state_dict_pre_hooks = OrderedDict() 351*da0073e9SAndroid Build Coastguard Worker self._optimizer_load_state_dict_post_hooks = OrderedDict() 352*da0073e9SAndroid Build Coastguard Worker 353*da0073e9SAndroid Build Coastguard Worker self._patch_step_function() 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard Worker if isinstance(params, torch.Tensor): 356*da0073e9SAndroid Build Coastguard Worker raise TypeError( 357*da0073e9SAndroid Build Coastguard Worker "params argument given to the optimizer should be " 358*da0073e9SAndroid Build Coastguard Worker "an iterable of Tensors or dicts, but got " + torch.typename(params) 359*da0073e9SAndroid Build Coastguard Worker ) 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Worker self.state: DefaultDict[torch.Tensor, Any] = defaultdict(dict) 362*da0073e9SAndroid Build Coastguard Worker self.param_groups: List[Dict[str, Any]] = [] 363*da0073e9SAndroid Build Coastguard Worker 364*da0073e9SAndroid Build Coastguard Worker param_groups = list(params) 365*da0073e9SAndroid Build Coastguard Worker if len(param_groups) == 0: 366*da0073e9SAndroid Build Coastguard Worker raise ValueError("optimizer got an empty parameter list") 367*da0073e9SAndroid Build Coastguard Worker if not isinstance(param_groups[0], dict): 368*da0073e9SAndroid Build Coastguard Worker param_groups = [{"params": param_groups}] 369*da0073e9SAndroid Build Coastguard Worker 370*da0073e9SAndroid Build Coastguard Worker for param_group in param_groups: 371*da0073e9SAndroid Build Coastguard Worker self.add_param_group(cast(dict, param_group)) 372*da0073e9SAndroid Build Coastguard Worker 373*da0073e9SAndroid Build Coastguard Worker # Allows _cuda_graph_capture_health_check to rig a poor man's TORCH_WARN_ONCE in python, 374*da0073e9SAndroid Build Coastguard Worker # which I don't think exists 375*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/72948 376*da0073e9SAndroid Build Coastguard Worker self._warned_capturable_if_run_uncaptured = True 377*da0073e9SAndroid Build Coastguard Worker 378*da0073e9SAndroid Build Coastguard Worker def __getstate__(self) -> Dict[str, Any]: # noqa: D105 379*da0073e9SAndroid Build Coastguard Worker return { 380*da0073e9SAndroid Build Coastguard Worker "defaults": self.defaults, 381*da0073e9SAndroid Build Coastguard Worker "state": self.state, 382*da0073e9SAndroid Build Coastguard Worker "param_groups": self.param_groups, 383*da0073e9SAndroid Build Coastguard Worker } 384*da0073e9SAndroid Build Coastguard Worker 385*da0073e9SAndroid Build Coastguard Worker def __setstate__(self, state: Dict[str, Any]) -> None: # noqa: D105 386*da0073e9SAndroid Build Coastguard Worker self.__dict__.update(state) 387*da0073e9SAndroid Build Coastguard Worker if "_optimizer_step_pre_hooks" not in self.__dict__: 388*da0073e9SAndroid Build Coastguard Worker self._optimizer_step_pre_hooks = OrderedDict() 389*da0073e9SAndroid Build Coastguard Worker if "_optimizer_step_post_hooks" not in self.__dict__: 390*da0073e9SAndroid Build Coastguard Worker self._optimizer_step_post_hooks = OrderedDict() 391*da0073e9SAndroid Build Coastguard Worker if "_optimizer_state_dict_pre_hooks" not in self.__dict__: 392*da0073e9SAndroid Build Coastguard Worker self._optimizer_state_dict_pre_hooks = OrderedDict() 393*da0073e9SAndroid Build Coastguard Worker if "_optimizer_state_dict_post_hooks" not in self.__dict__: 394*da0073e9SAndroid Build Coastguard Worker self._optimizer_state_dict_post_hooks = OrderedDict() 395*da0073e9SAndroid Build Coastguard Worker if "_optimizer_load_state_dict_pre_hooks" not in self.__dict__: 396*da0073e9SAndroid Build Coastguard Worker self._optimizer_load_state_dict_pre_hooks = OrderedDict() 397*da0073e9SAndroid Build Coastguard Worker if "_optimizer_load_state_dict_post_hooks" not in self.__dict__: 398*da0073e9SAndroid Build Coastguard Worker self._optimizer_load_state_dict_post_hooks = OrderedDict() 399*da0073e9SAndroid Build Coastguard Worker self._patch_step_function() # To support multiprocessing pickle/unpickle 400*da0073e9SAndroid Build Coastguard Worker self.defaults.setdefault("differentiable", False) 401*da0073e9SAndroid Build Coastguard Worker 402*da0073e9SAndroid Build Coastguard Worker def __repr__(self) -> str: # noqa: D105 403*da0073e9SAndroid Build Coastguard Worker format_string = self.__class__.__name__ + " (" 404*da0073e9SAndroid Build Coastguard Worker for i, group in enumerate(self.param_groups): 405*da0073e9SAndroid Build Coastguard Worker format_string += "\n" 406*da0073e9SAndroid Build Coastguard Worker format_string += f"Parameter Group {i}\n" 407*da0073e9SAndroid Build Coastguard Worker for key in sorted(group.keys()): 408*da0073e9SAndroid Build Coastguard Worker if key != "params": 409*da0073e9SAndroid Build Coastguard Worker format_string += f" {key}: {group[key]}\n" 410*da0073e9SAndroid Build Coastguard Worker format_string += ")" 411*da0073e9SAndroid Build Coastguard Worker return format_string 412*da0073e9SAndroid Build Coastguard Worker 413*da0073e9SAndroid Build Coastguard Worker # Currently needed by Adam and AdamW 414*da0073e9SAndroid Build Coastguard Worker def _cuda_graph_capture_health_check(self) -> None: 415*da0073e9SAndroid Build Coastguard Worker # Note [torch.compile x capturable] 416*da0073e9SAndroid Build Coastguard Worker # If we are compiling, we try to take the capturable path automatically by 417*da0073e9SAndroid Build Coastguard Worker # setting the flag to True during tracing. Due to this, we skip all the checks 418*da0073e9SAndroid Build Coastguard Worker # normally required for determining whether we can use CUDA graphs and 419*da0073e9SAndroid Build Coastguard Worker # shunt the responsibility to torch.inductor. This saves time during tracing 420*da0073e9SAndroid Build Coastguard Worker # since the checks are slow without sacrificing UX since inductor will warn 421*da0073e9SAndroid Build Coastguard Worker # later if CUDA graphs cannot be enabled, e.g., 422*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/blob/d3ba8901d8640eb16f88b2bfef9df7fa383d4b47/torch/_inductor/compile_fx.py#L390. 423*da0073e9SAndroid Build Coastguard Worker # Thus, when compiling, inductor will determine if cudagraphs 424*da0073e9SAndroid Build Coastguard Worker # can be enabled based on whether there is input mutation or CPU tensors. 425*da0073e9SAndroid Build Coastguard Worker if ( 426*da0073e9SAndroid Build Coastguard Worker not is_compiling() 427*da0073e9SAndroid Build Coastguard Worker and torch.backends.cuda.is_built() 428*da0073e9SAndroid Build Coastguard Worker and torch.cuda.is_available() 429*da0073e9SAndroid Build Coastguard Worker ): 430*da0073e9SAndroid Build Coastguard Worker capturing = torch.cuda.is_current_stream_capturing() 431*da0073e9SAndroid Build Coastguard Worker 432*da0073e9SAndroid Build Coastguard Worker if capturing and not all( 433*da0073e9SAndroid Build Coastguard Worker group["capturable"] for group in self.param_groups 434*da0073e9SAndroid Build Coastguard Worker ): 435*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 436*da0073e9SAndroid Build Coastguard Worker "Attempting CUDA graph capture of step() for an instance of " 437*da0073e9SAndroid Build Coastguard Worker + self.__class__.__name__ 438*da0073e9SAndroid Build Coastguard Worker + " but param_groups' capturable is False." 439*da0073e9SAndroid Build Coastguard Worker ) 440*da0073e9SAndroid Build Coastguard Worker 441*da0073e9SAndroid Build Coastguard Worker if ( 442*da0073e9SAndroid Build Coastguard Worker (not getattr(self, "_warned_capturable_if_run_uncaptured", False)) 443*da0073e9SAndroid Build Coastguard Worker and all(group["capturable"] for group in self.param_groups) 444*da0073e9SAndroid Build Coastguard Worker and (not capturing) 445*da0073e9SAndroid Build Coastguard Worker ): 446*da0073e9SAndroid Build Coastguard Worker warnings.warn( 447*da0073e9SAndroid Build Coastguard Worker "This instance was constructed with capturable=True or some of all the param_groups came with capturable=True, " 448*da0073e9SAndroid Build Coastguard Worker "but step() is running without CUDA graph capture. If you never intend to graph-capture this " 449*da0073e9SAndroid Build Coastguard Worker "instance, capturable=True can impair performance, and you should set capturable=False." 450*da0073e9SAndroid Build Coastguard Worker ) 451*da0073e9SAndroid Build Coastguard Worker self._warned_capturable_if_run_uncaptured = True 452*da0073e9SAndroid Build Coastguard Worker 453*da0073e9SAndroid Build Coastguard Worker def _optimizer_step_code(self) -> None: 454*da0073e9SAndroid Build Coastguard Worker """Entry point for `torch.profile.profiler`. 455*da0073e9SAndroid Build Coastguard Worker 456*da0073e9SAndroid Build Coastguard Worker When python tracing is enabled the profiler will hook into this 457*da0073e9SAndroid Build Coastguard Worker function at the CPython level to inspect the optimizer's parameters and 458*da0073e9SAndroid Build Coastguard Worker param groups. It is called it after `step()` since many optimizers 459*da0073e9SAndroid Build Coastguard Worker lazily initialize state. 460*da0073e9SAndroid Build Coastguard Worker 461*da0073e9SAndroid Build Coastguard Worker This is a workaround due to lack of a proper step hook on the optimizer, 462*da0073e9SAndroid Build Coastguard Worker and will be removed if it exists. 463*da0073e9SAndroid Build Coastguard Worker """ 464*da0073e9SAndroid Build Coastguard Worker 465*da0073e9SAndroid Build Coastguard Worker @staticmethod 466*da0073e9SAndroid Build Coastguard Worker def profile_hook_step(func: Callable[_P, R]) -> Callable[_P, R]: # noqa: D102 467*da0073e9SAndroid Build Coastguard Worker @functools.wraps(func) 468*da0073e9SAndroid Build Coastguard Worker def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> R: 469*da0073e9SAndroid Build Coastguard Worker self, *_ = args 470*da0073e9SAndroid Build Coastguard Worker self = cast(Optimizer, self) 471*da0073e9SAndroid Build Coastguard Worker profile_name = f"Optimizer.step#{self.__class__.__name__}.step" 472*da0073e9SAndroid Build Coastguard Worker with torch.autograd.profiler.record_function(profile_name): 473*da0073e9SAndroid Build Coastguard Worker # call optimizer step pre hooks 474*da0073e9SAndroid Build Coastguard Worker for pre_hook in chain( 475*da0073e9SAndroid Build Coastguard Worker _global_optimizer_pre_hooks.values(), 476*da0073e9SAndroid Build Coastguard Worker self._optimizer_step_pre_hooks.values(), 477*da0073e9SAndroid Build Coastguard Worker ): 478*da0073e9SAndroid Build Coastguard Worker result = pre_hook(self, args, kwargs) 479*da0073e9SAndroid Build Coastguard Worker if result is not None: 480*da0073e9SAndroid Build Coastguard Worker if isinstance(result, tuple) and len(result) == 2: 481*da0073e9SAndroid Build Coastguard Worker args, kwargs = result # type: ignore[assignment] 482*da0073e9SAndroid Build Coastguard Worker else: 483*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 484*da0073e9SAndroid Build Coastguard Worker f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}." 485*da0073e9SAndroid Build Coastguard Worker ) 486*da0073e9SAndroid Build Coastguard Worker 487*da0073e9SAndroid Build Coastguard Worker out = func(*args, **kwargs) 488*da0073e9SAndroid Build Coastguard Worker self._optimizer_step_code() 489*da0073e9SAndroid Build Coastguard Worker 490*da0073e9SAndroid Build Coastguard Worker # call optimizer step post hooks 491*da0073e9SAndroid Build Coastguard Worker for post_hook in chain( 492*da0073e9SAndroid Build Coastguard Worker self._optimizer_step_post_hooks.values(), 493*da0073e9SAndroid Build Coastguard Worker _global_optimizer_post_hooks.values(), 494*da0073e9SAndroid Build Coastguard Worker ): 495*da0073e9SAndroid Build Coastguard Worker post_hook(self, args, kwargs) 496*da0073e9SAndroid Build Coastguard Worker 497*da0073e9SAndroid Build Coastguard Worker return out 498*da0073e9SAndroid Build Coastguard Worker 499*da0073e9SAndroid Build Coastguard Worker return wrapper 500*da0073e9SAndroid Build Coastguard Worker 501*da0073e9SAndroid Build Coastguard Worker @staticmethod 502*da0073e9SAndroid Build Coastguard Worker def _group_tensors_by_device_and_dtype( 503*da0073e9SAndroid Build Coastguard Worker tensorlistlist: TensorListList, 504*da0073e9SAndroid Build Coastguard Worker with_indices: bool = False, 505*da0073e9SAndroid Build Coastguard Worker ) -> Union[ 506*da0073e9SAndroid Build Coastguard Worker Dict[Tuple[None, None], Tuple[TensorListList, Indices]], 507*da0073e9SAndroid Build Coastguard Worker Dict[Tuple[torch.device, torch.dtype], Tuple[TensorListList, Indices]], 508*da0073e9SAndroid Build Coastguard Worker ]: 509*da0073e9SAndroid Build Coastguard Worker """Group a list of lists of tensors by device and dtype. 510*da0073e9SAndroid Build Coastguard Worker 511*da0073e9SAndroid Build Coastguard Worker Skips this step if we are compiling since this will occur during inductor lowering. 512*da0073e9SAndroid Build Coastguard Worker """ 513*da0073e9SAndroid Build Coastguard Worker if is_compiling(): 514*da0073e9SAndroid Build Coastguard Worker return {(None, None): (tensorlistlist, list(range(len(tensorlistlist[0]))))} 515*da0073e9SAndroid Build Coastguard Worker else: 516*da0073e9SAndroid Build Coastguard Worker return _group_tensors_by_device_and_dtype(tensorlistlist, with_indices) # type: ignore[return-value, arg-type] 517*da0073e9SAndroid Build Coastguard Worker 518*da0073e9SAndroid Build Coastguard Worker def _patch_step_function(self) -> None: 519*da0073e9SAndroid Build Coastguard Worker self._zero_grad_profile_name = ( 520*da0073e9SAndroid Build Coastguard Worker f"Optimizer.zero_grad#{self.__class__.__name__}.zero_grad" 521*da0073e9SAndroid Build Coastguard Worker ) 522*da0073e9SAndroid Build Coastguard Worker hooked = getattr(self.__class__.step, "hooked", None) 523*da0073e9SAndroid Build Coastguard Worker if not hooked: 524*da0073e9SAndroid Build Coastguard Worker self.__class__.step = self.profile_hook_step(self.__class__.step) # type: ignore[assignment] 525*da0073e9SAndroid Build Coastguard Worker self.__class__.step.hooked = True # type: ignore[attr-defined] 526*da0073e9SAndroid Build Coastguard Worker 527*da0073e9SAndroid Build Coastguard Worker def register_step_pre_hook(self, hook: OptimizerPreHook) -> RemovableHandle: 528*da0073e9SAndroid Build Coastguard Worker r"""Register an optimizer step pre hook which will be called before optimizer step. 529*da0073e9SAndroid Build Coastguard Worker 530*da0073e9SAndroid Build Coastguard Worker It should have the following signature:: 531*da0073e9SAndroid Build Coastguard Worker 532*da0073e9SAndroid Build Coastguard Worker hook(optimizer, args, kwargs) -> None or modified args and kwargs 533*da0073e9SAndroid Build Coastguard Worker 534*da0073e9SAndroid Build Coastguard Worker The ``optimizer`` argument is the optimizer instance being used. If 535*da0073e9SAndroid Build Coastguard Worker args and kwargs are modified by the pre-hook, then the transformed 536*da0073e9SAndroid Build Coastguard Worker values are returned as a tuple containing the new_args and new_kwargs. 537*da0073e9SAndroid Build Coastguard Worker 538*da0073e9SAndroid Build Coastguard Worker Args: 539*da0073e9SAndroid Build Coastguard Worker hook (Callable): The user defined hook to be registered. 540*da0073e9SAndroid Build Coastguard Worker 541*da0073e9SAndroid Build Coastguard Worker Returns: 542*da0073e9SAndroid Build Coastguard Worker :class:`torch.utils.hooks.RemovableHandle`: 543*da0073e9SAndroid Build Coastguard Worker a handle that can be used to remove the added hook by calling 544*da0073e9SAndroid Build Coastguard Worker ``handle.remove()`` 545*da0073e9SAndroid Build Coastguard Worker """ 546*da0073e9SAndroid Build Coastguard Worker handle = hooks.RemovableHandle(self._optimizer_step_pre_hooks) 547*da0073e9SAndroid Build Coastguard Worker self._optimizer_step_pre_hooks[handle.id] = hook 548*da0073e9SAndroid Build Coastguard Worker return handle 549*da0073e9SAndroid Build Coastguard Worker 550*da0073e9SAndroid Build Coastguard Worker def register_step_post_hook(self, hook: OptimizerPostHook) -> RemovableHandle: 551*da0073e9SAndroid Build Coastguard Worker r"""Register an optimizer step post hook which will be called after optimizer step. 552*da0073e9SAndroid Build Coastguard Worker 553*da0073e9SAndroid Build Coastguard Worker It should have the following signature:: 554*da0073e9SAndroid Build Coastguard Worker 555*da0073e9SAndroid Build Coastguard Worker hook(optimizer, args, kwargs) -> None 556*da0073e9SAndroid Build Coastguard Worker 557*da0073e9SAndroid Build Coastguard Worker The ``optimizer`` argument is the optimizer instance being used. 558*da0073e9SAndroid Build Coastguard Worker 559*da0073e9SAndroid Build Coastguard Worker Args: 560*da0073e9SAndroid Build Coastguard Worker hook (Callable): The user defined hook to be registered. 561*da0073e9SAndroid Build Coastguard Worker 562*da0073e9SAndroid Build Coastguard Worker Returns: 563*da0073e9SAndroid Build Coastguard Worker :class:`torch.utils.hooks.RemovableHandle`: 564*da0073e9SAndroid Build Coastguard Worker a handle that can be used to remove the added hook by calling 565*da0073e9SAndroid Build Coastguard Worker ``handle.remove()`` 566*da0073e9SAndroid Build Coastguard Worker """ 567*da0073e9SAndroid Build Coastguard Worker handle = hooks.RemovableHandle(self._optimizer_step_post_hooks) 568*da0073e9SAndroid Build Coastguard Worker self._optimizer_step_post_hooks[handle.id] = hook 569*da0073e9SAndroid Build Coastguard Worker return handle 570*da0073e9SAndroid Build Coastguard Worker 571*da0073e9SAndroid Build Coastguard Worker def register_state_dict_pre_hook( 572*da0073e9SAndroid Build Coastguard Worker self, hook: Callable[["Optimizer"], None], prepend: bool = False 573*da0073e9SAndroid Build Coastguard Worker ) -> RemovableHandle: # noqa: D101 574*da0073e9SAndroid Build Coastguard Worker r"""Register a state dict pre-hook which will be called before :meth:`~torch.optim.Optimizer.state_dict` is called. 575*da0073e9SAndroid Build Coastguard Worker 576*da0073e9SAndroid Build Coastguard Worker It should have the following signature:: 577*da0073e9SAndroid Build Coastguard Worker 578*da0073e9SAndroid Build Coastguard Worker hook(optimizer) -> None 579*da0073e9SAndroid Build Coastguard Worker 580*da0073e9SAndroid Build Coastguard Worker The ``optimizer`` argument is the optimizer instance being used. 581*da0073e9SAndroid Build Coastguard Worker The hook will be called with argument ``self`` before calling ``state_dict`` on ``self``. 582*da0073e9SAndroid Build Coastguard Worker The registered hook can be used to perform pre-processing before the ``state_dict`` 583*da0073e9SAndroid Build Coastguard Worker call is made. 584*da0073e9SAndroid Build Coastguard Worker 585*da0073e9SAndroid Build Coastguard Worker Args: 586*da0073e9SAndroid Build Coastguard Worker hook (Callable): The user defined hook to be registered. 587*da0073e9SAndroid Build Coastguard Worker prepend (bool): If True, the provided pre ``hook`` will be fired before 588*da0073e9SAndroid Build Coastguard Worker all the already registered pre-hooks on ``state_dict``. Otherwise, 589*da0073e9SAndroid Build Coastguard Worker the provided ``hook`` will be fired after all the already registered 590*da0073e9SAndroid Build Coastguard Worker pre-hooks. (default: False) 591*da0073e9SAndroid Build Coastguard Worker 592*da0073e9SAndroid Build Coastguard Worker Returns: 593*da0073e9SAndroid Build Coastguard Worker :class:`torch.utils.hooks.RemoveableHandle`: 594*da0073e9SAndroid Build Coastguard Worker a handle that can be used to remove the added hook by calling 595*da0073e9SAndroid Build Coastguard Worker ``handle.remove()`` 596*da0073e9SAndroid Build Coastguard Worker """ 597*da0073e9SAndroid Build Coastguard Worker handle = hooks.RemovableHandle(self._optimizer_state_dict_pre_hooks) 598*da0073e9SAndroid Build Coastguard Worker self._optimizer_state_dict_pre_hooks[handle.id] = hook 599*da0073e9SAndroid Build Coastguard Worker if prepend: 600*da0073e9SAndroid Build Coastguard Worker self._optimizer_state_dict_pre_hooks.move_to_end(handle.id, last=False) 601*da0073e9SAndroid Build Coastguard Worker return handle 602*da0073e9SAndroid Build Coastguard Worker 603*da0073e9SAndroid Build Coastguard Worker def register_state_dict_post_hook( 604*da0073e9SAndroid Build Coastguard Worker self, 605*da0073e9SAndroid Build Coastguard Worker hook: Callable[["Optimizer", StateDict], Optional[StateDict]], 606*da0073e9SAndroid Build Coastguard Worker prepend: bool = False, 607*da0073e9SAndroid Build Coastguard Worker ) -> RemovableHandle: 608*da0073e9SAndroid Build Coastguard Worker r"""Register a state dict post-hook which will be called after :meth:`~torch.optim.Optimizer.state_dict` is called. 609*da0073e9SAndroid Build Coastguard Worker 610*da0073e9SAndroid Build Coastguard Worker It should have the following signature:: 611*da0073e9SAndroid Build Coastguard Worker 612*da0073e9SAndroid Build Coastguard Worker hook(optimizer, state_dict) -> state_dict or None 613*da0073e9SAndroid Build Coastguard Worker 614*da0073e9SAndroid Build Coastguard Worker The hook will be called with arguments ``self`` and ``state_dict`` after generating 615*da0073e9SAndroid Build Coastguard Worker a ``state_dict`` on ``self``. The hook may modify the state_dict inplace or optionally 616*da0073e9SAndroid Build Coastguard Worker return a new one. The registered hook can be used to perform post-processing 617*da0073e9SAndroid Build Coastguard Worker on the ``state_dict`` before it is returned. 618*da0073e9SAndroid Build Coastguard Worker 619*da0073e9SAndroid Build Coastguard Worker Args: 620*da0073e9SAndroid Build Coastguard Worker hook (Callable): The user defined hook to be registered. 621*da0073e9SAndroid Build Coastguard Worker prepend (bool): If True, the provided post ``hook`` will be fired before 622*da0073e9SAndroid Build Coastguard Worker all the already registered post-hooks on ``state_dict``. Otherwise, 623*da0073e9SAndroid Build Coastguard Worker the provided ``hook`` will be fired after all the already registered 624*da0073e9SAndroid Build Coastguard Worker post-hooks. (default: False) 625*da0073e9SAndroid Build Coastguard Worker 626*da0073e9SAndroid Build Coastguard Worker Returns: 627*da0073e9SAndroid Build Coastguard Worker :class:`torch.utils.hooks.RemoveableHandle`: 628*da0073e9SAndroid Build Coastguard Worker a handle that can be used to remove the added hook by calling 629*da0073e9SAndroid Build Coastguard Worker ``handle.remove()`` 630*da0073e9SAndroid Build Coastguard Worker """ 631*da0073e9SAndroid Build Coastguard Worker handle = hooks.RemovableHandle(self._optimizer_state_dict_post_hooks) 632*da0073e9SAndroid Build Coastguard Worker self._optimizer_state_dict_post_hooks[handle.id] = hook 633*da0073e9SAndroid Build Coastguard Worker if prepend: 634*da0073e9SAndroid Build Coastguard Worker self._optimizer_state_dict_post_hooks.move_to_end(handle.id, last=False) 635*da0073e9SAndroid Build Coastguard Worker return handle 636*da0073e9SAndroid Build Coastguard Worker 637*da0073e9SAndroid Build Coastguard Worker @torch._disable_dynamo 638*da0073e9SAndroid Build Coastguard Worker def state_dict(self) -> StateDict: 639*da0073e9SAndroid Build Coastguard Worker r"""Return the state of the optimizer as a :class:`dict`. 640*da0073e9SAndroid Build Coastguard Worker 641*da0073e9SAndroid Build Coastguard Worker It contains two entries: 642*da0073e9SAndroid Build Coastguard Worker 643*da0073e9SAndroid Build Coastguard Worker * ``state``: a Dict holding current optimization state. Its content 644*da0073e9SAndroid Build Coastguard Worker differs between optimizer classes, but some common characteristics 645*da0073e9SAndroid Build Coastguard Worker hold. For example, state is saved per parameter, and the parameter 646*da0073e9SAndroid Build Coastguard Worker itself is NOT saved. ``state`` is a Dictionary mapping parameter ids 647*da0073e9SAndroid Build Coastguard Worker to a Dict with state corresponding to each parameter. 648*da0073e9SAndroid Build Coastguard Worker * ``param_groups``: a List containing all parameter groups where each 649*da0073e9SAndroid Build Coastguard Worker parameter group is a Dict. Each parameter group contains metadata 650*da0073e9SAndroid Build Coastguard Worker specific to the optimizer, such as learning rate and weight decay, 651*da0073e9SAndroid Build Coastguard Worker as well as a List of parameter IDs of the parameters in the group. 652*da0073e9SAndroid Build Coastguard Worker 653*da0073e9SAndroid Build Coastguard Worker NOTE: The parameter IDs may look like indices but they are just IDs 654*da0073e9SAndroid Build Coastguard Worker associating state with param_group. When loading from a state_dict, 655*da0073e9SAndroid Build Coastguard Worker the optimizer will zip the param_group ``params`` (int IDs) and the 656*da0073e9SAndroid Build Coastguard Worker optimizer ``param_groups`` (actual ``nn.Parameter`` s) in order to 657*da0073e9SAndroid Build Coastguard Worker match state WITHOUT additional verification. 658*da0073e9SAndroid Build Coastguard Worker 659*da0073e9SAndroid Build Coastguard Worker A returned state dict might look something like: 660*da0073e9SAndroid Build Coastguard Worker 661*da0073e9SAndroid Build Coastguard Worker .. code-block:: text 662*da0073e9SAndroid Build Coastguard Worker 663*da0073e9SAndroid Build Coastguard Worker { 664*da0073e9SAndroid Build Coastguard Worker 'state': { 665*da0073e9SAndroid Build Coastguard Worker 0: {'momentum_buffer': tensor(...), ...}, 666*da0073e9SAndroid Build Coastguard Worker 1: {'momentum_buffer': tensor(...), ...}, 667*da0073e9SAndroid Build Coastguard Worker 2: {'momentum_buffer': tensor(...), ...}, 668*da0073e9SAndroid Build Coastguard Worker 3: {'momentum_buffer': tensor(...), ...} 669*da0073e9SAndroid Build Coastguard Worker }, 670*da0073e9SAndroid Build Coastguard Worker 'param_groups': [ 671*da0073e9SAndroid Build Coastguard Worker { 672*da0073e9SAndroid Build Coastguard Worker 'lr': 0.01, 673*da0073e9SAndroid Build Coastguard Worker 'weight_decay': 0, 674*da0073e9SAndroid Build Coastguard Worker ... 675*da0073e9SAndroid Build Coastguard Worker 'params': [0] 676*da0073e9SAndroid Build Coastguard Worker }, 677*da0073e9SAndroid Build Coastguard Worker { 678*da0073e9SAndroid Build Coastguard Worker 'lr': 0.001, 679*da0073e9SAndroid Build Coastguard Worker 'weight_decay': 0.5, 680*da0073e9SAndroid Build Coastguard Worker ... 681*da0073e9SAndroid Build Coastguard Worker 'params': [1, 2, 3] 682*da0073e9SAndroid Build Coastguard Worker } 683*da0073e9SAndroid Build Coastguard Worker ] 684*da0073e9SAndroid Build Coastguard Worker } 685*da0073e9SAndroid Build Coastguard Worker 686*da0073e9SAndroid Build Coastguard Worker """ 687*da0073e9SAndroid Build Coastguard Worker for pre_hook in self._optimizer_state_dict_pre_hooks.values(): 688*da0073e9SAndroid Build Coastguard Worker pre_hook(self) 689*da0073e9SAndroid Build Coastguard Worker 690*da0073e9SAndroid Build Coastguard Worker # Save order indices instead of Tensors 691*da0073e9SAndroid Build Coastguard Worker param_mappings: Dict[int, int] = {} 692*da0073e9SAndroid Build Coastguard Worker start_index = 0 693*da0073e9SAndroid Build Coastguard Worker 694*da0073e9SAndroid Build Coastguard Worker def pack_group(group: Dict[str, Any]) -> Dict[str, Any]: 695*da0073e9SAndroid Build Coastguard Worker nonlocal start_index 696*da0073e9SAndroid Build Coastguard Worker packed = {k: v for k, v in group.items() if k != "params"} 697*da0073e9SAndroid Build Coastguard Worker param_mappings.update( 698*da0073e9SAndroid Build Coastguard Worker { 699*da0073e9SAndroid Build Coastguard Worker id(p): i 700*da0073e9SAndroid Build Coastguard Worker for i, p in enumerate(group["params"], start_index) 701*da0073e9SAndroid Build Coastguard Worker if id(p) not in param_mappings 702*da0073e9SAndroid Build Coastguard Worker } 703*da0073e9SAndroid Build Coastguard Worker ) 704*da0073e9SAndroid Build Coastguard Worker packed["params"] = [param_mappings[id(p)] for p in group["params"]] 705*da0073e9SAndroid Build Coastguard Worker start_index += len(packed["params"]) 706*da0073e9SAndroid Build Coastguard Worker return packed 707*da0073e9SAndroid Build Coastguard Worker 708*da0073e9SAndroid Build Coastguard Worker param_groups = [pack_group(g) for g in self.param_groups] 709*da0073e9SAndroid Build Coastguard Worker # Remap state to use order indices as keys 710*da0073e9SAndroid Build Coastguard Worker packed_state = { 711*da0073e9SAndroid Build Coastguard Worker (param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v 712*da0073e9SAndroid Build Coastguard Worker for k, v in self.state.items() 713*da0073e9SAndroid Build Coastguard Worker } 714*da0073e9SAndroid Build Coastguard Worker 715*da0073e9SAndroid Build Coastguard Worker state_dict = { 716*da0073e9SAndroid Build Coastguard Worker "state": packed_state, 717*da0073e9SAndroid Build Coastguard Worker "param_groups": param_groups, 718*da0073e9SAndroid Build Coastguard Worker } 719*da0073e9SAndroid Build Coastguard Worker 720*da0073e9SAndroid Build Coastguard Worker for post_hook in self._optimizer_state_dict_post_hooks.values(): 721*da0073e9SAndroid Build Coastguard Worker hook_result = post_hook(self, state_dict) 722*da0073e9SAndroid Build Coastguard Worker if hook_result is not None: 723*da0073e9SAndroid Build Coastguard Worker state_dict = hook_result 724*da0073e9SAndroid Build Coastguard Worker return state_dict 725*da0073e9SAndroid Build Coastguard Worker 726*da0073e9SAndroid Build Coastguard Worker @staticmethod 727*da0073e9SAndroid Build Coastguard Worker def _process_value_according_to_param_policy( 728*da0073e9SAndroid Build Coastguard Worker param: torch.Tensor, 729*da0073e9SAndroid Build Coastguard Worker value: torch.Tensor, 730*da0073e9SAndroid Build Coastguard Worker param_id: int, 731*da0073e9SAndroid Build Coastguard Worker param_groups: List[Dict[Any, Any]], 732*da0073e9SAndroid Build Coastguard Worker key: Hashable = None, 733*da0073e9SAndroid Build Coastguard Worker ) -> torch.Tensor: 734*da0073e9SAndroid Build Coastguard Worker # Floating-point types are a bit special here. They are the only ones 735*da0073e9SAndroid Build Coastguard Worker # that are assumed to always match the type of params. 736*da0073e9SAndroid Build Coastguard Worker # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424 737*da0073e9SAndroid Build Coastguard Worker # UNLESS fused or capturable, see note [special device hosting for step] 738*da0073e9SAndroid Build Coastguard Worker fused = False 739*da0073e9SAndroid Build Coastguard Worker capturable = False 740*da0073e9SAndroid Build Coastguard Worker assert param_groups is not None 741*da0073e9SAndroid Build Coastguard Worker for pg in param_groups: 742*da0073e9SAndroid Build Coastguard Worker if param_id in pg["params"]: 743*da0073e9SAndroid Build Coastguard Worker fused = pg["fused"] if "fused" in pg else False 744*da0073e9SAndroid Build Coastguard Worker capturable = pg["capturable"] if "capturable" in pg else False 745*da0073e9SAndroid Build Coastguard Worker break 746*da0073e9SAndroid Build Coastguard Worker if key == "step": 747*da0073e9SAndroid Build Coastguard Worker if capturable or fused: 748*da0073e9SAndroid Build Coastguard Worker return value.to(dtype=torch.float32, device=param.device) 749*da0073e9SAndroid Build Coastguard Worker else: 750*da0073e9SAndroid Build Coastguard Worker return value 751*da0073e9SAndroid Build Coastguard Worker else: 752*da0073e9SAndroid Build Coastguard Worker if param.is_floating_point(): 753*da0073e9SAndroid Build Coastguard Worker return value.to(dtype=param.dtype, device=param.device) 754*da0073e9SAndroid Build Coastguard Worker else: 755*da0073e9SAndroid Build Coastguard Worker return value.to(device=param.device) 756*da0073e9SAndroid Build Coastguard Worker 757*da0073e9SAndroid Build Coastguard Worker def register_load_state_dict_pre_hook( 758*da0073e9SAndroid Build Coastguard Worker self, 759*da0073e9SAndroid Build Coastguard Worker hook: Callable[["Optimizer", StateDict], Optional[StateDict]], 760*da0073e9SAndroid Build Coastguard Worker prepend: bool = False, 761*da0073e9SAndroid Build Coastguard Worker ) -> RemovableHandle: # noqa: D205 D400 762*da0073e9SAndroid Build Coastguard Worker r"""Register a load_state_dict pre-hook which will be called before 763*da0073e9SAndroid Build Coastguard Worker :meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the 764*da0073e9SAndroid Build Coastguard Worker following signature:: 765*da0073e9SAndroid Build Coastguard Worker 766*da0073e9SAndroid Build Coastguard Worker hook(optimizer, state_dict) -> state_dict or None 767*da0073e9SAndroid Build Coastguard Worker 768*da0073e9SAndroid Build Coastguard Worker The ``optimizer`` argument is the optimizer instance being used and the 769*da0073e9SAndroid Build Coastguard Worker ``state_dict`` argument is a shallow copy of the ``state_dict`` the user 770*da0073e9SAndroid Build Coastguard Worker passed in to ``load_state_dict``. The hook may modify the state_dict inplace 771*da0073e9SAndroid Build Coastguard Worker or optionally return a new one. If a state_dict is returned, it will be used 772*da0073e9SAndroid Build Coastguard Worker to be loaded into the optimizer. 773*da0073e9SAndroid Build Coastguard Worker 774*da0073e9SAndroid Build Coastguard Worker The hook will be called with argument ``self`` and ``state_dict`` before 775*da0073e9SAndroid Build Coastguard Worker calling ``load_state_dict`` on ``self``. The registered hook can be used to 776*da0073e9SAndroid Build Coastguard Worker perform pre-processing before the ``load_state_dict`` call is made. 777*da0073e9SAndroid Build Coastguard Worker 778*da0073e9SAndroid Build Coastguard Worker Args: 779*da0073e9SAndroid Build Coastguard Worker hook (Callable): The user defined hook to be registered. 780*da0073e9SAndroid Build Coastguard Worker prepend (bool): If True, the provided pre ``hook`` will be fired before 781*da0073e9SAndroid Build Coastguard Worker all the already registered pre-hooks on ``load_state_dict``. Otherwise, 782*da0073e9SAndroid Build Coastguard Worker the provided ``hook`` will be fired after all the already registered 783*da0073e9SAndroid Build Coastguard Worker pre-hooks. (default: False) 784*da0073e9SAndroid Build Coastguard Worker 785*da0073e9SAndroid Build Coastguard Worker Returns: 786*da0073e9SAndroid Build Coastguard Worker :class:`torch.utils.hooks.RemoveableHandle`: 787*da0073e9SAndroid Build Coastguard Worker a handle that can be used to remove the added hook by calling 788*da0073e9SAndroid Build Coastguard Worker ``handle.remove()`` 789*da0073e9SAndroid Build Coastguard Worker """ 790*da0073e9SAndroid Build Coastguard Worker handle = hooks.RemovableHandle(self._optimizer_load_state_dict_pre_hooks) 791*da0073e9SAndroid Build Coastguard Worker self._optimizer_load_state_dict_pre_hooks[handle.id] = hook 792*da0073e9SAndroid Build Coastguard Worker if prepend: 793*da0073e9SAndroid Build Coastguard Worker self._optimizer_load_state_dict_pre_hooks.move_to_end(handle.id, last=False) 794*da0073e9SAndroid Build Coastguard Worker return handle 795*da0073e9SAndroid Build Coastguard Worker 796*da0073e9SAndroid Build Coastguard Worker def register_load_state_dict_post_hook( 797*da0073e9SAndroid Build Coastguard Worker self, hook: Callable[["Optimizer"], None], prepend: bool = False 798*da0073e9SAndroid Build Coastguard Worker ) -> RemovableHandle: # noqa: D205 D400 799*da0073e9SAndroid Build Coastguard Worker r"""Register a load_state_dict post-hook which will be called after 800*da0073e9SAndroid Build Coastguard Worker :meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the 801*da0073e9SAndroid Build Coastguard Worker following signature:: 802*da0073e9SAndroid Build Coastguard Worker 803*da0073e9SAndroid Build Coastguard Worker hook(optimizer) -> None 804*da0073e9SAndroid Build Coastguard Worker 805*da0073e9SAndroid Build Coastguard Worker The ``optimizer`` argument is the optimizer instance being used. 806*da0073e9SAndroid Build Coastguard Worker 807*da0073e9SAndroid Build Coastguard Worker The hook will be called with argument ``self`` after calling 808*da0073e9SAndroid Build Coastguard Worker ``load_state_dict`` on ``self``. The registered hook can be used to 809*da0073e9SAndroid Build Coastguard Worker perform post-processing after ``load_state_dict`` has loaded the 810*da0073e9SAndroid Build Coastguard Worker ``state_dict``. 811*da0073e9SAndroid Build Coastguard Worker 812*da0073e9SAndroid Build Coastguard Worker Args: 813*da0073e9SAndroid Build Coastguard Worker hook (Callable): The user defined hook to be registered. 814*da0073e9SAndroid Build Coastguard Worker prepend (bool): If True, the provided post ``hook`` will be fired before 815*da0073e9SAndroid Build Coastguard Worker all the already registered post-hooks on ``load_state_dict``. Otherwise, 816*da0073e9SAndroid Build Coastguard Worker the provided ``hook`` will be fired after all the already registered 817*da0073e9SAndroid Build Coastguard Worker post-hooks. (default: False) 818*da0073e9SAndroid Build Coastguard Worker 819*da0073e9SAndroid Build Coastguard Worker Returns: 820*da0073e9SAndroid Build Coastguard Worker :class:`torch.utils.hooks.RemoveableHandle`: 821*da0073e9SAndroid Build Coastguard Worker a handle that can be used to remove the added hook by calling 822*da0073e9SAndroid Build Coastguard Worker ``handle.remove()`` 823*da0073e9SAndroid Build Coastguard Worker """ 824*da0073e9SAndroid Build Coastguard Worker handle = hooks.RemovableHandle(self._optimizer_load_state_dict_post_hooks) 825*da0073e9SAndroid Build Coastguard Worker self._optimizer_load_state_dict_post_hooks[handle.id] = hook 826*da0073e9SAndroid Build Coastguard Worker if prepend: 827*da0073e9SAndroid Build Coastguard Worker self._optimizer_load_state_dict_post_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined] 828*da0073e9SAndroid Build Coastguard Worker return handle 829*da0073e9SAndroid Build Coastguard Worker 830*da0073e9SAndroid Build Coastguard Worker @torch._disable_dynamo 831*da0073e9SAndroid Build Coastguard Worker def load_state_dict(self, state_dict: StateDict) -> None: 832*da0073e9SAndroid Build Coastguard Worker r"""Load the optimizer state. 833*da0073e9SAndroid Build Coastguard Worker 834*da0073e9SAndroid Build Coastguard Worker Args: 835*da0073e9SAndroid Build Coastguard Worker state_dict (dict): optimizer state. Should be an object returned 836*da0073e9SAndroid Build Coastguard Worker from a call to :meth:`state_dict`. 837*da0073e9SAndroid Build Coastguard Worker """ 838*da0073e9SAndroid Build Coastguard Worker # shallow copy, to be consistent with module API 839*da0073e9SAndroid Build Coastguard Worker state_dict = state_dict.copy() 840*da0073e9SAndroid Build Coastguard Worker 841*da0073e9SAndroid Build Coastguard Worker for pre_hook in self._optimizer_load_state_dict_pre_hooks.values(): 842*da0073e9SAndroid Build Coastguard Worker hook_result = pre_hook(self, state_dict) 843*da0073e9SAndroid Build Coastguard Worker if hook_result is not None: 844*da0073e9SAndroid Build Coastguard Worker state_dict = hook_result 845*da0073e9SAndroid Build Coastguard Worker 846*da0073e9SAndroid Build Coastguard Worker # Validate the state_dict 847*da0073e9SAndroid Build Coastguard Worker groups = self.param_groups 848*da0073e9SAndroid Build Coastguard Worker 849*da0073e9SAndroid Build Coastguard Worker # Deepcopy as we write into saved_groups later to update state 850*da0073e9SAndroid Build Coastguard Worker saved_groups = deepcopy(state_dict["param_groups"]) 851*da0073e9SAndroid Build Coastguard Worker 852*da0073e9SAndroid Build Coastguard Worker if len(groups) != len(saved_groups): 853*da0073e9SAndroid Build Coastguard Worker raise ValueError( 854*da0073e9SAndroid Build Coastguard Worker "loaded state dict has a different number of " "parameter groups" 855*da0073e9SAndroid Build Coastguard Worker ) 856*da0073e9SAndroid Build Coastguard Worker param_lens = (len(g["params"]) for g in groups) 857*da0073e9SAndroid Build Coastguard Worker saved_lens = (len(g["params"]) for g in saved_groups) 858*da0073e9SAndroid Build Coastguard Worker if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)): 859*da0073e9SAndroid Build Coastguard Worker raise ValueError( 860*da0073e9SAndroid Build Coastguard Worker "loaded state dict contains a parameter group " 861*da0073e9SAndroid Build Coastguard Worker "that doesn't match the size of optimizer's group" 862*da0073e9SAndroid Build Coastguard Worker ) 863*da0073e9SAndroid Build Coastguard Worker 864*da0073e9SAndroid Build Coastguard Worker # Update the state 865*da0073e9SAndroid Build Coastguard Worker id_map = dict( 866*da0073e9SAndroid Build Coastguard Worker zip( 867*da0073e9SAndroid Build Coastguard Worker chain.from_iterable(g["params"] for g in saved_groups), 868*da0073e9SAndroid Build Coastguard Worker chain.from_iterable(g["params"] for g in groups), 869*da0073e9SAndroid Build Coastguard Worker ) 870*da0073e9SAndroid Build Coastguard Worker ) 871*da0073e9SAndroid Build Coastguard Worker 872*da0073e9SAndroid Build Coastguard Worker def _cast(param, value, param_id=None, param_groups=None, key=None): 873*da0073e9SAndroid Build Coastguard Worker r"""Make a deep copy of value, casting all tensors to device of param.""" 874*da0073e9SAndroid Build Coastguard Worker if isinstance(value, torch.Tensor): 875*da0073e9SAndroid Build Coastguard Worker return Optimizer._process_value_according_to_param_policy( 876*da0073e9SAndroid Build Coastguard Worker param, value, param_id, param_groups, key 877*da0073e9SAndroid Build Coastguard Worker ) 878*da0073e9SAndroid Build Coastguard Worker elif isinstance(value, dict): 879*da0073e9SAndroid Build Coastguard Worker return { 880*da0073e9SAndroid Build Coastguard Worker k: _cast( 881*da0073e9SAndroid Build Coastguard Worker param, v, param_id=param_id, param_groups=param_groups, key=k 882*da0073e9SAndroid Build Coastguard Worker ) 883*da0073e9SAndroid Build Coastguard Worker for k, v in value.items() 884*da0073e9SAndroid Build Coastguard Worker } 885*da0073e9SAndroid Build Coastguard Worker elif isinstance(value, Iterable): 886*da0073e9SAndroid Build Coastguard Worker return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value) # type: ignore[call-arg] 887*da0073e9SAndroid Build Coastguard Worker else: 888*da0073e9SAndroid Build Coastguard Worker return value 889*da0073e9SAndroid Build Coastguard Worker 890*da0073e9SAndroid Build Coastguard Worker # Copy state assigned to params (and cast tensors to appropriate types). 891*da0073e9SAndroid Build Coastguard Worker # State that is not assigned to params is copied as is (needed for 892*da0073e9SAndroid Build Coastguard Worker # backward compatibility). 893*da0073e9SAndroid Build Coastguard Worker state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict) 894*da0073e9SAndroid Build Coastguard Worker for k, v in state_dict["state"].items(): 895*da0073e9SAndroid Build Coastguard Worker if k in id_map: 896*da0073e9SAndroid Build Coastguard Worker param = id_map[k] 897*da0073e9SAndroid Build Coastguard Worker state[param] = _cast( 898*da0073e9SAndroid Build Coastguard Worker param, v, param_id=k, param_groups=state_dict["param_groups"] 899*da0073e9SAndroid Build Coastguard Worker ) 900*da0073e9SAndroid Build Coastguard Worker else: 901*da0073e9SAndroid Build Coastguard Worker state[k] = v 902*da0073e9SAndroid Build Coastguard Worker 903*da0073e9SAndroid Build Coastguard Worker # Update parameter groups, setting their 'params' value 904*da0073e9SAndroid Build Coastguard Worker def update_group( 905*da0073e9SAndroid Build Coastguard Worker group: Dict[str, Any], new_group: Dict[str, Any] 906*da0073e9SAndroid Build Coastguard Worker ) -> Dict[str, Any]: 907*da0073e9SAndroid Build Coastguard Worker new_group["params"] = group["params"] 908*da0073e9SAndroid Build Coastguard Worker return new_group 909*da0073e9SAndroid Build Coastguard Worker 910*da0073e9SAndroid Build Coastguard Worker param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)] 911*da0073e9SAndroid Build Coastguard Worker self.__setstate__({"state": state, "param_groups": param_groups}) 912*da0073e9SAndroid Build Coastguard Worker 913*da0073e9SAndroid Build Coastguard Worker for post_hook in self._optimizer_load_state_dict_post_hooks.values(): 914*da0073e9SAndroid Build Coastguard Worker post_hook(self) 915*da0073e9SAndroid Build Coastguard Worker 916*da0073e9SAndroid Build Coastguard Worker @torch._disable_dynamo 917*da0073e9SAndroid Build Coastguard Worker def zero_grad(self, set_to_none: bool = True) -> None: 918*da0073e9SAndroid Build Coastguard Worker r"""Reset the gradients of all optimized :class:`torch.Tensor` s. 919*da0073e9SAndroid Build Coastguard Worker 920*da0073e9SAndroid Build Coastguard Worker Args: 921*da0073e9SAndroid Build Coastguard Worker set_to_none (bool): instead of setting to zero, set the grads to None. 922*da0073e9SAndroid Build Coastguard Worker This will in general have lower memory footprint, and can modestly improve performance. 923*da0073e9SAndroid Build Coastguard Worker However, it changes certain behaviors. For example: 924*da0073e9SAndroid Build Coastguard Worker 1. When the user tries to access a gradient and perform manual ops on it, 925*da0073e9SAndroid Build Coastguard Worker a None attribute or a Tensor full of 0s will behave differently. 926*da0073e9SAndroid Build Coastguard Worker 2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s 927*da0073e9SAndroid Build Coastguard Worker are guaranteed to be None for params that did not receive a gradient. 928*da0073e9SAndroid Build Coastguard Worker 3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None 929*da0073e9SAndroid Build Coastguard Worker (in one case it does the step with a gradient of 0 and in the other it skips 930*da0073e9SAndroid Build Coastguard Worker the step altogether). 931*da0073e9SAndroid Build Coastguard Worker """ 932*da0073e9SAndroid Build Coastguard Worker foreach = self.defaults.get("foreach", False) or self.defaults.get( 933*da0073e9SAndroid Build Coastguard Worker "fused", False 934*da0073e9SAndroid Build Coastguard Worker ) 935*da0073e9SAndroid Build Coastguard Worker 936*da0073e9SAndroid Build Coastguard Worker if not hasattr(self, "_zero_grad_profile_name"): 937*da0073e9SAndroid Build Coastguard Worker self._patch_step_function() 938*da0073e9SAndroid Build Coastguard Worker 939*da0073e9SAndroid Build Coastguard Worker per_device_and_dtype_grads: Optional[ 940*da0073e9SAndroid Build Coastguard Worker DefaultDict[torch.device, DefaultDict[torch.dtype, List[torch.Tensor]]] 941*da0073e9SAndroid Build Coastguard Worker ] 942*da0073e9SAndroid Build Coastguard Worker if foreach: 943*da0073e9SAndroid Build Coastguard Worker per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) 944*da0073e9SAndroid Build Coastguard Worker else: 945*da0073e9SAndroid Build Coastguard Worker per_device_and_dtype_grads = None 946*da0073e9SAndroid Build Coastguard Worker 947*da0073e9SAndroid Build Coastguard Worker with torch.autograd.profiler.record_function(self._zero_grad_profile_name): 948*da0073e9SAndroid Build Coastguard Worker for group in self.param_groups: 949*da0073e9SAndroid Build Coastguard Worker for p in group["params"]: 950*da0073e9SAndroid Build Coastguard Worker if p.grad is not None: 951*da0073e9SAndroid Build Coastguard Worker if set_to_none: 952*da0073e9SAndroid Build Coastguard Worker p.grad = None 953*da0073e9SAndroid Build Coastguard Worker else: 954*da0073e9SAndroid Build Coastguard Worker if p.grad.grad_fn is not None: 955*da0073e9SAndroid Build Coastguard Worker p.grad.detach_() 956*da0073e9SAndroid Build Coastguard Worker else: 957*da0073e9SAndroid Build Coastguard Worker p.grad.requires_grad_(False) 958*da0073e9SAndroid Build Coastguard Worker if not foreach or p.grad.is_sparse: 959*da0073e9SAndroid Build Coastguard Worker p.grad.zero_() 960*da0073e9SAndroid Build Coastguard Worker else: 961*da0073e9SAndroid Build Coastguard Worker assert per_device_and_dtype_grads is not None 962*da0073e9SAndroid Build Coastguard Worker per_device_and_dtype_grads[p.grad.device][ 963*da0073e9SAndroid Build Coastguard Worker p.grad.dtype 964*da0073e9SAndroid Build Coastguard Worker ].append(p.grad) 965*da0073e9SAndroid Build Coastguard Worker if foreach: 966*da0073e9SAndroid Build Coastguard Worker assert per_device_and_dtype_grads is not None 967*da0073e9SAndroid Build Coastguard Worker for per_dtype_grads in per_device_and_dtype_grads.values(): 968*da0073e9SAndroid Build Coastguard Worker for grads in per_dtype_grads.values(): 969*da0073e9SAndroid Build Coastguard Worker torch._foreach_zero_(grads) 970*da0073e9SAndroid Build Coastguard Worker 971*da0073e9SAndroid Build Coastguard Worker @overload 972*da0073e9SAndroid Build Coastguard Worker def step(self, closure: None = ...) -> None: 973*da0073e9SAndroid Build Coastguard Worker ... 974*da0073e9SAndroid Build Coastguard Worker 975*da0073e9SAndroid Build Coastguard Worker @overload 976*da0073e9SAndroid Build Coastguard Worker def step(self, closure: Callable[[], float]) -> float: 977*da0073e9SAndroid Build Coastguard Worker ... 978*da0073e9SAndroid Build Coastguard Worker 979*da0073e9SAndroid Build Coastguard Worker def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: 980*da0073e9SAndroid Build Coastguard Worker r"""Perform a single optimization step to update parameter. 981*da0073e9SAndroid Build Coastguard Worker 982*da0073e9SAndroid Build Coastguard Worker Args: 983*da0073e9SAndroid Build Coastguard Worker closure (Callable): A closure that reevaluates the model and 984*da0073e9SAndroid Build Coastguard Worker returns the loss. Optional for most optimizers. 985*da0073e9SAndroid Build Coastguard Worker 986*da0073e9SAndroid Build Coastguard Worker .. note:: 987*da0073e9SAndroid Build Coastguard Worker Unless otherwise specified, this function should not modify the 988*da0073e9SAndroid Build Coastguard Worker ``.grad`` field of the parameters. 989*da0073e9SAndroid Build Coastguard Worker """ 990*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError 991*da0073e9SAndroid Build Coastguard Worker 992*da0073e9SAndroid Build Coastguard Worker @torch._disable_dynamo 993*da0073e9SAndroid Build Coastguard Worker def add_param_group(self, param_group: Dict[str, Any]) -> None: 994*da0073e9SAndroid Build Coastguard Worker r"""Add a param group to the :class:`Optimizer` s `param_groups`. 995*da0073e9SAndroid Build Coastguard Worker 996*da0073e9SAndroid Build Coastguard Worker This can be useful when fine tuning a pre-trained network as frozen layers can be made 997*da0073e9SAndroid Build Coastguard Worker trainable and added to the :class:`Optimizer` as training progresses. 998*da0073e9SAndroid Build Coastguard Worker 999*da0073e9SAndroid Build Coastguard Worker Args: 1000*da0073e9SAndroid Build Coastguard Worker param_group (dict): Specifies what Tensors should be optimized along with group 1001*da0073e9SAndroid Build Coastguard Worker specific optimization options. 1002*da0073e9SAndroid Build Coastguard Worker """ 1003*da0073e9SAndroid Build Coastguard Worker if not isinstance(param_group, dict): 1004*da0073e9SAndroid Build Coastguard Worker raise TypeError(f"param_group must be a dict, but got {type(param_group)}") 1005*da0073e9SAndroid Build Coastguard Worker 1006*da0073e9SAndroid Build Coastguard Worker params = param_group["params"] 1007*da0073e9SAndroid Build Coastguard Worker if isinstance(params, torch.Tensor): 1008*da0073e9SAndroid Build Coastguard Worker param_group["params"] = [params] 1009*da0073e9SAndroid Build Coastguard Worker elif isinstance(params, set): 1010*da0073e9SAndroid Build Coastguard Worker raise TypeError( 1011*da0073e9SAndroid Build Coastguard Worker "optimizer parameters need to be organized in ordered collections, but " 1012*da0073e9SAndroid Build Coastguard Worker "the ordering of tensors in sets will change between runs. Please use a list instead." 1013*da0073e9SAndroid Build Coastguard Worker ) 1014*da0073e9SAndroid Build Coastguard Worker else: 1015*da0073e9SAndroid Build Coastguard Worker param_group["params"] = list(params) 1016*da0073e9SAndroid Build Coastguard Worker 1017*da0073e9SAndroid Build Coastguard Worker for param in param_group["params"]: 1018*da0073e9SAndroid Build Coastguard Worker if not isinstance(param, torch.Tensor): 1019*da0073e9SAndroid Build Coastguard Worker raise TypeError( 1020*da0073e9SAndroid Build Coastguard Worker "optimizer can only optimize Tensors, " 1021*da0073e9SAndroid Build Coastguard Worker "but one of the params is " + torch.typename(param) 1022*da0073e9SAndroid Build Coastguard Worker ) 1023*da0073e9SAndroid Build Coastguard Worker if not self.defaults.get("differentiable", None) and not ( 1024*da0073e9SAndroid Build Coastguard Worker param.is_leaf or param.retains_grad 1025*da0073e9SAndroid Build Coastguard Worker ): 1026*da0073e9SAndroid Build Coastguard Worker raise ValueError("can't optimize a non-leaf Tensor") 1027*da0073e9SAndroid Build Coastguard Worker 1028*da0073e9SAndroid Build Coastguard Worker for name, default in self.defaults.items(): 1029*da0073e9SAndroid Build Coastguard Worker if default is required and name not in param_group: 1030*da0073e9SAndroid Build Coastguard Worker raise ValueError( 1031*da0073e9SAndroid Build Coastguard Worker f"parameter group didn't specify a value of required optimization parameter {name}" 1032*da0073e9SAndroid Build Coastguard Worker ) 1033*da0073e9SAndroid Build Coastguard Worker else: 1034*da0073e9SAndroid Build Coastguard Worker param_group.setdefault(name, default) 1035*da0073e9SAndroid Build Coastguard Worker 1036*da0073e9SAndroid Build Coastguard Worker params = param_group["params"] 1037*da0073e9SAndroid Build Coastguard Worker if len(params) != len(set(params)): 1038*da0073e9SAndroid Build Coastguard Worker warnings.warn( 1039*da0073e9SAndroid Build Coastguard Worker "optimizer contains a parameter group with duplicate parameters; " 1040*da0073e9SAndroid Build Coastguard Worker "in future, this will cause an error; " 1041*da0073e9SAndroid Build Coastguard Worker "see github.com/pytorch/pytorch/issues/40967 for more information", 1042*da0073e9SAndroid Build Coastguard Worker stacklevel=3, 1043*da0073e9SAndroid Build Coastguard Worker ) 1044*da0073e9SAndroid Build Coastguard Worker 1045*da0073e9SAndroid Build Coastguard Worker param_set: Set[torch.Tensor] = set() 1046*da0073e9SAndroid Build Coastguard Worker for group in self.param_groups: 1047*da0073e9SAndroid Build Coastguard Worker param_set.update(set(group["params"])) 1048*da0073e9SAndroid Build Coastguard Worker 1049*da0073e9SAndroid Build Coastguard Worker if not param_set.isdisjoint(set(param_group["params"])): 1050*da0073e9SAndroid Build Coastguard Worker raise ValueError("some parameters appear in more than one parameter group") 1051*da0073e9SAndroid Build Coastguard Worker 1052*da0073e9SAndroid Build Coastguard Worker self.param_groups.append(param_group) 1053