1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Worker"""Functionality for Python <-> C++ frontend inter-op.""" 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Workerfrom torch import nn 5*da0073e9SAndroid Build Coastguard Worker 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerclass OrderedDictWrapper: 8*da0073e9SAndroid Build Coastguard Worker """A wrapper around a C++ OrderedDict. 9*da0073e9SAndroid Build Coastguard Worker 10*da0073e9SAndroid Build Coastguard Worker It dynamically evaluates the OrderedDict getter on a bound C++ module, such 11*da0073e9SAndroid Build Coastguard Worker that new changes on the C++ side are picked up. Otherwise accessing e.g. 12*da0073e9SAndroid Build Coastguard Worker ``cpp_module._parameters`` just once would get a frozen copy of the parameters 13*da0073e9SAndroid Build Coastguard Worker at the time of access. ``torch.nn.Module`` accesses ``_parameters`` et al. via ``self.__dict__`` 14*da0073e9SAndroid Build Coastguard Worker so using properties does not work. 15*da0073e9SAndroid Build Coastguard Worker """ 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Worker def __init__(self, cpp_module, attr): 18*da0073e9SAndroid Build Coastguard Worker self.cpp_module = cpp_module 19*da0073e9SAndroid Build Coastguard Worker self.attr = attr 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker @property 22*da0073e9SAndroid Build Coastguard Worker def cpp_dict(self): 23*da0073e9SAndroid Build Coastguard Worker return getattr(self.cpp_module, self.attr) 24*da0073e9SAndroid Build Coastguard Worker 25*da0073e9SAndroid Build Coastguard Worker # Magic methods cannot be assigned dynamically and bypass ``getattr``, so we 26*da0073e9SAndroid Build Coastguard Worker # must manually override them. 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker def items(self): 29*da0073e9SAndroid Build Coastguard Worker return self.cpp_dict.items() 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker def keys(self): 32*da0073e9SAndroid Build Coastguard Worker return self.cpp_dict.keys() 33*da0073e9SAndroid Build Coastguard Worker 34*da0073e9SAndroid Build Coastguard Worker def values(self): 35*da0073e9SAndroid Build Coastguard Worker return self.cpp_dict.values() 36*da0073e9SAndroid Build Coastguard Worker 37*da0073e9SAndroid Build Coastguard Worker def __iter__(self): 38*da0073e9SAndroid Build Coastguard Worker return self.cpp_dict.__iter__() 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker def __len__(self): 41*da0073e9SAndroid Build Coastguard Worker return self.cpp_dict.__len__() 42*da0073e9SAndroid Build Coastguard Worker 43*da0073e9SAndroid Build Coastguard Worker def __contains__(self, key): 44*da0073e9SAndroid Build Coastguard Worker return self.cpp_dict.__contains__(key) 45*da0073e9SAndroid Build Coastguard Worker 46*da0073e9SAndroid Build Coastguard Worker def __getitem__(self, key): 47*da0073e9SAndroid Build Coastguard Worker return self.cpp_dict.__getitem__(key) 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Workerclass ModuleWrapper(nn.Module): 51*da0073e9SAndroid Build Coastguard Worker """A subclass of ``torch.nn.Module`` that wraps a C++ frontend module and delegates all access.""" 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker def __init__(self, cpp_module): 54*da0073e9SAndroid Build Coastguard Worker # Assign before the super class constructor so ``self.training`` can be 55*da0073e9SAndroid Build Coastguard Worker # assigned to in the super class constructor. 56*da0073e9SAndroid Build Coastguard Worker self.cpp_module = cpp_module 57*da0073e9SAndroid Build Coastguard Worker super().__init__() 58*da0073e9SAndroid Build Coastguard Worker self._parameters = OrderedDictWrapper(cpp_module, "_parameters") # type: ignore[assignment] 59*da0073e9SAndroid Build Coastguard Worker self._buffers: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_buffers") # type: ignore[assignment] 60*da0073e9SAndroid Build Coastguard Worker self._modules: OrderedDictWrapper = OrderedDictWrapper(cpp_module, "_modules") # type: ignore[assignment] 61*da0073e9SAndroid Build Coastguard Worker for attr in dir(cpp_module): 62*da0073e9SAndroid Build Coastguard Worker # Skip magic methods and the three attributes above. 63*da0073e9SAndroid Build Coastguard Worker if not attr.startswith("_"): 64*da0073e9SAndroid Build Coastguard Worker setattr(self, attr, getattr(self.cpp_module, attr)) 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker def _apply(self, fn, recurse=True): 67*da0073e9SAndroid Build Coastguard Worker for param in self.parameters(): 68*da0073e9SAndroid Build Coastguard Worker # Tensors stored in modules are graph leaves, and we don't 69*da0073e9SAndroid Build Coastguard Worker # want to create copy nodes, so we have to unpack the data. 70*da0073e9SAndroid Build Coastguard Worker param.data = fn(param.data) 71*da0073e9SAndroid Build Coastguard Worker if param._grad is not None: 72*da0073e9SAndroid Build Coastguard Worker param._grad.data = fn(param._grad.data) 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker for buf in self.buffers(): 75*da0073e9SAndroid Build Coastguard Worker buf.data = fn(buf.data) 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker return self 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker # nn.Module defines training as a boolean 80*da0073e9SAndroid Build Coastguard Worker @property # type: ignore[override] 81*da0073e9SAndroid Build Coastguard Worker def training(self): 82*da0073e9SAndroid Build Coastguard Worker return self.cpp_module.training 83*da0073e9SAndroid Build Coastguard Worker 84*da0073e9SAndroid Build Coastguard Worker @training.setter 85*da0073e9SAndroid Build Coastguard Worker def training(self, mode): 86*da0073e9SAndroid Build Coastguard Worker self.cpp_module.train(mode) 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 89*da0073e9SAndroid Build Coastguard Worker return self.cpp_module.__repr__() 90