xref: /aosp_15_r20/external/pytorch/torch/distributed/optim/utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Type
3
4from torch import optim
5
6from .functional_adadelta import _FunctionalAdadelta
7from .functional_adagrad import _FunctionalAdagrad
8from .functional_adam import _FunctionalAdam
9from .functional_adamax import _FunctionalAdamax
10from .functional_adamw import _FunctionalAdamW
11from .functional_rmsprop import _FunctionalRMSprop
12from .functional_rprop import _FunctionalRprop
13from .functional_sgd import _FunctionalSGD
14
15
16# dict to map a user passed in optimizer_class to a functional
17# optimizer class if we have already defined inside the
18# distributed.optim package, this is so that we hide the
19# functional optimizer to user and still provide the same API.
20functional_optim_map = {
21    optim.Adagrad: _FunctionalAdagrad,
22    optim.Adam: _FunctionalAdam,
23    optim.AdamW: _FunctionalAdamW,
24    optim.SGD: _FunctionalSGD,
25    optim.Adadelta: _FunctionalAdadelta,
26    optim.RMSprop: _FunctionalRMSprop,
27    optim.Rprop: _FunctionalRprop,
28    optim.Adamax: _FunctionalAdamax,
29}
30
31
32def register_functional_optim(key, optim):
33    """
34    Interface to insert a new functional optimizer to functional_optim_map
35    ``fn_optim_key`` and ``fn_optimizer`` are user defined. The optimizer and key
36    need not be of :class:`torch.optim.Optimizer` (e.g. for custom optimizers)
37    Example::
38        >>> # import the new functional optimizer
39        >>> # xdoctest: +SKIP
40        >>> from xyz import fn_optimizer
41        >>> from torch.distributed.optim.utils import register_functional_optim
42        >>> fn_optim_key = "XYZ_optim"
43        >>> register_functional_optim(fn_optim_key, fn_optimizer)
44    """
45    if key not in functional_optim_map:
46        functional_optim_map[key] = optim
47
48
49def as_functional_optim(optim_cls: Type, *args, **kwargs):
50    try:
51        functional_cls = functional_optim_map[optim_cls]
52    except KeyError as e:
53        raise ValueError(
54            f"Optimizer {optim_cls} does not have a functional " f"counterpart!"
55        ) from e
56
57    return _create_functional_optim(functional_cls, *args, **kwargs)
58
59
60def _create_functional_optim(functional_optim_cls: Type, *args, **kwargs):
61    return functional_optim_cls(
62        [],
63        *args,
64        **kwargs,
65        _allow_empty_param_list=True,
66    )
67