1# mypy: allow-untyped-defs 2from typing import Type 3 4from torch import optim 5 6from .functional_adadelta import _FunctionalAdadelta 7from .functional_adagrad import _FunctionalAdagrad 8from .functional_adam import _FunctionalAdam 9from .functional_adamax import _FunctionalAdamax 10from .functional_adamw import _FunctionalAdamW 11from .functional_rmsprop import _FunctionalRMSprop 12from .functional_rprop import _FunctionalRprop 13from .functional_sgd import _FunctionalSGD 14 15 16# dict to map a user passed in optimizer_class to a functional 17# optimizer class if we have already defined inside the 18# distributed.optim package, this is so that we hide the 19# functional optimizer to user and still provide the same API. 20functional_optim_map = { 21 optim.Adagrad: _FunctionalAdagrad, 22 optim.Adam: _FunctionalAdam, 23 optim.AdamW: _FunctionalAdamW, 24 optim.SGD: _FunctionalSGD, 25 optim.Adadelta: _FunctionalAdadelta, 26 optim.RMSprop: _FunctionalRMSprop, 27 optim.Rprop: _FunctionalRprop, 28 optim.Adamax: _FunctionalAdamax, 29} 30 31 32def register_functional_optim(key, optim): 33 """ 34 Interface to insert a new functional optimizer to functional_optim_map 35 ``fn_optim_key`` and ``fn_optimizer`` are user defined. The optimizer and key 36 need not be of :class:`torch.optim.Optimizer` (e.g. for custom optimizers) 37 Example:: 38 >>> # import the new functional optimizer 39 >>> # xdoctest: +SKIP 40 >>> from xyz import fn_optimizer 41 >>> from torch.distributed.optim.utils import register_functional_optim 42 >>> fn_optim_key = "XYZ_optim" 43 >>> register_functional_optim(fn_optim_key, fn_optimizer) 44 """ 45 if key not in functional_optim_map: 46 functional_optim_map[key] = optim 47 48 49def as_functional_optim(optim_cls: Type, *args, **kwargs): 50 try: 51 functional_cls = functional_optim_map[optim_cls] 52 except KeyError as e: 53 raise ValueError( 54 f"Optimizer {optim_cls} does not have a functional " f"counterpart!" 55 ) from e 56 57 return _create_functional_optim(functional_cls, *args, **kwargs) 58 59 60def _create_functional_optim(functional_optim_cls: Type, *args, **kwargs): 61 return functional_optim_cls( 62 [], 63 *args, 64 **kwargs, 65 _allow_empty_param_list=True, 66 ) 67