xref: /aosp_15_r20/external/pytorch/torch/optim/optimizer.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-decorators
2*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
3*da0073e9SAndroid Build Coastguard Worker"""Base optimizer."""
4*da0073e9SAndroid Build Coastguard Workerimport functools
5*da0073e9SAndroid Build Coastguard Workerimport warnings
6*da0073e9SAndroid Build Coastguard Workerfrom collections import defaultdict, OrderedDict
7*da0073e9SAndroid Build Coastguard Workerfrom copy import deepcopy
8*da0073e9SAndroid Build Coastguard Workerfrom itertools import chain
9*da0073e9SAndroid Build Coastguard Workerfrom typing import (
10*da0073e9SAndroid Build Coastguard Worker    Any,
11*da0073e9SAndroid Build Coastguard Worker    Callable,
12*da0073e9SAndroid Build Coastguard Worker    cast,
13*da0073e9SAndroid Build Coastguard Worker    DefaultDict,
14*da0073e9SAndroid Build Coastguard Worker    Dict,
15*da0073e9SAndroid Build Coastguard Worker    Hashable,
16*da0073e9SAndroid Build Coastguard Worker    Iterable,
17*da0073e9SAndroid Build Coastguard Worker    List,
18*da0073e9SAndroid Build Coastguard Worker    Optional,
19*da0073e9SAndroid Build Coastguard Worker    overload,
20*da0073e9SAndroid Build Coastguard Worker    Set,
21*da0073e9SAndroid Build Coastguard Worker    Tuple,
22*da0073e9SAndroid Build Coastguard Worker    TypeVar,
23*da0073e9SAndroid Build Coastguard Worker    Union,
24*da0073e9SAndroid Build Coastguard Worker)
25*da0073e9SAndroid Build Coastguard Workerfrom typing_extensions import ParamSpec, Self, TypeAlias
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Workerimport torch
28*da0073e9SAndroid Build Coastguard Workerimport torch.utils.hooks as hooks
29*da0073e9SAndroid Build Coastguard Workerfrom torch._utils import is_compiling
30*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._foreach_utils import (
31*da0073e9SAndroid Build Coastguard Worker    _get_foreach_kernels_supported_devices,
32*da0073e9SAndroid Build Coastguard Worker    _get_fused_kernels_supported_devices,
33*da0073e9SAndroid Build Coastguard Worker    _group_tensors_by_device_and_dtype,
34*da0073e9SAndroid Build Coastguard Worker    Indices,
35*da0073e9SAndroid Build Coastguard Worker    TensorListList,
36*da0073e9SAndroid Build Coastguard Worker)
37*da0073e9SAndroid Build Coastguard Workerfrom torch.utils.hooks import RemovableHandle
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard WorkerArgs: TypeAlias = Tuple[Any, ...]
41*da0073e9SAndroid Build Coastguard WorkerKwargs: TypeAlias = Dict[str, Any]
42*da0073e9SAndroid Build Coastguard WorkerStateDict: TypeAlias = Dict[str, Any]
43*da0073e9SAndroid Build Coastguard WorkerDeviceDict = Dict[Optional[torch.device], torch.Tensor]
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard WorkerGlobalOptimizerPreHook: TypeAlias = Callable[
47*da0073e9SAndroid Build Coastguard Worker    ["Optimizer", Args, Kwargs], Optional[Tuple[Args, Kwargs]]
48*da0073e9SAndroid Build Coastguard Worker]
49*da0073e9SAndroid Build Coastguard WorkerGlobalOptimizerPostHook: TypeAlias = Callable[["Optimizer", Args, Kwargs], None]
50*da0073e9SAndroid Build Coastguard Worker
51*da0073e9SAndroid Build Coastguard Worker__all__ = [
52*da0073e9SAndroid Build Coastguard Worker    "Optimizer",
53*da0073e9SAndroid Build Coastguard Worker    "register_optimizer_step_pre_hook",
54*da0073e9SAndroid Build Coastguard Worker    "register_optimizer_step_post_hook",
55*da0073e9SAndroid Build Coastguard Worker]
56*da0073e9SAndroid Build Coastguard Worker_global_optimizer_pre_hooks: Dict[int, GlobalOptimizerPreHook] = OrderedDict()
57*da0073e9SAndroid Build Coastguard Worker_global_optimizer_post_hooks: Dict[int, GlobalOptimizerPostHook] = OrderedDict()
58*da0073e9SAndroid Build Coastguard Worker_foreach_supported_types = [torch.Tensor, torch.nn.parameter.Parameter]
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Workerclass _RequiredParameter:
62*da0073e9SAndroid Build Coastguard Worker    """Singleton class representing a required parameter for an Optimizer."""
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker    def __repr__(self) -> str:
65*da0073e9SAndroid Build Coastguard Worker        return "<required parameter>"
66*da0073e9SAndroid Build Coastguard Worker
67*da0073e9SAndroid Build Coastguard Worker
68*da0073e9SAndroid Build Coastguard Workerrequired = _RequiredParameter()
69*da0073e9SAndroid Build Coastguard Worker
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Workerdef _use_grad_for_differentiable(func):
72*da0073e9SAndroid Build Coastguard Worker    def _use_grad(self, *args, **kwargs):
73*da0073e9SAndroid Build Coastguard Worker        import torch._dynamo
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker        prev_grad = torch.is_grad_enabled()
76*da0073e9SAndroid Build Coastguard Worker        try:
77*da0073e9SAndroid Build Coastguard Worker            # Note on graph break below:
78*da0073e9SAndroid Build Coastguard Worker            # we need to graph break to ensure that aot respects the no_grad annotation.
79*da0073e9SAndroid Build Coastguard Worker            # This is important for perf because without this, functionalization will generate an epilogue
80*da0073e9SAndroid Build Coastguard Worker            # which updates the mutated parameters of the optimizer which is *not* visible to inductor, as a result,
81*da0073e9SAndroid Build Coastguard Worker            # inductor will allocate for every parameter in the model, which is horrible.
82*da0073e9SAndroid Build Coastguard Worker            # With this, aot correctly sees that this is an inference graph, and functionalization will generate
83*da0073e9SAndroid Build Coastguard Worker            # an epilogue which is appended to the graph, which *is* visible to inductor, as a result, inductor sees that
84*da0073e9SAndroid Build Coastguard Worker            # step is in place and is able to avoid the extra allocation.
85*da0073e9SAndroid Build Coastguard Worker            # In the future, we will either 1) continue to graph break on backward, so this graph break does not matter
86*da0073e9SAndroid Build Coastguard Worker            # or 2) have a fully fused forward and backward graph, which will have no_grad by default, and we can remove this
87*da0073e9SAndroid Build Coastguard Worker            # graph break to allow the fully fused fwd-bwd-optimizer graph to be compiled.
88*da0073e9SAndroid Build Coastguard Worker            # see https://github.com/pytorch/pytorch/issues/104053
89*da0073e9SAndroid Build Coastguard Worker            torch.set_grad_enabled(self.defaults["differentiable"])
90*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.graph_break()
91*da0073e9SAndroid Build Coastguard Worker            ret = func(self, *args, **kwargs)
92*da0073e9SAndroid Build Coastguard Worker        finally:
93*da0073e9SAndroid Build Coastguard Worker            torch._dynamo.graph_break()
94*da0073e9SAndroid Build Coastguard Worker            torch.set_grad_enabled(prev_grad)
95*da0073e9SAndroid Build Coastguard Worker        return ret
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker    functools.update_wrapper(_use_grad, func)
98*da0073e9SAndroid Build Coastguard Worker    return _use_grad
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Workerdef _get_value(x):
102*da0073e9SAndroid Build Coastguard Worker    # item is significantly faster than a cpu tensor in eager mode
103*da0073e9SAndroid Build Coastguard Worker    if not torch.jit.is_scripting() and is_compiling():
104*da0073e9SAndroid Build Coastguard Worker        return x
105*da0073e9SAndroid Build Coastguard Worker    else:
106*da0073e9SAndroid Build Coastguard Worker        return x.item() if isinstance(x, torch.Tensor) else x
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker
109*da0073e9SAndroid Build Coastguard Workerdef _stack_if_compiling(x):
110*da0073e9SAndroid Build Coastguard Worker    if not torch.jit.is_scripting() and is_compiling():
111*da0073e9SAndroid Build Coastguard Worker        return torch.stack(x)
112*da0073e9SAndroid Build Coastguard Worker    else:
113*da0073e9SAndroid Build Coastguard Worker        return x
114*da0073e9SAndroid Build Coastguard Worker
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Workerdef _disable_dynamo_if_unsupported(single_tensor_fn=None):
117*da0073e9SAndroid Build Coastguard Worker    # workaround for torchscript BC
118*da0073e9SAndroid Build Coastguard Worker    # it requires all called functions to be in the
119*da0073e9SAndroid Build Coastguard Worker    # global environment at the site at which the
120*da0073e9SAndroid Build Coastguard Worker    # maybe_fallback closure is created
121*da0073e9SAndroid Build Coastguard Worker    if single_tensor_fn:
122*da0073e9SAndroid Build Coastguard Worker        globals()[single_tensor_fn.__name__] = single_tensor_fn
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker    def wrapper(func):
125*da0073e9SAndroid Build Coastguard Worker        import inspect
126*da0073e9SAndroid Build Coastguard Worker
127*da0073e9SAndroid Build Coastguard Worker        disabled_func = torch._disable_dynamo(func)
128*da0073e9SAndroid Build Coastguard Worker        ps = inspect.signature(func).parameters
129*da0073e9SAndroid Build Coastguard Worker        has_state_steps = True
130*da0073e9SAndroid Build Coastguard Worker        try:
131*da0073e9SAndroid Build Coastguard Worker            state_steps_ind = list(ps.keys()).index("state_steps")
132*da0073e9SAndroid Build Coastguard Worker        except ValueError:
133*da0073e9SAndroid Build Coastguard Worker            has_state_steps = False
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Worker        # Today, there are cases where we stack state steps
136*da0073e9SAndroid Build Coastguard Worker        # and pass them as the value arg of foreach ops.
137*da0073e9SAndroid Build Coastguard Worker        # Having state steps on cuda as the value arg is not supported in eager,
138*da0073e9SAndroid Build Coastguard Worker        # but this only occurs in the rare case that the user explicitly deletes
139*da0073e9SAndroid Build Coastguard Worker        # the capturable flag. If capturable=True, this is not a problem.
140*da0073e9SAndroid Build Coastguard Worker        @functools.wraps(func)
141*da0073e9SAndroid Build Coastguard Worker        def maybe_fallback(*args, **kwargs):
142*da0073e9SAndroid Build Coastguard Worker            if is_compiling() and (
143*da0073e9SAndroid Build Coastguard Worker                not kwargs.get("capturable", False)
144*da0073e9SAndroid Build Coastguard Worker                and has_state_steps
145*da0073e9SAndroid Build Coastguard Worker                and (args[state_steps_ind] and args[state_steps_ind][0].is_cuda)
146*da0073e9SAndroid Build Coastguard Worker                or (
147*da0073e9SAndroid Build Coastguard Worker                    "state_steps" in kwargs
148*da0073e9SAndroid Build Coastguard Worker                    and kwargs["state_steps"]
149*da0073e9SAndroid Build Coastguard Worker                    and kwargs["state_steps"][0].is_cuda
150*da0073e9SAndroid Build Coastguard Worker                )
151*da0073e9SAndroid Build Coastguard Worker            ):
152*da0073e9SAndroid Build Coastguard Worker                return disabled_func(*args, **kwargs)
153*da0073e9SAndroid Build Coastguard Worker            else:
154*da0073e9SAndroid Build Coastguard Worker                return func(*args, **kwargs)
155*da0073e9SAndroid Build Coastguard Worker
156*da0073e9SAndroid Build Coastguard Worker        return maybe_fallback
157*da0073e9SAndroid Build Coastguard Worker
158*da0073e9SAndroid Build Coastguard Worker    return wrapper
159*da0073e9SAndroid Build Coastguard Worker
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker# For any optimizer with a faster implementation, we attempt to default to the
162*da0073e9SAndroid Build Coastguard Worker# fastest + stablest whenever possible. For foreach, the requirements are to have
163*da0073e9SAndroid Build Coastguard Worker# native params all on CUDA. For fused, there's currently the additional requirement
164*da0073e9SAndroid Build Coastguard Worker# that the tensors' dtypes must be floating point. Neither alternative supports
165*da0073e9SAndroid Build Coastguard Worker# torch.jit.script nor differentiable, so we fall back to the single tensor
166*da0073e9SAndroid Build Coastguard Worker# implementation in those cases.
167*da0073e9SAndroid Build Coastguard Workerdef _default_to_fused_or_foreach(
168*da0073e9SAndroid Build Coastguard Worker    params: List[torch.Tensor], differentiable: bool, use_fused: bool = False
169*da0073e9SAndroid Build Coastguard Worker) -> Tuple[bool, bool]:
170*da0073e9SAndroid Build Coastguard Worker    if torch.jit.is_scripting() or differentiable:
171*da0073e9SAndroid Build Coastguard Worker        return False, False
172*da0073e9SAndroid Build Coastguard Worker
173*da0073e9SAndroid Build Coastguard Worker    fused_supported_devices = _get_fused_kernels_supported_devices()
174*da0073e9SAndroid Build Coastguard Worker    foreach_supported_devices = _get_foreach_kernels_supported_devices()
175*da0073e9SAndroid Build Coastguard Worker    fused = use_fused and all(
176*da0073e9SAndroid Build Coastguard Worker        p is None
177*da0073e9SAndroid Build Coastguard Worker        or (
178*da0073e9SAndroid Build Coastguard Worker            type(p) in _foreach_supported_types
179*da0073e9SAndroid Build Coastguard Worker            and p.device.type in fused_supported_devices
180*da0073e9SAndroid Build Coastguard Worker            and torch.is_floating_point(p)
181*da0073e9SAndroid Build Coastguard Worker        )
182*da0073e9SAndroid Build Coastguard Worker        for p in params
183*da0073e9SAndroid Build Coastguard Worker    )
184*da0073e9SAndroid Build Coastguard Worker    foreach = not fused and all(
185*da0073e9SAndroid Build Coastguard Worker        p is None
186*da0073e9SAndroid Build Coastguard Worker        or (
187*da0073e9SAndroid Build Coastguard Worker            type(p) in _foreach_supported_types
188*da0073e9SAndroid Build Coastguard Worker            and p.device.type in foreach_supported_devices
189*da0073e9SAndroid Build Coastguard Worker        )
190*da0073e9SAndroid Build Coastguard Worker        for p in params
191*da0073e9SAndroid Build Coastguard Worker    )
192*da0073e9SAndroid Build Coastguard Worker    return fused, foreach
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker
195*da0073e9SAndroid Build Coastguard Workerdef _device_dtype_check_for_fused(
196*da0073e9SAndroid Build Coastguard Worker    p: torch.Tensor, cuda_unsupported: bool = False
197*da0073e9SAndroid Build Coastguard Worker) -> None:
198*da0073e9SAndroid Build Coastguard Worker    fused_supported_devices = _get_fused_kernels_supported_devices()
199*da0073e9SAndroid Build Coastguard Worker    if cuda_unsupported:
200*da0073e9SAndroid Build Coastguard Worker        fused_supported_devices.remove("cuda")
201*da0073e9SAndroid Build Coastguard Worker    if not (p.device.type in fused_supported_devices and torch.is_floating_point(p)):
202*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
203*da0073e9SAndroid Build Coastguard Worker            "`fused=True` requires all the params to be floating point Tensors of "
204*da0073e9SAndroid Build Coastguard Worker            f"supported devices: {fused_supported_devices} but {p.dtype} and {p.device.type}"
205*da0073e9SAndroid Build Coastguard Worker        )
206*da0073e9SAndroid Build Coastguard Worker
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Workerdef _view_as_real(params, *state_and_grads):
209*da0073e9SAndroid Build Coastguard Worker    for i, p in enumerate(params):
210*da0073e9SAndroid Build Coastguard Worker        if torch.is_complex(p):
211*da0073e9SAndroid Build Coastguard Worker            params[i] = torch.view_as_real(params[i])
212*da0073e9SAndroid Build Coastguard Worker            for s in state_and_grads:
213*da0073e9SAndroid Build Coastguard Worker                s[i] = torch.view_as_real(s[i])
214*da0073e9SAndroid Build Coastguard Worker
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Workerdef _get_scalar_dtype(is_fused=None):
217*da0073e9SAndroid Build Coastguard Worker    if is_fused:
218*da0073e9SAndroid Build Coastguard Worker        return torch.float32
219*da0073e9SAndroid Build Coastguard Worker    return (
220*da0073e9SAndroid Build Coastguard Worker        torch.float64 if torch.get_default_dtype() == torch.float64 else torch.float32
221*da0073e9SAndroid Build Coastguard Worker    )
222*da0073e9SAndroid Build Coastguard Worker
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Workerdef _get_capturable_supported_devices(supports_xla: bool = True) -> List[str]:
225*da0073e9SAndroid Build Coastguard Worker    r"""Return the device type list that supports capturable optimizer."""
226*da0073e9SAndroid Build Coastguard Worker    capturable_supported_devices = ["cuda", "xpu", "hpu"]
227*da0073e9SAndroid Build Coastguard Worker    if not torch.jit.is_scripting():
228*da0073e9SAndroid Build Coastguard Worker        capturable_supported_devices.append(torch._C._get_privateuse1_backend_name())
229*da0073e9SAndroid Build Coastguard Worker    if supports_xla:
230*da0073e9SAndroid Build Coastguard Worker        capturable_supported_devices.append("xla")
231*da0073e9SAndroid Build Coastguard Worker    return capturable_supported_devices
232*da0073e9SAndroid Build Coastguard Worker
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker# Common doc strings among optimizers
235*da0073e9SAndroid Build Coastguard Worker_foreach_doc = r"""foreach (bool, optional): whether foreach implementation of optimizer
236*da0073e9SAndroid Build Coastguard Worker            is used. If unspecified by the user (so foreach is None), we will try to use
237*da0073e9SAndroid Build Coastguard Worker            foreach over the for-loop implementation on CUDA, since it is usually
238*da0073e9SAndroid Build Coastguard Worker            significantly more performant. Note that the foreach implementation uses
239*da0073e9SAndroid Build Coastguard Worker            ~ sizeof(params) more peak memory than the for-loop version due to the intermediates
240*da0073e9SAndroid Build Coastguard Worker            being a tensorlist vs just one tensor. If memory is prohibitive, batch fewer
241*da0073e9SAndroid Build Coastguard Worker            parameters through the optimizer at a time or switch this flag to False (default: None)"""
242*da0073e9SAndroid Build Coastguard Worker
243*da0073e9SAndroid Build Coastguard Worker_fused_doc = r"""fused (bool, optional): whether the fused implementation is used.
244*da0073e9SAndroid Build Coastguard Worker            Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16`
245*da0073e9SAndroid Build Coastguard Worker            are supported. (default: None)
246*da0073e9SAndroid Build Coastguard Worker
247*da0073e9SAndroid Build Coastguard Worker    .. note:: The foreach and fused implementations are typically faster than the for-loop,
248*da0073e9SAndroid Build Coastguard Worker              single-tensor implementation, with fused being theoretically fastest with both
249*da0073e9SAndroid Build Coastguard Worker              vertical and horizontal fusion. As such, if the user has not specified either
250*da0073e9SAndroid Build Coastguard Worker              flag (i.e., when foreach = fused = None), we will attempt defaulting to the foreach
251*da0073e9SAndroid Build Coastguard Worker              implementation when the tensors are all on CUDA. Why not fused? Since the fused
252*da0073e9SAndroid Build Coastguard Worker              implementation is relatively new, we want to give it sufficient bake-in time.
253*da0073e9SAndroid Build Coastguard Worker              To specify fused, pass True for fused. To force running the for-loop
254*da0073e9SAndroid Build Coastguard Worker              implementation, pass False for either foreach or fused. """
255*da0073e9SAndroid Build Coastguard Worker
256*da0073e9SAndroid Build Coastguard Worker_capturable_doc = r"""capturable (bool, optional): whether this instance is safe to
257*da0073e9SAndroid Build Coastguard Worker            capture in a CUDA graph. Passing True can impair ungraphed performance,
258*da0073e9SAndroid Build Coastguard Worker            so if you don't intend to graph capture this instance, leave it False
259*da0073e9SAndroid Build Coastguard Worker            (default: False)"""
260*da0073e9SAndroid Build Coastguard Worker
261*da0073e9SAndroid Build Coastguard Worker_differentiable_doc = r"""differentiable (bool, optional): whether autograd should
262*da0073e9SAndroid Build Coastguard Worker            occur through the optimizer step in training. Otherwise, the step()
263*da0073e9SAndroid Build Coastguard Worker            function runs in a torch.no_grad() context. Setting to True can impair
264*da0073e9SAndroid Build Coastguard Worker            performance, so leave it False if you don't intend to run autograd
265*da0073e9SAndroid Build Coastguard Worker            through this instance (default: False)"""
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Worker_maximize_doc = r"""maximize (bool, optional): maximize the objective with respect to the
268*da0073e9SAndroid Build Coastguard Worker            params, instead of minimizing (default: False)"""
269*da0073e9SAndroid Build Coastguard Worker
270*da0073e9SAndroid Build Coastguard Worker
271*da0073e9SAndroid Build Coastguard Workerdef register_optimizer_step_pre_hook(hook: GlobalOptimizerPreHook) -> RemovableHandle:
272*da0073e9SAndroid Build Coastguard Worker    r"""Register a pre hook common to all optimizers.
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker    The hook should have the following signature::
275*da0073e9SAndroid Build Coastguard Worker
276*da0073e9SAndroid Build Coastguard Worker        hook(optimizer, args, kwargs) -> None or modified args and kwargs
277*da0073e9SAndroid Build Coastguard Worker
278*da0073e9SAndroid Build Coastguard Worker    Args:
279*da0073e9SAndroid Build Coastguard Worker        hook (Callable): A user defined hook which is registered on all optimizers.
280*da0073e9SAndroid Build Coastguard Worker
281*da0073e9SAndroid Build Coastguard Worker    Returns:
282*da0073e9SAndroid Build Coastguard Worker        :class:`torch.utils.hooks.RemovableHandle`:
283*da0073e9SAndroid Build Coastguard Worker            a handle that can be used to remove the added hook by calling
284*da0073e9SAndroid Build Coastguard Worker            ``handle.remove()``
285*da0073e9SAndroid Build Coastguard Worker    """
286*da0073e9SAndroid Build Coastguard Worker    handle = hooks.RemovableHandle(_global_optimizer_pre_hooks)
287*da0073e9SAndroid Build Coastguard Worker    _global_optimizer_pre_hooks[handle.id] = hook
288*da0073e9SAndroid Build Coastguard Worker    return handle
289*da0073e9SAndroid Build Coastguard Worker
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Workerdef register_optimizer_step_post_hook(hook: GlobalOptimizerPostHook) -> RemovableHandle:
292*da0073e9SAndroid Build Coastguard Worker    r"""Register a post hook common to all optimizers.
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker    The hook should have the following signature::
295*da0073e9SAndroid Build Coastguard Worker
296*da0073e9SAndroid Build Coastguard Worker        hook(optimizer, args, kwargs) -> None
297*da0073e9SAndroid Build Coastguard Worker
298*da0073e9SAndroid Build Coastguard Worker    Args:
299*da0073e9SAndroid Build Coastguard Worker        hook (Callable): A user defined hook which is registered on all optimizers.
300*da0073e9SAndroid Build Coastguard Worker
301*da0073e9SAndroid Build Coastguard Worker    Returns:
302*da0073e9SAndroid Build Coastguard Worker        :class:`torch.utils.hooks.RemovableHandle`:
303*da0073e9SAndroid Build Coastguard Worker            a handle that can be used to remove the added hook by calling
304*da0073e9SAndroid Build Coastguard Worker            ``handle.remove()``
305*da0073e9SAndroid Build Coastguard Worker    """
306*da0073e9SAndroid Build Coastguard Worker    handle = hooks.RemovableHandle(_global_optimizer_post_hooks)
307*da0073e9SAndroid Build Coastguard Worker    _global_optimizer_post_hooks[handle.id] = hook
308*da0073e9SAndroid Build Coastguard Worker    return handle
309*da0073e9SAndroid Build Coastguard Worker
310*da0073e9SAndroid Build Coastguard Worker
311*da0073e9SAndroid Build Coastguard WorkerParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker_P = ParamSpec("_P")
314*da0073e9SAndroid Build Coastguard WorkerR = TypeVar("R")
315*da0073e9SAndroid Build Coastguard WorkerT = TypeVar("T")
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard Worker
318*da0073e9SAndroid Build Coastguard Workerclass Optimizer:
319*da0073e9SAndroid Build Coastguard Worker    r"""Base class for all optimizers.
320*da0073e9SAndroid Build Coastguard Worker
321*da0073e9SAndroid Build Coastguard Worker    .. warning::
322*da0073e9SAndroid Build Coastguard Worker        Parameters need to be specified as collections that have a deterministic
323*da0073e9SAndroid Build Coastguard Worker        ordering that is consistent between runs. Examples of objects that don't
324*da0073e9SAndroid Build Coastguard Worker        satisfy those properties are sets and iterators over values of dictionaries.
325*da0073e9SAndroid Build Coastguard Worker
326*da0073e9SAndroid Build Coastguard Worker    Args:
327*da0073e9SAndroid Build Coastguard Worker        params (iterable): an iterable of :class:`torch.Tensor` s or
328*da0073e9SAndroid Build Coastguard Worker            :class:`dict` s. Specifies what Tensors should be optimized.
329*da0073e9SAndroid Build Coastguard Worker        defaults: (dict): a dict containing default values of optimization
330*da0073e9SAndroid Build Coastguard Worker            options (used when a parameter group doesn't specify them).
331*da0073e9SAndroid Build Coastguard Worker    """
332*da0073e9SAndroid Build Coastguard Worker
333*da0073e9SAndroid Build Coastguard Worker    OptimizerPreHook: TypeAlias = Callable[[Self, Args, Kwargs], Optional[Tuple[Args, Kwargs]]]  # type: ignore[misc]
334*da0073e9SAndroid Build Coastguard Worker    OptimizerPostHook: TypeAlias = Callable[[Self, Args, Kwargs], None]  # type: ignore[misc]
335*da0073e9SAndroid Build Coastguard Worker
336*da0073e9SAndroid Build Coastguard Worker    _optimizer_step_pre_hooks: Dict[int, OptimizerPreHook]
337*da0073e9SAndroid Build Coastguard Worker    _optimizer_step_post_hooks: Dict[int, OptimizerPostHook]
338*da0073e9SAndroid Build Coastguard Worker    _optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]'
339*da0073e9SAndroid Build Coastguard Worker    _optimizer_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
340*da0073e9SAndroid Build Coastguard Worker    _optimizer_load_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
341*da0073e9SAndroid Build Coastguard Worker    _optimizer_load_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]'
342*da0073e9SAndroid Build Coastguard Worker
343*da0073e9SAndroid Build Coastguard Worker    def __init__(self, params: ParamsT, defaults: Dict[str, Any]) -> None:  # noqa: D107
344*da0073e9SAndroid Build Coastguard Worker        torch._C._log_api_usage_once("python.optimizer")
345*da0073e9SAndroid Build Coastguard Worker        self.defaults = defaults
346*da0073e9SAndroid Build Coastguard Worker        self._optimizer_step_pre_hooks = OrderedDict()
347*da0073e9SAndroid Build Coastguard Worker        self._optimizer_step_post_hooks = OrderedDict()
348*da0073e9SAndroid Build Coastguard Worker        self._optimizer_state_dict_pre_hooks = OrderedDict()
349*da0073e9SAndroid Build Coastguard Worker        self._optimizer_state_dict_post_hooks = OrderedDict()
350*da0073e9SAndroid Build Coastguard Worker        self._optimizer_load_state_dict_pre_hooks = OrderedDict()
351*da0073e9SAndroid Build Coastguard Worker        self._optimizer_load_state_dict_post_hooks = OrderedDict()
352*da0073e9SAndroid Build Coastguard Worker
353*da0073e9SAndroid Build Coastguard Worker        self._patch_step_function()
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard Worker        if isinstance(params, torch.Tensor):
356*da0073e9SAndroid Build Coastguard Worker            raise TypeError(
357*da0073e9SAndroid Build Coastguard Worker                "params argument given to the optimizer should be "
358*da0073e9SAndroid Build Coastguard Worker                "an iterable of Tensors or dicts, but got " + torch.typename(params)
359*da0073e9SAndroid Build Coastguard Worker            )
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker        self.state: DefaultDict[torch.Tensor, Any] = defaultdict(dict)
362*da0073e9SAndroid Build Coastguard Worker        self.param_groups: List[Dict[str, Any]] = []
363*da0073e9SAndroid Build Coastguard Worker
364*da0073e9SAndroid Build Coastguard Worker        param_groups = list(params)
365*da0073e9SAndroid Build Coastguard Worker        if len(param_groups) == 0:
366*da0073e9SAndroid Build Coastguard Worker            raise ValueError("optimizer got an empty parameter list")
367*da0073e9SAndroid Build Coastguard Worker        if not isinstance(param_groups[0], dict):
368*da0073e9SAndroid Build Coastguard Worker            param_groups = [{"params": param_groups}]
369*da0073e9SAndroid Build Coastguard Worker
370*da0073e9SAndroid Build Coastguard Worker        for param_group in param_groups:
371*da0073e9SAndroid Build Coastguard Worker            self.add_param_group(cast(dict, param_group))
372*da0073e9SAndroid Build Coastguard Worker
373*da0073e9SAndroid Build Coastguard Worker        # Allows _cuda_graph_capture_health_check to rig a poor man's TORCH_WARN_ONCE in python,
374*da0073e9SAndroid Build Coastguard Worker        # which I don't think exists
375*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/72948
376*da0073e9SAndroid Build Coastguard Worker        self._warned_capturable_if_run_uncaptured = True
377*da0073e9SAndroid Build Coastguard Worker
378*da0073e9SAndroid Build Coastguard Worker    def __getstate__(self) -> Dict[str, Any]:  # noqa: D105
379*da0073e9SAndroid Build Coastguard Worker        return {
380*da0073e9SAndroid Build Coastguard Worker            "defaults": self.defaults,
381*da0073e9SAndroid Build Coastguard Worker            "state": self.state,
382*da0073e9SAndroid Build Coastguard Worker            "param_groups": self.param_groups,
383*da0073e9SAndroid Build Coastguard Worker        }
384*da0073e9SAndroid Build Coastguard Worker
385*da0073e9SAndroid Build Coastguard Worker    def __setstate__(self, state: Dict[str, Any]) -> None:  # noqa: D105
386*da0073e9SAndroid Build Coastguard Worker        self.__dict__.update(state)
387*da0073e9SAndroid Build Coastguard Worker        if "_optimizer_step_pre_hooks" not in self.__dict__:
388*da0073e9SAndroid Build Coastguard Worker            self._optimizer_step_pre_hooks = OrderedDict()
389*da0073e9SAndroid Build Coastguard Worker        if "_optimizer_step_post_hooks" not in self.__dict__:
390*da0073e9SAndroid Build Coastguard Worker            self._optimizer_step_post_hooks = OrderedDict()
391*da0073e9SAndroid Build Coastguard Worker        if "_optimizer_state_dict_pre_hooks" not in self.__dict__:
392*da0073e9SAndroid Build Coastguard Worker            self._optimizer_state_dict_pre_hooks = OrderedDict()
393*da0073e9SAndroid Build Coastguard Worker        if "_optimizer_state_dict_post_hooks" not in self.__dict__:
394*da0073e9SAndroid Build Coastguard Worker            self._optimizer_state_dict_post_hooks = OrderedDict()
395*da0073e9SAndroid Build Coastguard Worker        if "_optimizer_load_state_dict_pre_hooks" not in self.__dict__:
396*da0073e9SAndroid Build Coastguard Worker            self._optimizer_load_state_dict_pre_hooks = OrderedDict()
397*da0073e9SAndroid Build Coastguard Worker        if "_optimizer_load_state_dict_post_hooks" not in self.__dict__:
398*da0073e9SAndroid Build Coastguard Worker            self._optimizer_load_state_dict_post_hooks = OrderedDict()
399*da0073e9SAndroid Build Coastguard Worker        self._patch_step_function()  # To support multiprocessing pickle/unpickle
400*da0073e9SAndroid Build Coastguard Worker        self.defaults.setdefault("differentiable", False)
401*da0073e9SAndroid Build Coastguard Worker
402*da0073e9SAndroid Build Coastguard Worker    def __repr__(self) -> str:  # noqa: D105
403*da0073e9SAndroid Build Coastguard Worker        format_string = self.__class__.__name__ + " ("
404*da0073e9SAndroid Build Coastguard Worker        for i, group in enumerate(self.param_groups):
405*da0073e9SAndroid Build Coastguard Worker            format_string += "\n"
406*da0073e9SAndroid Build Coastguard Worker            format_string += f"Parameter Group {i}\n"
407*da0073e9SAndroid Build Coastguard Worker            for key in sorted(group.keys()):
408*da0073e9SAndroid Build Coastguard Worker                if key != "params":
409*da0073e9SAndroid Build Coastguard Worker                    format_string += f"    {key}: {group[key]}\n"
410*da0073e9SAndroid Build Coastguard Worker        format_string += ")"
411*da0073e9SAndroid Build Coastguard Worker        return format_string
412*da0073e9SAndroid Build Coastguard Worker
413*da0073e9SAndroid Build Coastguard Worker    # Currently needed by Adam and AdamW
414*da0073e9SAndroid Build Coastguard Worker    def _cuda_graph_capture_health_check(self) -> None:
415*da0073e9SAndroid Build Coastguard Worker        # Note [torch.compile x capturable]
416*da0073e9SAndroid Build Coastguard Worker        # If we are compiling, we try to take the capturable path automatically by
417*da0073e9SAndroid Build Coastguard Worker        # setting the flag to True during tracing. Due to this, we skip all the checks
418*da0073e9SAndroid Build Coastguard Worker        # normally required for determining whether we can use CUDA graphs and
419*da0073e9SAndroid Build Coastguard Worker        # shunt the responsibility to torch.inductor. This saves time during tracing
420*da0073e9SAndroid Build Coastguard Worker        # since the checks are slow without sacrificing UX since inductor will warn
421*da0073e9SAndroid Build Coastguard Worker        # later if CUDA graphs cannot be enabled, e.g.,
422*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/blob/d3ba8901d8640eb16f88b2bfef9df7fa383d4b47/torch/_inductor/compile_fx.py#L390.
423*da0073e9SAndroid Build Coastguard Worker        # Thus, when compiling, inductor will determine if cudagraphs
424*da0073e9SAndroid Build Coastguard Worker        # can be enabled based on whether there is input mutation or CPU tensors.
425*da0073e9SAndroid Build Coastguard Worker        if (
426*da0073e9SAndroid Build Coastguard Worker            not is_compiling()
427*da0073e9SAndroid Build Coastguard Worker            and torch.backends.cuda.is_built()
428*da0073e9SAndroid Build Coastguard Worker            and torch.cuda.is_available()
429*da0073e9SAndroid Build Coastguard Worker        ):
430*da0073e9SAndroid Build Coastguard Worker            capturing = torch.cuda.is_current_stream_capturing()
431*da0073e9SAndroid Build Coastguard Worker
432*da0073e9SAndroid Build Coastguard Worker            if capturing and not all(
433*da0073e9SAndroid Build Coastguard Worker                group["capturable"] for group in self.param_groups
434*da0073e9SAndroid Build Coastguard Worker            ):
435*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
436*da0073e9SAndroid Build Coastguard Worker                    "Attempting CUDA graph capture of step() for an instance of "
437*da0073e9SAndroid Build Coastguard Worker                    + self.__class__.__name__
438*da0073e9SAndroid Build Coastguard Worker                    + " but param_groups' capturable is False."
439*da0073e9SAndroid Build Coastguard Worker                )
440*da0073e9SAndroid Build Coastguard Worker
441*da0073e9SAndroid Build Coastguard Worker            if (
442*da0073e9SAndroid Build Coastguard Worker                (not getattr(self, "_warned_capturable_if_run_uncaptured", False))
443*da0073e9SAndroid Build Coastguard Worker                and all(group["capturable"] for group in self.param_groups)
444*da0073e9SAndroid Build Coastguard Worker                and (not capturing)
445*da0073e9SAndroid Build Coastguard Worker            ):
446*da0073e9SAndroid Build Coastguard Worker                warnings.warn(
447*da0073e9SAndroid Build Coastguard Worker                    "This instance was constructed with capturable=True or some of all the param_groups came with capturable=True, "
448*da0073e9SAndroid Build Coastguard Worker                    "but step() is running without CUDA graph capture. If you never intend to graph-capture this "
449*da0073e9SAndroid Build Coastguard Worker                    "instance, capturable=True can impair performance, and you should set capturable=False."
450*da0073e9SAndroid Build Coastguard Worker                )
451*da0073e9SAndroid Build Coastguard Worker                self._warned_capturable_if_run_uncaptured = True
452*da0073e9SAndroid Build Coastguard Worker
453*da0073e9SAndroid Build Coastguard Worker    def _optimizer_step_code(self) -> None:
454*da0073e9SAndroid Build Coastguard Worker        """Entry point for `torch.profile.profiler`.
455*da0073e9SAndroid Build Coastguard Worker
456*da0073e9SAndroid Build Coastguard Worker        When python tracing is enabled the profiler will hook into this
457*da0073e9SAndroid Build Coastguard Worker        function at the CPython level to inspect the optimizer's parameters and
458*da0073e9SAndroid Build Coastguard Worker        param groups. It is called it after `step()` since many optimizers
459*da0073e9SAndroid Build Coastguard Worker        lazily initialize state.
460*da0073e9SAndroid Build Coastguard Worker
461*da0073e9SAndroid Build Coastguard Worker        This is a workaround due to lack of a proper step hook on the optimizer,
462*da0073e9SAndroid Build Coastguard Worker        and will be removed if it exists.
463*da0073e9SAndroid Build Coastguard Worker        """
464*da0073e9SAndroid Build Coastguard Worker
465*da0073e9SAndroid Build Coastguard Worker    @staticmethod
466*da0073e9SAndroid Build Coastguard Worker    def profile_hook_step(func: Callable[_P, R]) -> Callable[_P, R]:  # noqa: D102
467*da0073e9SAndroid Build Coastguard Worker        @functools.wraps(func)
468*da0073e9SAndroid Build Coastguard Worker        def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> R:
469*da0073e9SAndroid Build Coastguard Worker            self, *_ = args
470*da0073e9SAndroid Build Coastguard Worker            self = cast(Optimizer, self)
471*da0073e9SAndroid Build Coastguard Worker            profile_name = f"Optimizer.step#{self.__class__.__name__}.step"
472*da0073e9SAndroid Build Coastguard Worker            with torch.autograd.profiler.record_function(profile_name):
473*da0073e9SAndroid Build Coastguard Worker                # call optimizer step pre hooks
474*da0073e9SAndroid Build Coastguard Worker                for pre_hook in chain(
475*da0073e9SAndroid Build Coastguard Worker                    _global_optimizer_pre_hooks.values(),
476*da0073e9SAndroid Build Coastguard Worker                    self._optimizer_step_pre_hooks.values(),
477*da0073e9SAndroid Build Coastguard Worker                ):
478*da0073e9SAndroid Build Coastguard Worker                    result = pre_hook(self, args, kwargs)
479*da0073e9SAndroid Build Coastguard Worker                    if result is not None:
480*da0073e9SAndroid Build Coastguard Worker                        if isinstance(result, tuple) and len(result) == 2:
481*da0073e9SAndroid Build Coastguard Worker                            args, kwargs = result  # type: ignore[assignment]
482*da0073e9SAndroid Build Coastguard Worker                        else:
483*da0073e9SAndroid Build Coastguard Worker                            raise RuntimeError(
484*da0073e9SAndroid Build Coastguard Worker                                f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
485*da0073e9SAndroid Build Coastguard Worker                            )
486*da0073e9SAndroid Build Coastguard Worker
487*da0073e9SAndroid Build Coastguard Worker                out = func(*args, **kwargs)
488*da0073e9SAndroid Build Coastguard Worker                self._optimizer_step_code()
489*da0073e9SAndroid Build Coastguard Worker
490*da0073e9SAndroid Build Coastguard Worker                # call optimizer step post hooks
491*da0073e9SAndroid Build Coastguard Worker                for post_hook in chain(
492*da0073e9SAndroid Build Coastguard Worker                    self._optimizer_step_post_hooks.values(),
493*da0073e9SAndroid Build Coastguard Worker                    _global_optimizer_post_hooks.values(),
494*da0073e9SAndroid Build Coastguard Worker                ):
495*da0073e9SAndroid Build Coastguard Worker                    post_hook(self, args, kwargs)
496*da0073e9SAndroid Build Coastguard Worker
497*da0073e9SAndroid Build Coastguard Worker                return out
498*da0073e9SAndroid Build Coastguard Worker
499*da0073e9SAndroid Build Coastguard Worker        return wrapper
500*da0073e9SAndroid Build Coastguard Worker
501*da0073e9SAndroid Build Coastguard Worker    @staticmethod
502*da0073e9SAndroid Build Coastguard Worker    def _group_tensors_by_device_and_dtype(
503*da0073e9SAndroid Build Coastguard Worker        tensorlistlist: TensorListList,
504*da0073e9SAndroid Build Coastguard Worker        with_indices: bool = False,
505*da0073e9SAndroid Build Coastguard Worker    ) -> Union[
506*da0073e9SAndroid Build Coastguard Worker        Dict[Tuple[None, None], Tuple[TensorListList, Indices]],
507*da0073e9SAndroid Build Coastguard Worker        Dict[Tuple[torch.device, torch.dtype], Tuple[TensorListList, Indices]],
508*da0073e9SAndroid Build Coastguard Worker    ]:
509*da0073e9SAndroid Build Coastguard Worker        """Group a list of lists of tensors by device and dtype.
510*da0073e9SAndroid Build Coastguard Worker
511*da0073e9SAndroid Build Coastguard Worker        Skips this step if we are compiling since this will occur during inductor lowering.
512*da0073e9SAndroid Build Coastguard Worker        """
513*da0073e9SAndroid Build Coastguard Worker        if is_compiling():
514*da0073e9SAndroid Build Coastguard Worker            return {(None, None): (tensorlistlist, list(range(len(tensorlistlist[0]))))}
515*da0073e9SAndroid Build Coastguard Worker        else:
516*da0073e9SAndroid Build Coastguard Worker            return _group_tensors_by_device_and_dtype(tensorlistlist, with_indices)  # type: ignore[return-value, arg-type]
517*da0073e9SAndroid Build Coastguard Worker
518*da0073e9SAndroid Build Coastguard Worker    def _patch_step_function(self) -> None:
519*da0073e9SAndroid Build Coastguard Worker        self._zero_grad_profile_name = (
520*da0073e9SAndroid Build Coastguard Worker            f"Optimizer.zero_grad#{self.__class__.__name__}.zero_grad"
521*da0073e9SAndroid Build Coastguard Worker        )
522*da0073e9SAndroid Build Coastguard Worker        hooked = getattr(self.__class__.step, "hooked", None)
523*da0073e9SAndroid Build Coastguard Worker        if not hooked:
524*da0073e9SAndroid Build Coastguard Worker            self.__class__.step = self.profile_hook_step(self.__class__.step)  # type: ignore[assignment]
525*da0073e9SAndroid Build Coastguard Worker            self.__class__.step.hooked = True  # type: ignore[attr-defined]
526*da0073e9SAndroid Build Coastguard Worker
527*da0073e9SAndroid Build Coastguard Worker    def register_step_pre_hook(self, hook: OptimizerPreHook) -> RemovableHandle:
528*da0073e9SAndroid Build Coastguard Worker        r"""Register an optimizer step pre hook which will be called before optimizer step.
529*da0073e9SAndroid Build Coastguard Worker
530*da0073e9SAndroid Build Coastguard Worker        It should have the following signature::
531*da0073e9SAndroid Build Coastguard Worker
532*da0073e9SAndroid Build Coastguard Worker            hook(optimizer, args, kwargs) -> None or modified args and kwargs
533*da0073e9SAndroid Build Coastguard Worker
534*da0073e9SAndroid Build Coastguard Worker        The ``optimizer`` argument is the optimizer instance being used. If
535*da0073e9SAndroid Build Coastguard Worker        args and kwargs are modified by the pre-hook, then the transformed
536*da0073e9SAndroid Build Coastguard Worker        values are returned as a tuple containing the new_args and new_kwargs.
537*da0073e9SAndroid Build Coastguard Worker
538*da0073e9SAndroid Build Coastguard Worker        Args:
539*da0073e9SAndroid Build Coastguard Worker            hook (Callable): The user defined hook to be registered.
540*da0073e9SAndroid Build Coastguard Worker
541*da0073e9SAndroid Build Coastguard Worker        Returns:
542*da0073e9SAndroid Build Coastguard Worker            :class:`torch.utils.hooks.RemovableHandle`:
543*da0073e9SAndroid Build Coastguard Worker                a handle that can be used to remove the added hook by calling
544*da0073e9SAndroid Build Coastguard Worker                ``handle.remove()``
545*da0073e9SAndroid Build Coastguard Worker        """
546*da0073e9SAndroid Build Coastguard Worker        handle = hooks.RemovableHandle(self._optimizer_step_pre_hooks)
547*da0073e9SAndroid Build Coastguard Worker        self._optimizer_step_pre_hooks[handle.id] = hook
548*da0073e9SAndroid Build Coastguard Worker        return handle
549*da0073e9SAndroid Build Coastguard Worker
550*da0073e9SAndroid Build Coastguard Worker    def register_step_post_hook(self, hook: OptimizerPostHook) -> RemovableHandle:
551*da0073e9SAndroid Build Coastguard Worker        r"""Register an optimizer step post hook which will be called after optimizer step.
552*da0073e9SAndroid Build Coastguard Worker
553*da0073e9SAndroid Build Coastguard Worker        It should have the following signature::
554*da0073e9SAndroid Build Coastguard Worker
555*da0073e9SAndroid Build Coastguard Worker            hook(optimizer, args, kwargs) -> None
556*da0073e9SAndroid Build Coastguard Worker
557*da0073e9SAndroid Build Coastguard Worker        The ``optimizer`` argument is the optimizer instance being used.
558*da0073e9SAndroid Build Coastguard Worker
559*da0073e9SAndroid Build Coastguard Worker        Args:
560*da0073e9SAndroid Build Coastguard Worker            hook (Callable): The user defined hook to be registered.
561*da0073e9SAndroid Build Coastguard Worker
562*da0073e9SAndroid Build Coastguard Worker        Returns:
563*da0073e9SAndroid Build Coastguard Worker            :class:`torch.utils.hooks.RemovableHandle`:
564*da0073e9SAndroid Build Coastguard Worker                a handle that can be used to remove the added hook by calling
565*da0073e9SAndroid Build Coastguard Worker                ``handle.remove()``
566*da0073e9SAndroid Build Coastguard Worker        """
567*da0073e9SAndroid Build Coastguard Worker        handle = hooks.RemovableHandle(self._optimizer_step_post_hooks)
568*da0073e9SAndroid Build Coastguard Worker        self._optimizer_step_post_hooks[handle.id] = hook
569*da0073e9SAndroid Build Coastguard Worker        return handle
570*da0073e9SAndroid Build Coastguard Worker
571*da0073e9SAndroid Build Coastguard Worker    def register_state_dict_pre_hook(
572*da0073e9SAndroid Build Coastguard Worker        self, hook: Callable[["Optimizer"], None], prepend: bool = False
573*da0073e9SAndroid Build Coastguard Worker    ) -> RemovableHandle:  # noqa: D101
574*da0073e9SAndroid Build Coastguard Worker        r"""Register a state dict pre-hook which will be called before :meth:`~torch.optim.Optimizer.state_dict` is called.
575*da0073e9SAndroid Build Coastguard Worker
576*da0073e9SAndroid Build Coastguard Worker        It should have the following signature::
577*da0073e9SAndroid Build Coastguard Worker
578*da0073e9SAndroid Build Coastguard Worker            hook(optimizer) -> None
579*da0073e9SAndroid Build Coastguard Worker
580*da0073e9SAndroid Build Coastguard Worker        The ``optimizer`` argument is the optimizer instance being used.
581*da0073e9SAndroid Build Coastguard Worker        The hook will be called with argument ``self`` before calling ``state_dict`` on ``self``.
582*da0073e9SAndroid Build Coastguard Worker        The registered hook can be used to perform pre-processing before the ``state_dict``
583*da0073e9SAndroid Build Coastguard Worker        call is made.
584*da0073e9SAndroid Build Coastguard Worker
585*da0073e9SAndroid Build Coastguard Worker        Args:
586*da0073e9SAndroid Build Coastguard Worker            hook (Callable): The user defined hook to be registered.
587*da0073e9SAndroid Build Coastguard Worker            prepend (bool): If True, the provided pre ``hook`` will be fired before
588*da0073e9SAndroid Build Coastguard Worker                all the already registered pre-hooks on ``state_dict``. Otherwise,
589*da0073e9SAndroid Build Coastguard Worker                the provided ``hook`` will be fired after all the already registered
590*da0073e9SAndroid Build Coastguard Worker                pre-hooks. (default: False)
591*da0073e9SAndroid Build Coastguard Worker
592*da0073e9SAndroid Build Coastguard Worker        Returns:
593*da0073e9SAndroid Build Coastguard Worker            :class:`torch.utils.hooks.RemoveableHandle`:
594*da0073e9SAndroid Build Coastguard Worker                a handle that can be used to remove the added hook by calling
595*da0073e9SAndroid Build Coastguard Worker                ``handle.remove()``
596*da0073e9SAndroid Build Coastguard Worker        """
597*da0073e9SAndroid Build Coastguard Worker        handle = hooks.RemovableHandle(self._optimizer_state_dict_pre_hooks)
598*da0073e9SAndroid Build Coastguard Worker        self._optimizer_state_dict_pre_hooks[handle.id] = hook
599*da0073e9SAndroid Build Coastguard Worker        if prepend:
600*da0073e9SAndroid Build Coastguard Worker            self._optimizer_state_dict_pre_hooks.move_to_end(handle.id, last=False)
601*da0073e9SAndroid Build Coastguard Worker        return handle
602*da0073e9SAndroid Build Coastguard Worker
603*da0073e9SAndroid Build Coastguard Worker    def register_state_dict_post_hook(
604*da0073e9SAndroid Build Coastguard Worker        self,
605*da0073e9SAndroid Build Coastguard Worker        hook: Callable[["Optimizer", StateDict], Optional[StateDict]],
606*da0073e9SAndroid Build Coastguard Worker        prepend: bool = False,
607*da0073e9SAndroid Build Coastguard Worker    ) -> RemovableHandle:
608*da0073e9SAndroid Build Coastguard Worker        r"""Register a state dict post-hook which will be called after :meth:`~torch.optim.Optimizer.state_dict` is called.
609*da0073e9SAndroid Build Coastguard Worker
610*da0073e9SAndroid Build Coastguard Worker        It should have the following signature::
611*da0073e9SAndroid Build Coastguard Worker
612*da0073e9SAndroid Build Coastguard Worker            hook(optimizer, state_dict) -> state_dict or None
613*da0073e9SAndroid Build Coastguard Worker
614*da0073e9SAndroid Build Coastguard Worker        The hook will be called with arguments ``self`` and ``state_dict`` after generating
615*da0073e9SAndroid Build Coastguard Worker        a ``state_dict`` on ``self``. The hook may modify the state_dict inplace or optionally
616*da0073e9SAndroid Build Coastguard Worker        return a new one. The registered hook can be used to perform post-processing
617*da0073e9SAndroid Build Coastguard Worker        on the ``state_dict`` before it is returned.
618*da0073e9SAndroid Build Coastguard Worker
619*da0073e9SAndroid Build Coastguard Worker        Args:
620*da0073e9SAndroid Build Coastguard Worker            hook (Callable): The user defined hook to be registered.
621*da0073e9SAndroid Build Coastguard Worker            prepend (bool): If True, the provided post ``hook`` will be fired before
622*da0073e9SAndroid Build Coastguard Worker                all the already registered post-hooks on ``state_dict``. Otherwise,
623*da0073e9SAndroid Build Coastguard Worker                the provided ``hook`` will be fired after all the already registered
624*da0073e9SAndroid Build Coastguard Worker                post-hooks. (default: False)
625*da0073e9SAndroid Build Coastguard Worker
626*da0073e9SAndroid Build Coastguard Worker        Returns:
627*da0073e9SAndroid Build Coastguard Worker            :class:`torch.utils.hooks.RemoveableHandle`:
628*da0073e9SAndroid Build Coastguard Worker                a handle that can be used to remove the added hook by calling
629*da0073e9SAndroid Build Coastguard Worker                ``handle.remove()``
630*da0073e9SAndroid Build Coastguard Worker        """
631*da0073e9SAndroid Build Coastguard Worker        handle = hooks.RemovableHandle(self._optimizer_state_dict_post_hooks)
632*da0073e9SAndroid Build Coastguard Worker        self._optimizer_state_dict_post_hooks[handle.id] = hook
633*da0073e9SAndroid Build Coastguard Worker        if prepend:
634*da0073e9SAndroid Build Coastguard Worker            self._optimizer_state_dict_post_hooks.move_to_end(handle.id, last=False)
635*da0073e9SAndroid Build Coastguard Worker        return handle
636*da0073e9SAndroid Build Coastguard Worker
637*da0073e9SAndroid Build Coastguard Worker    @torch._disable_dynamo
638*da0073e9SAndroid Build Coastguard Worker    def state_dict(self) -> StateDict:
639*da0073e9SAndroid Build Coastguard Worker        r"""Return the state of the optimizer as a :class:`dict`.
640*da0073e9SAndroid Build Coastguard Worker
641*da0073e9SAndroid Build Coastguard Worker        It contains two entries:
642*da0073e9SAndroid Build Coastguard Worker
643*da0073e9SAndroid Build Coastguard Worker        * ``state``: a Dict holding current optimization state. Its content
644*da0073e9SAndroid Build Coastguard Worker            differs between optimizer classes, but some common characteristics
645*da0073e9SAndroid Build Coastguard Worker            hold. For example, state is saved per parameter, and the parameter
646*da0073e9SAndroid Build Coastguard Worker            itself is NOT saved. ``state`` is a Dictionary mapping parameter ids
647*da0073e9SAndroid Build Coastguard Worker            to a Dict with state corresponding to each parameter.
648*da0073e9SAndroid Build Coastguard Worker        * ``param_groups``: a List containing all parameter groups where each
649*da0073e9SAndroid Build Coastguard Worker            parameter group is a Dict. Each parameter group contains metadata
650*da0073e9SAndroid Build Coastguard Worker            specific to the optimizer, such as learning rate and weight decay,
651*da0073e9SAndroid Build Coastguard Worker            as well as a List of parameter IDs of the parameters in the group.
652*da0073e9SAndroid Build Coastguard Worker
653*da0073e9SAndroid Build Coastguard Worker        NOTE: The parameter IDs may look like indices but they are just IDs
654*da0073e9SAndroid Build Coastguard Worker        associating state with param_group. When loading from a state_dict,
655*da0073e9SAndroid Build Coastguard Worker        the optimizer will zip the param_group ``params`` (int IDs) and the
656*da0073e9SAndroid Build Coastguard Worker        optimizer ``param_groups`` (actual ``nn.Parameter`` s) in order to
657*da0073e9SAndroid Build Coastguard Worker        match state WITHOUT additional verification.
658*da0073e9SAndroid Build Coastguard Worker
659*da0073e9SAndroid Build Coastguard Worker        A returned state dict might look something like:
660*da0073e9SAndroid Build Coastguard Worker
661*da0073e9SAndroid Build Coastguard Worker        .. code-block:: text
662*da0073e9SAndroid Build Coastguard Worker
663*da0073e9SAndroid Build Coastguard Worker            {
664*da0073e9SAndroid Build Coastguard Worker                'state': {
665*da0073e9SAndroid Build Coastguard Worker                    0: {'momentum_buffer': tensor(...), ...},
666*da0073e9SAndroid Build Coastguard Worker                    1: {'momentum_buffer': tensor(...), ...},
667*da0073e9SAndroid Build Coastguard Worker                    2: {'momentum_buffer': tensor(...), ...},
668*da0073e9SAndroid Build Coastguard Worker                    3: {'momentum_buffer': tensor(...), ...}
669*da0073e9SAndroid Build Coastguard Worker                },
670*da0073e9SAndroid Build Coastguard Worker                'param_groups': [
671*da0073e9SAndroid Build Coastguard Worker                    {
672*da0073e9SAndroid Build Coastguard Worker                        'lr': 0.01,
673*da0073e9SAndroid Build Coastguard Worker                        'weight_decay': 0,
674*da0073e9SAndroid Build Coastguard Worker                        ...
675*da0073e9SAndroid Build Coastguard Worker                        'params': [0]
676*da0073e9SAndroid Build Coastguard Worker                    },
677*da0073e9SAndroid Build Coastguard Worker                    {
678*da0073e9SAndroid Build Coastguard Worker                        'lr': 0.001,
679*da0073e9SAndroid Build Coastguard Worker                        'weight_decay': 0.5,
680*da0073e9SAndroid Build Coastguard Worker                        ...
681*da0073e9SAndroid Build Coastguard Worker                        'params': [1, 2, 3]
682*da0073e9SAndroid Build Coastguard Worker                    }
683*da0073e9SAndroid Build Coastguard Worker                ]
684*da0073e9SAndroid Build Coastguard Worker            }
685*da0073e9SAndroid Build Coastguard Worker
686*da0073e9SAndroid Build Coastguard Worker        """
687*da0073e9SAndroid Build Coastguard Worker        for pre_hook in self._optimizer_state_dict_pre_hooks.values():
688*da0073e9SAndroid Build Coastguard Worker            pre_hook(self)
689*da0073e9SAndroid Build Coastguard Worker
690*da0073e9SAndroid Build Coastguard Worker        # Save order indices instead of Tensors
691*da0073e9SAndroid Build Coastguard Worker        param_mappings: Dict[int, int] = {}
692*da0073e9SAndroid Build Coastguard Worker        start_index = 0
693*da0073e9SAndroid Build Coastguard Worker
694*da0073e9SAndroid Build Coastguard Worker        def pack_group(group: Dict[str, Any]) -> Dict[str, Any]:
695*da0073e9SAndroid Build Coastguard Worker            nonlocal start_index
696*da0073e9SAndroid Build Coastguard Worker            packed = {k: v for k, v in group.items() if k != "params"}
697*da0073e9SAndroid Build Coastguard Worker            param_mappings.update(
698*da0073e9SAndroid Build Coastguard Worker                {
699*da0073e9SAndroid Build Coastguard Worker                    id(p): i
700*da0073e9SAndroid Build Coastguard Worker                    for i, p in enumerate(group["params"], start_index)
701*da0073e9SAndroid Build Coastguard Worker                    if id(p) not in param_mappings
702*da0073e9SAndroid Build Coastguard Worker                }
703*da0073e9SAndroid Build Coastguard Worker            )
704*da0073e9SAndroid Build Coastguard Worker            packed["params"] = [param_mappings[id(p)] for p in group["params"]]
705*da0073e9SAndroid Build Coastguard Worker            start_index += len(packed["params"])
706*da0073e9SAndroid Build Coastguard Worker            return packed
707*da0073e9SAndroid Build Coastguard Worker
708*da0073e9SAndroid Build Coastguard Worker        param_groups = [pack_group(g) for g in self.param_groups]
709*da0073e9SAndroid Build Coastguard Worker        # Remap state to use order indices as keys
710*da0073e9SAndroid Build Coastguard Worker        packed_state = {
711*da0073e9SAndroid Build Coastguard Worker            (param_mappings[id(k)] if isinstance(k, torch.Tensor) else k): v
712*da0073e9SAndroid Build Coastguard Worker            for k, v in self.state.items()
713*da0073e9SAndroid Build Coastguard Worker        }
714*da0073e9SAndroid Build Coastguard Worker
715*da0073e9SAndroid Build Coastguard Worker        state_dict = {
716*da0073e9SAndroid Build Coastguard Worker            "state": packed_state,
717*da0073e9SAndroid Build Coastguard Worker            "param_groups": param_groups,
718*da0073e9SAndroid Build Coastguard Worker        }
719*da0073e9SAndroid Build Coastguard Worker
720*da0073e9SAndroid Build Coastguard Worker        for post_hook in self._optimizer_state_dict_post_hooks.values():
721*da0073e9SAndroid Build Coastguard Worker            hook_result = post_hook(self, state_dict)
722*da0073e9SAndroid Build Coastguard Worker            if hook_result is not None:
723*da0073e9SAndroid Build Coastguard Worker                state_dict = hook_result
724*da0073e9SAndroid Build Coastguard Worker        return state_dict
725*da0073e9SAndroid Build Coastguard Worker
726*da0073e9SAndroid Build Coastguard Worker    @staticmethod
727*da0073e9SAndroid Build Coastguard Worker    def _process_value_according_to_param_policy(
728*da0073e9SAndroid Build Coastguard Worker        param: torch.Tensor,
729*da0073e9SAndroid Build Coastguard Worker        value: torch.Tensor,
730*da0073e9SAndroid Build Coastguard Worker        param_id: int,
731*da0073e9SAndroid Build Coastguard Worker        param_groups: List[Dict[Any, Any]],
732*da0073e9SAndroid Build Coastguard Worker        key: Hashable = None,
733*da0073e9SAndroid Build Coastguard Worker    ) -> torch.Tensor:
734*da0073e9SAndroid Build Coastguard Worker        # Floating-point types are a bit special here. They are the only ones
735*da0073e9SAndroid Build Coastguard Worker        # that are assumed to always match the type of params.
736*da0073e9SAndroid Build Coastguard Worker        # Make sure state['step'] is not casted https://github.com/pytorch/pytorch/issues/74424
737*da0073e9SAndroid Build Coastguard Worker        # UNLESS fused or capturable, see note [special device hosting for step]
738*da0073e9SAndroid Build Coastguard Worker        fused = False
739*da0073e9SAndroid Build Coastguard Worker        capturable = False
740*da0073e9SAndroid Build Coastguard Worker        assert param_groups is not None
741*da0073e9SAndroid Build Coastguard Worker        for pg in param_groups:
742*da0073e9SAndroid Build Coastguard Worker            if param_id in pg["params"]:
743*da0073e9SAndroid Build Coastguard Worker                fused = pg["fused"] if "fused" in pg else False
744*da0073e9SAndroid Build Coastguard Worker                capturable = pg["capturable"] if "capturable" in pg else False
745*da0073e9SAndroid Build Coastguard Worker                break
746*da0073e9SAndroid Build Coastguard Worker        if key == "step":
747*da0073e9SAndroid Build Coastguard Worker            if capturable or fused:
748*da0073e9SAndroid Build Coastguard Worker                return value.to(dtype=torch.float32, device=param.device)
749*da0073e9SAndroid Build Coastguard Worker            else:
750*da0073e9SAndroid Build Coastguard Worker                return value
751*da0073e9SAndroid Build Coastguard Worker        else:
752*da0073e9SAndroid Build Coastguard Worker            if param.is_floating_point():
753*da0073e9SAndroid Build Coastguard Worker                return value.to(dtype=param.dtype, device=param.device)
754*da0073e9SAndroid Build Coastguard Worker            else:
755*da0073e9SAndroid Build Coastguard Worker                return value.to(device=param.device)
756*da0073e9SAndroid Build Coastguard Worker
757*da0073e9SAndroid Build Coastguard Worker    def register_load_state_dict_pre_hook(
758*da0073e9SAndroid Build Coastguard Worker        self,
759*da0073e9SAndroid Build Coastguard Worker        hook: Callable[["Optimizer", StateDict], Optional[StateDict]],
760*da0073e9SAndroid Build Coastguard Worker        prepend: bool = False,
761*da0073e9SAndroid Build Coastguard Worker    ) -> RemovableHandle:  # noqa: D205 D400
762*da0073e9SAndroid Build Coastguard Worker        r"""Register a load_state_dict pre-hook which will be called before
763*da0073e9SAndroid Build Coastguard Worker        :meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the
764*da0073e9SAndroid Build Coastguard Worker        following signature::
765*da0073e9SAndroid Build Coastguard Worker
766*da0073e9SAndroid Build Coastguard Worker            hook(optimizer, state_dict) -> state_dict or None
767*da0073e9SAndroid Build Coastguard Worker
768*da0073e9SAndroid Build Coastguard Worker        The ``optimizer`` argument is the optimizer instance being used and the
769*da0073e9SAndroid Build Coastguard Worker        ``state_dict`` argument is a shallow copy of the ``state_dict`` the user
770*da0073e9SAndroid Build Coastguard Worker        passed in to ``load_state_dict``. The hook may modify the state_dict inplace
771*da0073e9SAndroid Build Coastguard Worker        or optionally return a new one. If a state_dict is returned, it will be used
772*da0073e9SAndroid Build Coastguard Worker        to be loaded into the optimizer.
773*da0073e9SAndroid Build Coastguard Worker
774*da0073e9SAndroid Build Coastguard Worker        The hook will be called with argument ``self`` and ``state_dict`` before
775*da0073e9SAndroid Build Coastguard Worker        calling ``load_state_dict`` on ``self``. The registered hook can be used to
776*da0073e9SAndroid Build Coastguard Worker        perform pre-processing before the ``load_state_dict`` call is made.
777*da0073e9SAndroid Build Coastguard Worker
778*da0073e9SAndroid Build Coastguard Worker        Args:
779*da0073e9SAndroid Build Coastguard Worker            hook (Callable): The user defined hook to be registered.
780*da0073e9SAndroid Build Coastguard Worker            prepend (bool): If True, the provided pre ``hook`` will be fired before
781*da0073e9SAndroid Build Coastguard Worker                all the already registered pre-hooks on ``load_state_dict``. Otherwise,
782*da0073e9SAndroid Build Coastguard Worker                the provided ``hook`` will be fired after all the already registered
783*da0073e9SAndroid Build Coastguard Worker                pre-hooks. (default: False)
784*da0073e9SAndroid Build Coastguard Worker
785*da0073e9SAndroid Build Coastguard Worker        Returns:
786*da0073e9SAndroid Build Coastguard Worker            :class:`torch.utils.hooks.RemoveableHandle`:
787*da0073e9SAndroid Build Coastguard Worker                a handle that can be used to remove the added hook by calling
788*da0073e9SAndroid Build Coastguard Worker                ``handle.remove()``
789*da0073e9SAndroid Build Coastguard Worker        """
790*da0073e9SAndroid Build Coastguard Worker        handle = hooks.RemovableHandle(self._optimizer_load_state_dict_pre_hooks)
791*da0073e9SAndroid Build Coastguard Worker        self._optimizer_load_state_dict_pre_hooks[handle.id] = hook
792*da0073e9SAndroid Build Coastguard Worker        if prepend:
793*da0073e9SAndroid Build Coastguard Worker            self._optimizer_load_state_dict_pre_hooks.move_to_end(handle.id, last=False)
794*da0073e9SAndroid Build Coastguard Worker        return handle
795*da0073e9SAndroid Build Coastguard Worker
796*da0073e9SAndroid Build Coastguard Worker    def register_load_state_dict_post_hook(
797*da0073e9SAndroid Build Coastguard Worker        self, hook: Callable[["Optimizer"], None], prepend: bool = False
798*da0073e9SAndroid Build Coastguard Worker    ) -> RemovableHandle:  # noqa: D205 D400
799*da0073e9SAndroid Build Coastguard Worker        r"""Register a load_state_dict post-hook which will be called after
800*da0073e9SAndroid Build Coastguard Worker        :meth:`~torch.optim.Optimizer.load_state_dict` is called. It should have the
801*da0073e9SAndroid Build Coastguard Worker        following signature::
802*da0073e9SAndroid Build Coastguard Worker
803*da0073e9SAndroid Build Coastguard Worker            hook(optimizer) -> None
804*da0073e9SAndroid Build Coastguard Worker
805*da0073e9SAndroid Build Coastguard Worker        The ``optimizer`` argument is the optimizer instance being used.
806*da0073e9SAndroid Build Coastguard Worker
807*da0073e9SAndroid Build Coastguard Worker        The hook will be called with argument ``self`` after calling
808*da0073e9SAndroid Build Coastguard Worker        ``load_state_dict`` on ``self``. The registered hook can be used to
809*da0073e9SAndroid Build Coastguard Worker        perform post-processing after ``load_state_dict`` has loaded the
810*da0073e9SAndroid Build Coastguard Worker        ``state_dict``.
811*da0073e9SAndroid Build Coastguard Worker
812*da0073e9SAndroid Build Coastguard Worker        Args:
813*da0073e9SAndroid Build Coastguard Worker            hook (Callable): The user defined hook to be registered.
814*da0073e9SAndroid Build Coastguard Worker            prepend (bool): If True, the provided post ``hook`` will be fired before
815*da0073e9SAndroid Build Coastguard Worker                all the already registered post-hooks on ``load_state_dict``. Otherwise,
816*da0073e9SAndroid Build Coastguard Worker                the provided ``hook`` will be fired after all the already registered
817*da0073e9SAndroid Build Coastguard Worker                post-hooks. (default: False)
818*da0073e9SAndroid Build Coastguard Worker
819*da0073e9SAndroid Build Coastguard Worker        Returns:
820*da0073e9SAndroid Build Coastguard Worker            :class:`torch.utils.hooks.RemoveableHandle`:
821*da0073e9SAndroid Build Coastguard Worker                a handle that can be used to remove the added hook by calling
822*da0073e9SAndroid Build Coastguard Worker                ``handle.remove()``
823*da0073e9SAndroid Build Coastguard Worker        """
824*da0073e9SAndroid Build Coastguard Worker        handle = hooks.RemovableHandle(self._optimizer_load_state_dict_post_hooks)
825*da0073e9SAndroid Build Coastguard Worker        self._optimizer_load_state_dict_post_hooks[handle.id] = hook
826*da0073e9SAndroid Build Coastguard Worker        if prepend:
827*da0073e9SAndroid Build Coastguard Worker            self._optimizer_load_state_dict_post_hooks.move_to_end(handle.id, last=False)  # type: ignore[attr-defined]
828*da0073e9SAndroid Build Coastguard Worker        return handle
829*da0073e9SAndroid Build Coastguard Worker
830*da0073e9SAndroid Build Coastguard Worker    @torch._disable_dynamo
831*da0073e9SAndroid Build Coastguard Worker    def load_state_dict(self, state_dict: StateDict) -> None:
832*da0073e9SAndroid Build Coastguard Worker        r"""Load the optimizer state.
833*da0073e9SAndroid Build Coastguard Worker
834*da0073e9SAndroid Build Coastguard Worker        Args:
835*da0073e9SAndroid Build Coastguard Worker            state_dict (dict): optimizer state. Should be an object returned
836*da0073e9SAndroid Build Coastguard Worker                from a call to :meth:`state_dict`.
837*da0073e9SAndroid Build Coastguard Worker        """
838*da0073e9SAndroid Build Coastguard Worker        # shallow copy, to be consistent with module API
839*da0073e9SAndroid Build Coastguard Worker        state_dict = state_dict.copy()
840*da0073e9SAndroid Build Coastguard Worker
841*da0073e9SAndroid Build Coastguard Worker        for pre_hook in self._optimizer_load_state_dict_pre_hooks.values():
842*da0073e9SAndroid Build Coastguard Worker            hook_result = pre_hook(self, state_dict)
843*da0073e9SAndroid Build Coastguard Worker            if hook_result is not None:
844*da0073e9SAndroid Build Coastguard Worker                state_dict = hook_result
845*da0073e9SAndroid Build Coastguard Worker
846*da0073e9SAndroid Build Coastguard Worker        # Validate the state_dict
847*da0073e9SAndroid Build Coastguard Worker        groups = self.param_groups
848*da0073e9SAndroid Build Coastguard Worker
849*da0073e9SAndroid Build Coastguard Worker        # Deepcopy as we write into saved_groups later to update state
850*da0073e9SAndroid Build Coastguard Worker        saved_groups = deepcopy(state_dict["param_groups"])
851*da0073e9SAndroid Build Coastguard Worker
852*da0073e9SAndroid Build Coastguard Worker        if len(groups) != len(saved_groups):
853*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
854*da0073e9SAndroid Build Coastguard Worker                "loaded state dict has a different number of " "parameter groups"
855*da0073e9SAndroid Build Coastguard Worker            )
856*da0073e9SAndroid Build Coastguard Worker        param_lens = (len(g["params"]) for g in groups)
857*da0073e9SAndroid Build Coastguard Worker        saved_lens = (len(g["params"]) for g in saved_groups)
858*da0073e9SAndroid Build Coastguard Worker        if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
859*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
860*da0073e9SAndroid Build Coastguard Worker                "loaded state dict contains a parameter group "
861*da0073e9SAndroid Build Coastguard Worker                "that doesn't match the size of optimizer's group"
862*da0073e9SAndroid Build Coastguard Worker            )
863*da0073e9SAndroid Build Coastguard Worker
864*da0073e9SAndroid Build Coastguard Worker        # Update the state
865*da0073e9SAndroid Build Coastguard Worker        id_map = dict(
866*da0073e9SAndroid Build Coastguard Worker            zip(
867*da0073e9SAndroid Build Coastguard Worker                chain.from_iterable(g["params"] for g in saved_groups),
868*da0073e9SAndroid Build Coastguard Worker                chain.from_iterable(g["params"] for g in groups),
869*da0073e9SAndroid Build Coastguard Worker            )
870*da0073e9SAndroid Build Coastguard Worker        )
871*da0073e9SAndroid Build Coastguard Worker
872*da0073e9SAndroid Build Coastguard Worker        def _cast(param, value, param_id=None, param_groups=None, key=None):
873*da0073e9SAndroid Build Coastguard Worker            r"""Make a deep copy of value, casting all tensors to device of param."""
874*da0073e9SAndroid Build Coastguard Worker            if isinstance(value, torch.Tensor):
875*da0073e9SAndroid Build Coastguard Worker                return Optimizer._process_value_according_to_param_policy(
876*da0073e9SAndroid Build Coastguard Worker                    param, value, param_id, param_groups, key
877*da0073e9SAndroid Build Coastguard Worker                )
878*da0073e9SAndroid Build Coastguard Worker            elif isinstance(value, dict):
879*da0073e9SAndroid Build Coastguard Worker                return {
880*da0073e9SAndroid Build Coastguard Worker                    k: _cast(
881*da0073e9SAndroid Build Coastguard Worker                        param, v, param_id=param_id, param_groups=param_groups, key=k
882*da0073e9SAndroid Build Coastguard Worker                    )
883*da0073e9SAndroid Build Coastguard Worker                    for k, v in value.items()
884*da0073e9SAndroid Build Coastguard Worker                }
885*da0073e9SAndroid Build Coastguard Worker            elif isinstance(value, Iterable):
886*da0073e9SAndroid Build Coastguard Worker                return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value)  # type: ignore[call-arg]
887*da0073e9SAndroid Build Coastguard Worker            else:
888*da0073e9SAndroid Build Coastguard Worker                return value
889*da0073e9SAndroid Build Coastguard Worker
890*da0073e9SAndroid Build Coastguard Worker        # Copy state assigned to params (and cast tensors to appropriate types).
891*da0073e9SAndroid Build Coastguard Worker        # State that is not assigned to params is copied as is (needed for
892*da0073e9SAndroid Build Coastguard Worker        # backward compatibility).
893*da0073e9SAndroid Build Coastguard Worker        state: DefaultDict[torch.Tensor, Dict[Any, Any]] = defaultdict(dict)
894*da0073e9SAndroid Build Coastguard Worker        for k, v in state_dict["state"].items():
895*da0073e9SAndroid Build Coastguard Worker            if k in id_map:
896*da0073e9SAndroid Build Coastguard Worker                param = id_map[k]
897*da0073e9SAndroid Build Coastguard Worker                state[param] = _cast(
898*da0073e9SAndroid Build Coastguard Worker                    param, v, param_id=k, param_groups=state_dict["param_groups"]
899*da0073e9SAndroid Build Coastguard Worker                )
900*da0073e9SAndroid Build Coastguard Worker            else:
901*da0073e9SAndroid Build Coastguard Worker                state[k] = v
902*da0073e9SAndroid Build Coastguard Worker
903*da0073e9SAndroid Build Coastguard Worker        # Update parameter groups, setting their 'params' value
904*da0073e9SAndroid Build Coastguard Worker        def update_group(
905*da0073e9SAndroid Build Coastguard Worker            group: Dict[str, Any], new_group: Dict[str, Any]
906*da0073e9SAndroid Build Coastguard Worker        ) -> Dict[str, Any]:
907*da0073e9SAndroid Build Coastguard Worker            new_group["params"] = group["params"]
908*da0073e9SAndroid Build Coastguard Worker            return new_group
909*da0073e9SAndroid Build Coastguard Worker
910*da0073e9SAndroid Build Coastguard Worker        param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
911*da0073e9SAndroid Build Coastguard Worker        self.__setstate__({"state": state, "param_groups": param_groups})
912*da0073e9SAndroid Build Coastguard Worker
913*da0073e9SAndroid Build Coastguard Worker        for post_hook in self._optimizer_load_state_dict_post_hooks.values():
914*da0073e9SAndroid Build Coastguard Worker            post_hook(self)
915*da0073e9SAndroid Build Coastguard Worker
916*da0073e9SAndroid Build Coastguard Worker    @torch._disable_dynamo
917*da0073e9SAndroid Build Coastguard Worker    def zero_grad(self, set_to_none: bool = True) -> None:
918*da0073e9SAndroid Build Coastguard Worker        r"""Reset the gradients of all optimized :class:`torch.Tensor` s.
919*da0073e9SAndroid Build Coastguard Worker
920*da0073e9SAndroid Build Coastguard Worker        Args:
921*da0073e9SAndroid Build Coastguard Worker            set_to_none (bool): instead of setting to zero, set the grads to None.
922*da0073e9SAndroid Build Coastguard Worker                This will in general have lower memory footprint, and can modestly improve performance.
923*da0073e9SAndroid Build Coastguard Worker                However, it changes certain behaviors. For example:
924*da0073e9SAndroid Build Coastguard Worker                1. When the user tries to access a gradient and perform manual ops on it,
925*da0073e9SAndroid Build Coastguard Worker                a None attribute or a Tensor full of 0s will behave differently.
926*da0073e9SAndroid Build Coastguard Worker                2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s
927*da0073e9SAndroid Build Coastguard Worker                are guaranteed to be None for params that did not receive a gradient.
928*da0073e9SAndroid Build Coastguard Worker                3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None
929*da0073e9SAndroid Build Coastguard Worker                (in one case it does the step with a gradient of 0 and in the other it skips
930*da0073e9SAndroid Build Coastguard Worker                the step altogether).
931*da0073e9SAndroid Build Coastguard Worker        """
932*da0073e9SAndroid Build Coastguard Worker        foreach = self.defaults.get("foreach", False) or self.defaults.get(
933*da0073e9SAndroid Build Coastguard Worker            "fused", False
934*da0073e9SAndroid Build Coastguard Worker        )
935*da0073e9SAndroid Build Coastguard Worker
936*da0073e9SAndroid Build Coastguard Worker        if not hasattr(self, "_zero_grad_profile_name"):
937*da0073e9SAndroid Build Coastguard Worker            self._patch_step_function()
938*da0073e9SAndroid Build Coastguard Worker
939*da0073e9SAndroid Build Coastguard Worker        per_device_and_dtype_grads: Optional[
940*da0073e9SAndroid Build Coastguard Worker            DefaultDict[torch.device, DefaultDict[torch.dtype, List[torch.Tensor]]]
941*da0073e9SAndroid Build Coastguard Worker        ]
942*da0073e9SAndroid Build Coastguard Worker        if foreach:
943*da0073e9SAndroid Build Coastguard Worker            per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list))
944*da0073e9SAndroid Build Coastguard Worker        else:
945*da0073e9SAndroid Build Coastguard Worker            per_device_and_dtype_grads = None
946*da0073e9SAndroid Build Coastguard Worker
947*da0073e9SAndroid Build Coastguard Worker        with torch.autograd.profiler.record_function(self._zero_grad_profile_name):
948*da0073e9SAndroid Build Coastguard Worker            for group in self.param_groups:
949*da0073e9SAndroid Build Coastguard Worker                for p in group["params"]:
950*da0073e9SAndroid Build Coastguard Worker                    if p.grad is not None:
951*da0073e9SAndroid Build Coastguard Worker                        if set_to_none:
952*da0073e9SAndroid Build Coastguard Worker                            p.grad = None
953*da0073e9SAndroid Build Coastguard Worker                        else:
954*da0073e9SAndroid Build Coastguard Worker                            if p.grad.grad_fn is not None:
955*da0073e9SAndroid Build Coastguard Worker                                p.grad.detach_()
956*da0073e9SAndroid Build Coastguard Worker                            else:
957*da0073e9SAndroid Build Coastguard Worker                                p.grad.requires_grad_(False)
958*da0073e9SAndroid Build Coastguard Worker                            if not foreach or p.grad.is_sparse:
959*da0073e9SAndroid Build Coastguard Worker                                p.grad.zero_()
960*da0073e9SAndroid Build Coastguard Worker                            else:
961*da0073e9SAndroid Build Coastguard Worker                                assert per_device_and_dtype_grads is not None
962*da0073e9SAndroid Build Coastguard Worker                                per_device_and_dtype_grads[p.grad.device][
963*da0073e9SAndroid Build Coastguard Worker                                    p.grad.dtype
964*da0073e9SAndroid Build Coastguard Worker                                ].append(p.grad)
965*da0073e9SAndroid Build Coastguard Worker            if foreach:
966*da0073e9SAndroid Build Coastguard Worker                assert per_device_and_dtype_grads is not None
967*da0073e9SAndroid Build Coastguard Worker                for per_dtype_grads in per_device_and_dtype_grads.values():
968*da0073e9SAndroid Build Coastguard Worker                    for grads in per_dtype_grads.values():
969*da0073e9SAndroid Build Coastguard Worker                        torch._foreach_zero_(grads)
970*da0073e9SAndroid Build Coastguard Worker
971*da0073e9SAndroid Build Coastguard Worker    @overload
972*da0073e9SAndroid Build Coastguard Worker    def step(self, closure: None = ...) -> None:
973*da0073e9SAndroid Build Coastguard Worker        ...
974*da0073e9SAndroid Build Coastguard Worker
975*da0073e9SAndroid Build Coastguard Worker    @overload
976*da0073e9SAndroid Build Coastguard Worker    def step(self, closure: Callable[[], float]) -> float:
977*da0073e9SAndroid Build Coastguard Worker        ...
978*da0073e9SAndroid Build Coastguard Worker
979*da0073e9SAndroid Build Coastguard Worker    def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
980*da0073e9SAndroid Build Coastguard Worker        r"""Perform a single optimization step to update parameter.
981*da0073e9SAndroid Build Coastguard Worker
982*da0073e9SAndroid Build Coastguard Worker        Args:
983*da0073e9SAndroid Build Coastguard Worker            closure (Callable): A closure that reevaluates the model and
984*da0073e9SAndroid Build Coastguard Worker                returns the loss. Optional for most optimizers.
985*da0073e9SAndroid Build Coastguard Worker
986*da0073e9SAndroid Build Coastguard Worker        .. note::
987*da0073e9SAndroid Build Coastguard Worker            Unless otherwise specified, this function should not modify the
988*da0073e9SAndroid Build Coastguard Worker            ``.grad`` field of the parameters.
989*da0073e9SAndroid Build Coastguard Worker        """
990*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError
991*da0073e9SAndroid Build Coastguard Worker
992*da0073e9SAndroid Build Coastguard Worker    @torch._disable_dynamo
993*da0073e9SAndroid Build Coastguard Worker    def add_param_group(self, param_group: Dict[str, Any]) -> None:
994*da0073e9SAndroid Build Coastguard Worker        r"""Add a param group to the :class:`Optimizer` s `param_groups`.
995*da0073e9SAndroid Build Coastguard Worker
996*da0073e9SAndroid Build Coastguard Worker        This can be useful when fine tuning a pre-trained network as frozen layers can be made
997*da0073e9SAndroid Build Coastguard Worker        trainable and added to the :class:`Optimizer` as training progresses.
998*da0073e9SAndroid Build Coastguard Worker
999*da0073e9SAndroid Build Coastguard Worker        Args:
1000*da0073e9SAndroid Build Coastguard Worker            param_group (dict): Specifies what Tensors should be optimized along with group
1001*da0073e9SAndroid Build Coastguard Worker                specific optimization options.
1002*da0073e9SAndroid Build Coastguard Worker        """
1003*da0073e9SAndroid Build Coastguard Worker        if not isinstance(param_group, dict):
1004*da0073e9SAndroid Build Coastguard Worker            raise TypeError(f"param_group must be a dict, but got {type(param_group)}")
1005*da0073e9SAndroid Build Coastguard Worker
1006*da0073e9SAndroid Build Coastguard Worker        params = param_group["params"]
1007*da0073e9SAndroid Build Coastguard Worker        if isinstance(params, torch.Tensor):
1008*da0073e9SAndroid Build Coastguard Worker            param_group["params"] = [params]
1009*da0073e9SAndroid Build Coastguard Worker        elif isinstance(params, set):
1010*da0073e9SAndroid Build Coastguard Worker            raise TypeError(
1011*da0073e9SAndroid Build Coastguard Worker                "optimizer parameters need to be organized in ordered collections, but "
1012*da0073e9SAndroid Build Coastguard Worker                "the ordering of tensors in sets will change between runs. Please use a list instead."
1013*da0073e9SAndroid Build Coastguard Worker            )
1014*da0073e9SAndroid Build Coastguard Worker        else:
1015*da0073e9SAndroid Build Coastguard Worker            param_group["params"] = list(params)
1016*da0073e9SAndroid Build Coastguard Worker
1017*da0073e9SAndroid Build Coastguard Worker        for param in param_group["params"]:
1018*da0073e9SAndroid Build Coastguard Worker            if not isinstance(param, torch.Tensor):
1019*da0073e9SAndroid Build Coastguard Worker                raise TypeError(
1020*da0073e9SAndroid Build Coastguard Worker                    "optimizer can only optimize Tensors, "
1021*da0073e9SAndroid Build Coastguard Worker                    "but one of the params is " + torch.typename(param)
1022*da0073e9SAndroid Build Coastguard Worker                )
1023*da0073e9SAndroid Build Coastguard Worker            if not self.defaults.get("differentiable", None) and not (
1024*da0073e9SAndroid Build Coastguard Worker                param.is_leaf or param.retains_grad
1025*da0073e9SAndroid Build Coastguard Worker            ):
1026*da0073e9SAndroid Build Coastguard Worker                raise ValueError("can't optimize a non-leaf Tensor")
1027*da0073e9SAndroid Build Coastguard Worker
1028*da0073e9SAndroid Build Coastguard Worker        for name, default in self.defaults.items():
1029*da0073e9SAndroid Build Coastguard Worker            if default is required and name not in param_group:
1030*da0073e9SAndroid Build Coastguard Worker                raise ValueError(
1031*da0073e9SAndroid Build Coastguard Worker                    f"parameter group didn't specify a value of required optimization parameter {name}"
1032*da0073e9SAndroid Build Coastguard Worker                )
1033*da0073e9SAndroid Build Coastguard Worker            else:
1034*da0073e9SAndroid Build Coastguard Worker                param_group.setdefault(name, default)
1035*da0073e9SAndroid Build Coastguard Worker
1036*da0073e9SAndroid Build Coastguard Worker        params = param_group["params"]
1037*da0073e9SAndroid Build Coastguard Worker        if len(params) != len(set(params)):
1038*da0073e9SAndroid Build Coastguard Worker            warnings.warn(
1039*da0073e9SAndroid Build Coastguard Worker                "optimizer contains a parameter group with duplicate parameters; "
1040*da0073e9SAndroid Build Coastguard Worker                "in future, this will cause an error; "
1041*da0073e9SAndroid Build Coastguard Worker                "see github.com/pytorch/pytorch/issues/40967 for more information",
1042*da0073e9SAndroid Build Coastguard Worker                stacklevel=3,
1043*da0073e9SAndroid Build Coastguard Worker            )
1044*da0073e9SAndroid Build Coastguard Worker
1045*da0073e9SAndroid Build Coastguard Worker        param_set: Set[torch.Tensor] = set()
1046*da0073e9SAndroid Build Coastguard Worker        for group in self.param_groups:
1047*da0073e9SAndroid Build Coastguard Worker            param_set.update(set(group["params"]))
1048*da0073e9SAndroid Build Coastguard Worker
1049*da0073e9SAndroid Build Coastguard Worker        if not param_set.isdisjoint(set(param_group["params"])):
1050*da0073e9SAndroid Build Coastguard Worker            raise ValueError("some parameters appear in more than one parameter group")
1051*da0073e9SAndroid Build Coastguard Worker
1052*da0073e9SAndroid Build Coastguard Worker        self.param_groups.append(param_group)
1053