1from typing import cast, Dict, Optional 2 3import torch.nn as nn 4 5 6class _State: 7 pass 8 9 10_module_state_mapping: Dict[nn.Module, _State] = {} 11 12 13def _insert_module_state(module: nn.Module, state: _State) -> None: 14 global _module_state_mapping 15 assert module not in _module_state_mapping, f"Inserting {module} more than once." 16 _module_state_mapping[module] = state 17 18 19def _get_module_state(module: nn.Module) -> Optional[_State]: 20 """ 21 Return the ``_State`` in ``model``. 22 23 Given a ``module``, this API finds out if the module is also a ``_State`` 24 instance or if the module is managed by a composable API. If the module 25 is also a ``_State``, ``module`` will be casted to ``_State` and returned. 26 If it is managed by a composable API, the corresponding ``_State`` will 27 be returned. 28 """ 29 global _module_state_mapping 30 if isinstance(module, _State): 31 return cast(_State, module) 32 else: 33 # https://github.com/pytorch/pytorch/issues/107054 34 if module in _module_state_mapping: 35 return _module_state_mapping[module] 36 else: 37 return None 38