xref: /aosp_15_r20/external/pytorch/torch/distributed/_composable_state.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from typing import cast, Dict, Optional
2
3import torch.nn as nn
4
5
6class _State:
7    pass
8
9
10_module_state_mapping: Dict[nn.Module, _State] = {}
11
12
13def _insert_module_state(module: nn.Module, state: _State) -> None:
14    global _module_state_mapping
15    assert module not in _module_state_mapping, f"Inserting {module} more than once."
16    _module_state_mapping[module] = state
17
18
19def _get_module_state(module: nn.Module) -> Optional[_State]:
20    """
21    Return the ``_State`` in ``model``.
22
23    Given a ``module``, this API finds out if the module is also a ``_State``
24    instance or if the module is managed by a composable API. If the module
25    is also a ``_State``, ``module`` will be casted to ``_State` and returned.
26    If it is managed by a composable API, the corresponding ``_State`` will
27    be returned.
28    """
29    global _module_state_mapping
30    if isinstance(module, _State):
31        return cast(_State, module)
32    else:
33        # https://github.com/pytorch/pytorch/issues/107054
34        if module in _module_state_mapping:
35            return _module_state_mapping[module]
36        else:
37            return None
38