xref: /aosp_15_r20/external/pytorch/torch/nn/cpp.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Worker"""Functionality for Python <-> C++ frontend inter-op."""
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Workerfrom torch import nn
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerclass OrderedDictWrapper:
8*da0073e9SAndroid Build Coastguard Worker    """A wrapper around a C++ OrderedDict.
9*da0073e9SAndroid Build Coastguard Worker
10*da0073e9SAndroid Build Coastguard Worker    It dynamically evaluates the OrderedDict getter on a bound C++ module, such
11*da0073e9SAndroid Build Coastguard Worker    that new changes on the C++ side are picked up. Otherwise accessing e.g.
12*da0073e9SAndroid Build Coastguard Worker    ``cpp_module._parameters`` just once would get a frozen copy of the parameters
13*da0073e9SAndroid Build Coastguard Worker    at the time of access. ``torch.nn.Module`` accesses ``_parameters`` et al. via ``self.__dict__``
14*da0073e9SAndroid Build Coastguard Worker    so using properties does not work.
15*da0073e9SAndroid Build Coastguard Worker    """
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker    def __init__(self, cpp_module, attr):
18*da0073e9SAndroid Build Coastguard Worker        self.cpp_module = cpp_module
19*da0073e9SAndroid Build Coastguard Worker        self.attr = attr
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker    @property
22*da0073e9SAndroid Build Coastguard Worker    def cpp_dict(self):
23*da0073e9SAndroid Build Coastguard Worker        return getattr(self.cpp_module, self.attr)
24*da0073e9SAndroid Build Coastguard Worker
25*da0073e9SAndroid Build Coastguard Worker    # Magic methods cannot be assigned dynamically and bypass ``getattr``, so we
26*da0073e9SAndroid Build Coastguard Worker    # must manually override them.
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker    def items(self):
29*da0073e9SAndroid Build Coastguard Worker        return self.cpp_dict.items()
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker    def keys(self):
32*da0073e9SAndroid Build Coastguard Worker        return self.cpp_dict.keys()
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker    def values(self):
35*da0073e9SAndroid Build Coastguard Worker        return self.cpp_dict.values()
36*da0073e9SAndroid Build Coastguard Worker
37*da0073e9SAndroid Build Coastguard Worker    def __iter__(self):
38*da0073e9SAndroid Build Coastguard Worker        return self.cpp_dict.__iter__()
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker    def __len__(self):
41*da0073e9SAndroid Build Coastguard Worker        return self.cpp_dict.__len__()
42*da0073e9SAndroid Build Coastguard Worker
43*da0073e9SAndroid Build Coastguard Worker    def __contains__(self, key):
44*da0073e9SAndroid Build Coastguard Worker        return self.cpp_dict.__contains__(key)
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Worker    def __getitem__(self, key):
47*da0073e9SAndroid Build Coastguard Worker        return self.cpp_dict.__getitem__(key)
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Workerclass ModuleWrapper(nn.Module):
51*da0073e9SAndroid Build Coastguard Worker    """A subclass of ``torch.nn.Module`` that wraps a C++ frontend module and delegates all access."""
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker    def __init__(self, cpp_module):
54*da0073e9SAndroid Build Coastguard Worker        # Assign before the super class constructor so ``self.training`` can be
55*da0073e9SAndroid Build Coastguard Worker        # assigned to in the super class constructor.
56*da0073e9SAndroid Build Coastguard Worker        self.cpp_module = cpp_module
57*da0073e9SAndroid Build Coastguard Worker        super().__init__()
58*da0073e9SAndroid Build Coastguard Worker        self._parameters = OrderedDictWrapper(cpp_module, "_parameters")  # type: ignore[assignment]
59*da0073e9SAndroid Build Coastguard Worker        self._buffers: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_buffers")  # type: ignore[assignment]
60*da0073e9SAndroid Build Coastguard Worker        self._modules: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_modules")  # type: ignore[assignment]
61*da0073e9SAndroid Build Coastguard Worker        for attr in dir(cpp_module):
62*da0073e9SAndroid Build Coastguard Worker            # Skip magic methods and the three attributes above.
63*da0073e9SAndroid Build Coastguard Worker            if not attr.startswith("_"):
64*da0073e9SAndroid Build Coastguard Worker                setattr(self, attr, getattr(self.cpp_module, attr))
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker    def _apply(self, fn, recurse=True):
67*da0073e9SAndroid Build Coastguard Worker        for param in self.parameters():
68*da0073e9SAndroid Build Coastguard Worker            # Tensors stored in modules are graph leaves, and we don't
69*da0073e9SAndroid Build Coastguard Worker            # want to create copy nodes, so we have to unpack the data.
70*da0073e9SAndroid Build Coastguard Worker            param.data = fn(param.data)
71*da0073e9SAndroid Build Coastguard Worker            if param._grad is not None:
72*da0073e9SAndroid Build Coastguard Worker                param._grad.data = fn(param._grad.data)
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker        for buf in self.buffers():
75*da0073e9SAndroid Build Coastguard Worker            buf.data = fn(buf.data)
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker        return self
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker    # nn.Module defines training as a boolean
80*da0073e9SAndroid Build Coastguard Worker    @property  # type: ignore[override]
81*da0073e9SAndroid Build Coastguard Worker    def training(self):
82*da0073e9SAndroid Build Coastguard Worker        return self.cpp_module.training
83*da0073e9SAndroid Build Coastguard Worker
84*da0073e9SAndroid Build Coastguard Worker    @training.setter
85*da0073e9SAndroid Build Coastguard Worker    def training(self, mode):
86*da0073e9SAndroid Build Coastguard Worker        self.cpp_module.train(mode)
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Worker    def __repr__(self):
89*da0073e9SAndroid Build Coastguard Worker        return self.cpp_module.__repr__()
90