1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport contextlib 4*da0073e9SAndroid Build Coastguard Workerimport functools 5*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Callable, Iterator, List, Optional, Tuple, TypeVar, Union 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Workerimport torchgen.local as local 8*da0073e9SAndroid Build Coastguard Workerfrom torchgen.model import ( 9*da0073e9SAndroid Build Coastguard Worker BackendIndex, 10*da0073e9SAndroid Build Coastguard Worker DispatchKey, 11*da0073e9SAndroid Build Coastguard Worker NativeFunction, 12*da0073e9SAndroid Build Coastguard Worker NativeFunctionsGroup, 13*da0073e9SAndroid Build Coastguard Worker NativeFunctionsViewGroup, 14*da0073e9SAndroid Build Coastguard Worker) 15*da0073e9SAndroid Build Coastguard Workerfrom torchgen.utils import context, S, T 16*da0073e9SAndroid Build Coastguard Worker 17*da0073e9SAndroid Build Coastguard Worker 18*da0073e9SAndroid Build Coastguard Worker# Helper functions for defining generators on things in the model 19*da0073e9SAndroid Build Coastguard Worker 20*da0073e9SAndroid Build Coastguard WorkerF = TypeVar( 21*da0073e9SAndroid Build Coastguard Worker "F", 22*da0073e9SAndroid Build Coastguard Worker NativeFunction, 23*da0073e9SAndroid Build Coastguard Worker NativeFunctionsGroup, 24*da0073e9SAndroid Build Coastguard Worker NativeFunctionsViewGroup, 25*da0073e9SAndroid Build Coastguard Worker Union[NativeFunction, NativeFunctionsGroup], 26*da0073e9SAndroid Build Coastguard Worker Union[NativeFunction, NativeFunctionsViewGroup], 27*da0073e9SAndroid Build Coastguard Worker) 28*da0073e9SAndroid Build Coastguard Worker 29*da0073e9SAndroid Build Coastguard WorkerF2 = TypeVar( 30*da0073e9SAndroid Build Coastguard Worker "F2", 31*da0073e9SAndroid Build Coastguard Worker NativeFunction, 32*da0073e9SAndroid Build Coastguard Worker NativeFunctionsGroup, 33*da0073e9SAndroid Build Coastguard Worker Optional[NativeFunction], 34*da0073e9SAndroid Build Coastguard Worker bool, 35*da0073e9SAndroid Build Coastguard Worker str, 36*da0073e9SAndroid Build Coastguard Worker) 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard WorkerF3 = TypeVar("F3", Tuple[NativeFunction, Any], List[NativeFunction]) 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker 41*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager 42*da0073e9SAndroid Build Coastguard Workerdef native_function_manager( 43*da0073e9SAndroid Build Coastguard Worker g: NativeFunctionsGroup | NativeFunctionsViewGroup | NativeFunction, 44*da0073e9SAndroid Build Coastguard Worker) -> Iterator[None]: 45*da0073e9SAndroid Build Coastguard Worker if isinstance(g, NativeFunctionsGroup): 46*da0073e9SAndroid Build Coastguard Worker # By default, we associate all errors with structured native functions 47*da0073e9SAndroid Build Coastguard Worker # with the out variant. In some cases, it might be better to have 48*da0073e9SAndroid Build Coastguard Worker # a more specific place to hang things; if so, use 49*da0073e9SAndroid Build Coastguard Worker # native_function_manager again on the inside 50*da0073e9SAndroid Build Coastguard Worker f = g.out 51*da0073e9SAndroid Build Coastguard Worker elif isinstance(g, NativeFunctionsViewGroup): 52*da0073e9SAndroid Build Coastguard Worker # We associate errors with the view operator 53*da0073e9SAndroid Build Coastguard Worker f = g.view 54*da0073e9SAndroid Build Coastguard Worker else: 55*da0073e9SAndroid Build Coastguard Worker f = g 56*da0073e9SAndroid Build Coastguard Worker with context(lambda: f"in native_functions.yaml line {f.loc}:\n {f.func}"): 57*da0073e9SAndroid Build Coastguard Worker with local.parametrize( 58*da0073e9SAndroid Build Coastguard Worker use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors, 59*da0073e9SAndroid Build Coastguard Worker use_ilistref_for_tensor_lists=f.part_of_structured_group, 60*da0073e9SAndroid Build Coastguard Worker ): 61*da0073e9SAndroid Build Coastguard Worker yield 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker 64*da0073e9SAndroid Build Coastguard Worker# Given a function that operates on NativeFunction, wrap it into a new function 65*da0073e9SAndroid Build Coastguard Worker# that sets some appropriate context managers for that native function. 66*da0073e9SAndroid Build Coastguard Worker# YOU MUST WRAP FUNCTIONS IN THIS for calls to api modules to be sound 67*da0073e9SAndroid Build Coastguard Worker# (you will get an error if we try to access the local variables without having 68*da0073e9SAndroid Build Coastguard Worker# set them). 69*da0073e9SAndroid Build Coastguard Workerdef with_native_function(func: Callable[[F], T]) -> Callable[[F], T]: 70*da0073e9SAndroid Build Coastguard Worker @functools.wraps(func) 71*da0073e9SAndroid Build Coastguard Worker def wrapper(f: F) -> T: 72*da0073e9SAndroid Build Coastguard Worker with native_function_manager(f): 73*da0073e9SAndroid Build Coastguard Worker return func(f) 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker return wrapper 76*da0073e9SAndroid Build Coastguard Worker 77*da0073e9SAndroid Build Coastguard Worker 78*da0073e9SAndroid Build Coastguard Workerdef with_native_function_and(func: Callable[[F, F2], T]) -> Callable[[F, F2], T]: 79*da0073e9SAndroid Build Coastguard Worker @functools.wraps(func) 80*da0073e9SAndroid Build Coastguard Worker def wrapper(f: F, f2: F2) -> T: 81*da0073e9SAndroid Build Coastguard Worker # The first native_function is assumed to be the one with the appropriate context. 82*da0073e9SAndroid Build Coastguard Worker with native_function_manager(f): 83*da0073e9SAndroid Build Coastguard Worker return func(f, f2) 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker return wrapper 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker 88*da0073e9SAndroid Build Coastguard Workerdef method_with_native_function(func: Callable[[S, F], T]) -> Callable[[S, F], T]: 89*da0073e9SAndroid Build Coastguard Worker @functools.wraps(func) 90*da0073e9SAndroid Build Coastguard Worker def wrapper(slf: S, f: F) -> T: 91*da0073e9SAndroid Build Coastguard Worker with native_function_manager(f): 92*da0073e9SAndroid Build Coastguard Worker return func(slf, f) 93*da0073e9SAndroid Build Coastguard Worker 94*da0073e9SAndroid Build Coastguard Worker return wrapper 95*da0073e9SAndroid Build Coastguard Worker 96*da0073e9SAndroid Build Coastguard Worker 97*da0073e9SAndroid Build Coastguard Workerdef method_with_nested_native_function( 98*da0073e9SAndroid Build Coastguard Worker func: Callable[[S, F3], T] 99*da0073e9SAndroid Build Coastguard Worker) -> Callable[[S, F3], T]: 100*da0073e9SAndroid Build Coastguard Worker @functools.wraps(func) 101*da0073e9SAndroid Build Coastguard Worker def wrapper(slf: S, f: F3) -> T: 102*da0073e9SAndroid Build Coastguard Worker with native_function_manager(f[0]): 103*da0073e9SAndroid Build Coastguard Worker return func(slf, f) 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker return wrapper 106*da0073e9SAndroid Build Coastguard Worker 107*da0073e9SAndroid Build Coastguard Worker 108*da0073e9SAndroid Build Coastguard Worker# Convenience decorator for functions that explicitly take in a BackendIndex, 109*da0073e9SAndroid Build Coastguard Worker# instead of indirectly taking one in as a closure 110*da0073e9SAndroid Build Coastguard Workerdef with_native_function_and_index( 111*da0073e9SAndroid Build Coastguard Worker func: Callable[[F, BackendIndex], T] 112*da0073e9SAndroid Build Coastguard Worker) -> Callable[[F, BackendIndex], T]: 113*da0073e9SAndroid Build Coastguard Worker @functools.wraps(func) 114*da0073e9SAndroid Build Coastguard Worker def wrapper(f: F, backend_index: BackendIndex) -> T: 115*da0073e9SAndroid Build Coastguard Worker with native_function_manager(f): 116*da0073e9SAndroid Build Coastguard Worker return func(f, backend_index) 117*da0073e9SAndroid Build Coastguard Worker 118*da0073e9SAndroid Build Coastguard Worker return wrapper 119*da0073e9SAndroid Build Coastguard Worker 120*da0073e9SAndroid Build Coastguard Worker 121*da0073e9SAndroid Build Coastguard Worker# Convenience decorator for functions that explicitly take in a Dict of BackendIndices 122*da0073e9SAndroid Build Coastguard Workerdef with_native_function_and_indices( 123*da0073e9SAndroid Build Coastguard Worker func: Callable[[F, dict[DispatchKey, BackendIndex]], T] 124*da0073e9SAndroid Build Coastguard Worker) -> Callable[[F, dict[DispatchKey, BackendIndex]], T]: 125*da0073e9SAndroid Build Coastguard Worker @functools.wraps(func) 126*da0073e9SAndroid Build Coastguard Worker def wrapper(f: F, backend_indices: dict[DispatchKey, BackendIndex]) -> T: 127*da0073e9SAndroid Build Coastguard Worker with native_function_manager(f): 128*da0073e9SAndroid Build Coastguard Worker return func(f, backend_indices) 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Worker return wrapper 131