xref: /aosp_15_r20/external/pytorch/torch/_VF.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker"""
2*da0073e9SAndroid Build Coastguard WorkerThis makes the functions in torch._C._VariableFunctions available as
3*da0073e9SAndroid Build Coastguard Worker    torch._VF.<funcname>
4*da0073e9SAndroid Build Coastguard Workerwithout mypy being able to find them.
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard WorkerA subset of those functions are mapped to ATen functions in
7*da0073e9SAndroid Build Coastguard Workertorch/jit/_builtins.py
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard WorkerSee https://github.com/pytorch/pytorch/issues/21478 for the reason for
10*da0073e9SAndroid Build Coastguard Workerintroducing torch._VF
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Worker"""
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Workerimport sys
15*da0073e9SAndroid Build Coastguard Workerimport types
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Workerimport torch
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Workerclass VFModule(types.ModuleType):
21*da0073e9SAndroid Build Coastguard Worker    vf: types.ModuleType
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker    def __init__(self, name: str):
24*da0073e9SAndroid Build Coastguard Worker        super().__init__(name)
25*da0073e9SAndroid Build Coastguard Worker        self.vf = torch._C._VariableFunctions
26*da0073e9SAndroid Build Coastguard Worker
27*da0073e9SAndroid Build Coastguard Worker    def __getattr__(self, name: str) -> object:
28*da0073e9SAndroid Build Coastguard Worker        return getattr(self.vf, name)
29*da0073e9SAndroid Build Coastguard Worker
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Workersys.modules[__name__] = VFModule(__name__)
32