xref: /aosp_15_r20/external/pytorch/torchgen/context.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport contextlib
4*da0073e9SAndroid Build Coastguard Workerimport functools
5*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Callable, Iterator, List, Optional, Tuple, TypeVar, Union
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Workerimport torchgen.local as local
8*da0073e9SAndroid Build Coastguard Workerfrom torchgen.model import (
9*da0073e9SAndroid Build Coastguard Worker    BackendIndex,
10*da0073e9SAndroid Build Coastguard Worker    DispatchKey,
11*da0073e9SAndroid Build Coastguard Worker    NativeFunction,
12*da0073e9SAndroid Build Coastguard Worker    NativeFunctionsGroup,
13*da0073e9SAndroid Build Coastguard Worker    NativeFunctionsViewGroup,
14*da0073e9SAndroid Build Coastguard Worker)
15*da0073e9SAndroid Build Coastguard Workerfrom torchgen.utils import context, S, T
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker# Helper functions for defining generators on things in the model
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard WorkerF = TypeVar(
21*da0073e9SAndroid Build Coastguard Worker    "F",
22*da0073e9SAndroid Build Coastguard Worker    NativeFunction,
23*da0073e9SAndroid Build Coastguard Worker    NativeFunctionsGroup,
24*da0073e9SAndroid Build Coastguard Worker    NativeFunctionsViewGroup,
25*da0073e9SAndroid Build Coastguard Worker    Union[NativeFunction, NativeFunctionsGroup],
26*da0073e9SAndroid Build Coastguard Worker    Union[NativeFunction, NativeFunctionsViewGroup],
27*da0073e9SAndroid Build Coastguard Worker)
28*da0073e9SAndroid Build Coastguard Worker
29*da0073e9SAndroid Build Coastguard WorkerF2 = TypeVar(
30*da0073e9SAndroid Build Coastguard Worker    "F2",
31*da0073e9SAndroid Build Coastguard Worker    NativeFunction,
32*da0073e9SAndroid Build Coastguard Worker    NativeFunctionsGroup,
33*da0073e9SAndroid Build Coastguard Worker    Optional[NativeFunction],
34*da0073e9SAndroid Build Coastguard Worker    bool,
35*da0073e9SAndroid Build Coastguard Worker    str,
36*da0073e9SAndroid Build Coastguard Worker)
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard WorkerF3 = TypeVar("F3", Tuple[NativeFunction, Any], List[NativeFunction])
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker
41*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager
42*da0073e9SAndroid Build Coastguard Workerdef native_function_manager(
43*da0073e9SAndroid Build Coastguard Worker    g: NativeFunctionsGroup | NativeFunctionsViewGroup | NativeFunction,
44*da0073e9SAndroid Build Coastguard Worker) -> Iterator[None]:
45*da0073e9SAndroid Build Coastguard Worker    if isinstance(g, NativeFunctionsGroup):
46*da0073e9SAndroid Build Coastguard Worker        # By default, we associate all errors with structured native functions
47*da0073e9SAndroid Build Coastguard Worker        # with the out variant.  In some cases, it might be better to have
48*da0073e9SAndroid Build Coastguard Worker        # a more specific place to hang things; if so, use
49*da0073e9SAndroid Build Coastguard Worker        # native_function_manager again on the inside
50*da0073e9SAndroid Build Coastguard Worker        f = g.out
51*da0073e9SAndroid Build Coastguard Worker    elif isinstance(g, NativeFunctionsViewGroup):
52*da0073e9SAndroid Build Coastguard Worker        # We associate errors with the view operator
53*da0073e9SAndroid Build Coastguard Worker        f = g.view
54*da0073e9SAndroid Build Coastguard Worker    else:
55*da0073e9SAndroid Build Coastguard Worker        f = g
56*da0073e9SAndroid Build Coastguard Worker    with context(lambda: f"in native_functions.yaml line {f.loc}:\n  {f.func}"):
57*da0073e9SAndroid Build Coastguard Worker        with local.parametrize(
58*da0073e9SAndroid Build Coastguard Worker            use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors,
59*da0073e9SAndroid Build Coastguard Worker            use_ilistref_for_tensor_lists=f.part_of_structured_group,
60*da0073e9SAndroid Build Coastguard Worker        ):
61*da0073e9SAndroid Build Coastguard Worker            yield
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker
64*da0073e9SAndroid Build Coastguard Worker# Given a function that operates on NativeFunction, wrap it into a new function
65*da0073e9SAndroid Build Coastguard Worker# that sets some appropriate context managers for that native function.
66*da0073e9SAndroid Build Coastguard Worker# YOU MUST WRAP FUNCTIONS IN THIS for calls to api modules to be sound
67*da0073e9SAndroid Build Coastguard Worker# (you will get an error if we try to access the local variables without having
68*da0073e9SAndroid Build Coastguard Worker# set them).
69*da0073e9SAndroid Build Coastguard Workerdef with_native_function(func: Callable[[F], T]) -> Callable[[F], T]:
70*da0073e9SAndroid Build Coastguard Worker    @functools.wraps(func)
71*da0073e9SAndroid Build Coastguard Worker    def wrapper(f: F) -> T:
72*da0073e9SAndroid Build Coastguard Worker        with native_function_manager(f):
73*da0073e9SAndroid Build Coastguard Worker            return func(f)
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker    return wrapper
76*da0073e9SAndroid Build Coastguard Worker
77*da0073e9SAndroid Build Coastguard Worker
78*da0073e9SAndroid Build Coastguard Workerdef with_native_function_and(func: Callable[[F, F2], T]) -> Callable[[F, F2], T]:
79*da0073e9SAndroid Build Coastguard Worker    @functools.wraps(func)
80*da0073e9SAndroid Build Coastguard Worker    def wrapper(f: F, f2: F2) -> T:
81*da0073e9SAndroid Build Coastguard Worker        # The first native_function is assumed to be the one with the appropriate context.
82*da0073e9SAndroid Build Coastguard Worker        with native_function_manager(f):
83*da0073e9SAndroid Build Coastguard Worker            return func(f, f2)
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker    return wrapper
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker
88*da0073e9SAndroid Build Coastguard Workerdef method_with_native_function(func: Callable[[S, F], T]) -> Callable[[S, F], T]:
89*da0073e9SAndroid Build Coastguard Worker    @functools.wraps(func)
90*da0073e9SAndroid Build Coastguard Worker    def wrapper(slf: S, f: F) -> T:
91*da0073e9SAndroid Build Coastguard Worker        with native_function_manager(f):
92*da0073e9SAndroid Build Coastguard Worker            return func(slf, f)
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker    return wrapper
95*da0073e9SAndroid Build Coastguard Worker
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Workerdef method_with_nested_native_function(
98*da0073e9SAndroid Build Coastguard Worker    func: Callable[[S, F3], T]
99*da0073e9SAndroid Build Coastguard Worker) -> Callable[[S, F3], T]:
100*da0073e9SAndroid Build Coastguard Worker    @functools.wraps(func)
101*da0073e9SAndroid Build Coastguard Worker    def wrapper(slf: S, f: F3) -> T:
102*da0073e9SAndroid Build Coastguard Worker        with native_function_manager(f[0]):
103*da0073e9SAndroid Build Coastguard Worker            return func(slf, f)
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker    return wrapper
106*da0073e9SAndroid Build Coastguard Worker
107*da0073e9SAndroid Build Coastguard Worker
108*da0073e9SAndroid Build Coastguard Worker# Convenience decorator for functions that explicitly take in a BackendIndex,
109*da0073e9SAndroid Build Coastguard Worker# instead of indirectly taking one in as a closure
110*da0073e9SAndroid Build Coastguard Workerdef with_native_function_and_index(
111*da0073e9SAndroid Build Coastguard Worker    func: Callable[[F, BackendIndex], T]
112*da0073e9SAndroid Build Coastguard Worker) -> Callable[[F, BackendIndex], T]:
113*da0073e9SAndroid Build Coastguard Worker    @functools.wraps(func)
114*da0073e9SAndroid Build Coastguard Worker    def wrapper(f: F, backend_index: BackendIndex) -> T:
115*da0073e9SAndroid Build Coastguard Worker        with native_function_manager(f):
116*da0073e9SAndroid Build Coastguard Worker            return func(f, backend_index)
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker    return wrapper
119*da0073e9SAndroid Build Coastguard Worker
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker# Convenience decorator for functions that explicitly take in a Dict of BackendIndices
122*da0073e9SAndroid Build Coastguard Workerdef with_native_function_and_indices(
123*da0073e9SAndroid Build Coastguard Worker    func: Callable[[F, dict[DispatchKey, BackendIndex]], T]
124*da0073e9SAndroid Build Coastguard Worker) -> Callable[[F, dict[DispatchKey, BackendIndex]], T]:
125*da0073e9SAndroid Build Coastguard Worker    @functools.wraps(func)
126*da0073e9SAndroid Build Coastguard Worker    def wrapper(f: F, backend_indices: dict[DispatchKey, BackendIndex]) -> T:
127*da0073e9SAndroid Build Coastguard Worker        with native_function_manager(f):
128*da0073e9SAndroid Build Coastguard Worker            return func(f, backend_indices)
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker    return wrapper
131