1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api import cpp 4*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.types import Binding, CppSignatureGroup, CType 5*da0073e9SAndroid Build Coastguard Workerfrom torchgen.model import ( 6*da0073e9SAndroid Build Coastguard Worker Argument, 7*da0073e9SAndroid Build Coastguard Worker BaseTy, 8*da0073e9SAndroid Build Coastguard Worker BaseType, 9*da0073e9SAndroid Build Coastguard Worker ListType, 10*da0073e9SAndroid Build Coastguard Worker NativeFunction, 11*da0073e9SAndroid Build Coastguard Worker OptionalType, 12*da0073e9SAndroid Build Coastguard Worker Type, 13*da0073e9SAndroid Build Coastguard Worker) 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker# This file generates the code for unboxing wrappers, i.e., the glue logic to unbox a boxed operator and convert the 17*da0073e9SAndroid Build Coastguard Worker# ivalues from stack to correct arguments to the unboxed kernel, based on corresponding JIT schema. This codegen is 18*da0073e9SAndroid Build Coastguard Worker# an alternative way to generate unboxing wrappers similar to the existing C++ metaprogramming approach but gets the 19*da0073e9SAndroid Build Coastguard Worker# job done statically. These generated unboxing wrappers will be useful under the scenario where we need to register 20*da0073e9SAndroid Build Coastguard Worker# a fixed set of operators known at compile time and thus can save some time in runtime initialization phase. 21*da0073e9SAndroid Build Coastguard Worker# 22*da0073e9SAndroid Build Coastguard Worker# Here's an example on how the codegen works: 23*da0073e9SAndroid Build Coastguard Worker# 24*da0073e9SAndroid Build Coastguard Worker# - Function Schema (source of truth) 25*da0073e9SAndroid Build Coastguard Worker# 26*da0073e9SAndroid Build Coastguard Worker# aten::empty.names(int[] size, *, Dimname[]? names, 27*da0073e9SAndroid Build Coastguard Worker# ScalarType? dtype=None, Layout? layout=None, 28*da0073e9SAndroid Build Coastguard Worker# Device? device=None, bool? pin_memory=None, 29*da0073e9SAndroid Build Coastguard Worker# MemoryFormat? memory_format=None) -> Tensor 30*da0073e9SAndroid Build Coastguard Worker# - Argument Conversion 31*da0073e9SAndroid Build Coastguard Worker# Generates C++ code to convert an ivalue (from stack) to its underlying C++ type. 32*da0073e9SAndroid Build Coastguard Worker# - int[] size 33*da0073e9SAndroid Build Coastguard Worker# ```cpp 34*da0073e9SAndroid Build Coastguard Worker# const c10::List<c10::IValue> size_list_in = (std::move(peek(stack, 0, 7))).toList(); 35*da0073e9SAndroid Build Coastguard Worker# 36*da0073e9SAndroid Build Coastguard Worker# std::vector<int64_t> size_vec; 37*da0073e9SAndroid Build Coastguard Worker# for (c10::IValue size_elem: size_list_in) { 38*da0073e9SAndroid Build Coastguard Worker# int64_t size_base = size_elem.to<int64_t>(); 39*da0073e9SAndroid Build Coastguard Worker# size_vec.push_back(size_base); 40*da0073e9SAndroid Build Coastguard Worker# } 41*da0073e9SAndroid Build Coastguard Worker# at::ArrayRef<int64_t> size_list_out(size_vec); 42*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~ <-- The converted argument from ivalues in the stack. 43*da0073e9SAndroid Build Coastguard Worker# Will be passed to unboxed kernel. 44*da0073e9SAndroid Build Coastguard Worker# ``` 45*da0073e9SAndroid Build Coastguard Worker# - Dimname[]? names 46*da0073e9SAndroid Build Coastguard Worker# ```cpp 47*da0073e9SAndroid Build Coastguard Worker# ::std::optional<c10::IValue> names_opt = (std::move(peek(stack, 1, 7))).toOptional<c10::IValue>(); 48*da0073e9SAndroid Build Coastguard Worker# ::std::optional<at::ArrayRef<at::Dimname>> names_opt_out; 49*da0073e9SAndroid Build Coastguard Worker# if (names_opt.has_value()) { 50*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~ <-- Unwrapping optional shell 51*da0073e9SAndroid Build Coastguard Worker# const c10::IValue names_opt_in = names_opt.value(); 52*da0073e9SAndroid Build Coastguard Worker# const c10::List<c10::IValue> names_list_in = names_opt_in.toList(); 53*da0073e9SAndroid Build Coastguard Worker# 54*da0073e9SAndroid Build Coastguard Worker# std::vector<at::Dimname> names_vec; 55*da0073e9SAndroid Build Coastguard Worker# for (c10::IValue names_elem: names_list_in) { 56*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~~~~~~ <-- Unrolling list, then convert elements one by one. 57*da0073e9SAndroid Build Coastguard Worker# at::Dimname names_base = names_elem.to<at::Dimname>(); 58*da0073e9SAndroid Build Coastguard Worker# names_vec.push_back(names_base); 59*da0073e9SAndroid Build Coastguard Worker# } 60*da0073e9SAndroid Build Coastguard Worker# at::ArrayRef<at::Dimname> names_list_out(names_vec); 61*da0073e9SAndroid Build Coastguard Worker# 62*da0073e9SAndroid Build Coastguard Worker# names_opt_out = ::std::optional<at::ArrayRef<at::Dimname>>(names_list_out); 63*da0073e9SAndroid Build Coastguard Worker# } else { 64*da0073e9SAndroid Build Coastguard Worker# names_opt_out = ::std::optional<at::ArrayRef<at::Dimname>>(); 65*da0073e9SAndroid Build Coastguard Worker# } 66*da0073e9SAndroid Build Coastguard Worker# ``` 67*da0073e9SAndroid Build Coastguard Worker# - ScalarType? dtype (similarly for the rest of the arguments) 68*da0073e9SAndroid Build Coastguard Worker# ```cpp 69*da0073e9SAndroid Build Coastguard Worker# ::std::optional<c10::IValue> dtype_opt = (std::move(peek(stack, 2, 7))).toOptional<c10::IValue>(); 70*da0073e9SAndroid Build Coastguard Worker# ::std::optional<at::ScalarType> dtype_opt_out; 71*da0073e9SAndroid Build Coastguard Worker# if (dtype_opt.has_value()) { 72*da0073e9SAndroid Build Coastguard Worker# const c10::IValue dtype_opt_in = dtype_opt.value(); 73*da0073e9SAndroid Build Coastguard Worker# at::ScalarType dtype_base = dtype_opt_in.to<at::ScalarType>(); 74*da0073e9SAndroid Build Coastguard Worker# ~~~~~~~~~~~~~~~~~~~~ <-- For base types, convert ivalue to it 75*da0073e9SAndroid Build Coastguard Worker# directly using ".to<T>()" API. 76*da0073e9SAndroid Build Coastguard Worker# dtype_opt_out = ::std::optional<at::ScalarType>(dtype_base); 77*da0073e9SAndroid Build Coastguard Worker# } else { 78*da0073e9SAndroid Build Coastguard Worker# dtype_opt_out = ::std::optional<at::ScalarType>(); 79*da0073e9SAndroid Build Coastguard Worker# } 80*da0073e9SAndroid Build Coastguard Worker# ``` 81*da0073e9SAndroid Build Coastguard Worker# 82*da0073e9SAndroid Build Coastguard Worker# - Unboxed Kernel Call 83*da0073e9SAndroid Build Coastguard Worker# ```cpp 84*da0073e9SAndroid Build Coastguard Worker# auto result_ = torch::empty( 85*da0073e9SAndroid Build Coastguard Worker# size_list_out, 86*da0073e9SAndroid Build Coastguard Worker# names_opt_out, 87*da0073e9SAndroid Build Coastguard Worker# options, 88*da0073e9SAndroid Build Coastguard Worker# memory_format_opt_out 89*da0073e9SAndroid Build Coastguard Worker# ); 90*da0073e9SAndroid Build Coastguard Worker# ``` 91*da0073e9SAndroid Build Coastguard Worker# 92*da0073e9SAndroid Build Coastguard Worker# - Push Result Back to Stack 93*da0073e9SAndroid Build Coastguard Worker# ```cpp 94*da0073e9SAndroid Build Coastguard Worker# drop(stack, 7); 95*da0073e9SAndroid Build Coastguard Worker# pack(stack, std::move(result_)); 96*da0073e9SAndroid Build Coastguard Worker# ``` 97*da0073e9SAndroid Build Coastguard Workerconnector = "\n\t" 98*da0073e9SAndroid Build Coastguard Worker 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker# Return unboxing function name for a NativeFunction 101*da0073e9SAndroid Build Coastguard Workerdef name(f: NativeFunction) -> str: 102*da0073e9SAndroid Build Coastguard Worker return f.func.name.unambiguous_name() 103*da0073e9SAndroid Build Coastguard Worker 104*da0073e9SAndroid Build Coastguard Worker 105*da0073e9SAndroid Build Coastguard Worker# Convert all the arguments in a NativeFunction to C++ code 106*da0073e9SAndroid Build Coastguard Workerdef convert_arguments(f: NativeFunction) -> tuple[list[Binding], list[str]]: 107*da0073e9SAndroid Build Coastguard Worker # we need the 'self' argument so method needs to be False 108*da0073e9SAndroid Build Coastguard Worker args = ( 109*da0073e9SAndroid Build Coastguard Worker CppSignatureGroup.from_native_function(f, method=False) 110*da0073e9SAndroid Build Coastguard Worker .most_faithful_signature() 111*da0073e9SAndroid Build Coastguard Worker .arguments() 112*da0073e9SAndroid Build Coastguard Worker ) 113*da0073e9SAndroid Build Coastguard Worker code_list = [ 114*da0073e9SAndroid Build Coastguard Worker f"c10::IValue {args[i].name} = std::move(peek(stack, {i}, {len(args)}));" 115*da0073e9SAndroid Build Coastguard Worker for i in range(len(args)) 116*da0073e9SAndroid Build Coastguard Worker ] + [""] 117*da0073e9SAndroid Build Coastguard Worker binding_list = [] 118*da0073e9SAndroid Build Coastguard Worker for arg in args: 119*da0073e9SAndroid Build Coastguard Worker # expecting only Argument 120*da0073e9SAndroid Build Coastguard Worker if not isinstance(arg.argument, Argument): 121*da0073e9SAndroid Build Coastguard Worker raise Exception( # noqa: TRY002 122*da0073e9SAndroid Build Coastguard Worker f"Unexpected argument type, expecting `Argument` but got {arg}" 123*da0073e9SAndroid Build Coastguard Worker ) 124*da0073e9SAndroid Build Coastguard Worker argument: Argument = arg.argument 125*da0073e9SAndroid Build Coastguard Worker unboxed_name, _, code, decl = argumenttype_ivalue_convert( 126*da0073e9SAndroid Build Coastguard Worker argument.type, 127*da0073e9SAndroid Build Coastguard Worker argument.name, 128*da0073e9SAndroid Build Coastguard Worker mutable=argument.is_write, 129*da0073e9SAndroid Build Coastguard Worker ) 130*da0073e9SAndroid Build Coastguard Worker code_list.extend(decl) 131*da0073e9SAndroid Build Coastguard Worker code_list.extend(code) 132*da0073e9SAndroid Build Coastguard Worker binding_list.append(arg.with_name(unboxed_name)) 133*da0073e9SAndroid Build Coastguard Worker return binding_list, code_list 134*da0073e9SAndroid Build Coastguard Worker 135*da0073e9SAndroid Build Coastguard Worker 136*da0073e9SAndroid Build Coastguard Worker# Takes in the type, name and mutability corresponding to an argument, and generates a tuple of: 137*da0073e9SAndroid Build Coastguard Worker# (1) the C++ code necessary to unbox the argument 138*da0073e9SAndroid Build Coastguard Worker# (2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType 139*da0073e9SAndroid Build Coastguard Workerdef argumenttype_ivalue_convert( 140*da0073e9SAndroid Build Coastguard Worker t: Type, arg_name: str, *, mutable: bool = False 141*da0073e9SAndroid Build Coastguard Worker) -> tuple[str, CType, list[str], list[str]]: 142*da0073e9SAndroid Build Coastguard Worker # Unboxing is for mobile, which doesn't care about SymInts 143*da0073e9SAndroid Build Coastguard Worker ctype = cpp.argumenttype_type( 144*da0073e9SAndroid Build Coastguard Worker t=t, mutable=mutable, binds=arg_name, symint=False 145*da0073e9SAndroid Build Coastguard Worker ).type 146*da0073e9SAndroid Build Coastguard Worker 147*da0073e9SAndroid Build Coastguard Worker if isinstance(t, BaseType): 148*da0073e9SAndroid Build Coastguard Worker out_name = f"{arg_name}_base" 149*da0073e9SAndroid Build Coastguard Worker code, decl = _gen_code_base_type( 150*da0073e9SAndroid Build Coastguard Worker arg_name=arg_name, out_name=out_name, ctype=ctype 151*da0073e9SAndroid Build Coastguard Worker ) 152*da0073e9SAndroid Build Coastguard Worker elif isinstance(t, OptionalType): 153*da0073e9SAndroid Build Coastguard Worker out_name = f"{arg_name}_opt_out" 154*da0073e9SAndroid Build Coastguard Worker code, decl = _gen_code_optional_type( 155*da0073e9SAndroid Build Coastguard Worker arg_name=arg_name, 156*da0073e9SAndroid Build Coastguard Worker out_name=out_name, 157*da0073e9SAndroid Build Coastguard Worker t=t, 158*da0073e9SAndroid Build Coastguard Worker ctype=ctype, 159*da0073e9SAndroid Build Coastguard Worker ) 160*da0073e9SAndroid Build Coastguard Worker elif isinstance(t, ListType): 161*da0073e9SAndroid Build Coastguard Worker out_name = f"{arg_name}_list_out" 162*da0073e9SAndroid Build Coastguard Worker code, decl = _gen_code_list_type( 163*da0073e9SAndroid Build Coastguard Worker arg_name=arg_name, 164*da0073e9SAndroid Build Coastguard Worker out_name=out_name, 165*da0073e9SAndroid Build Coastguard Worker t=t, 166*da0073e9SAndroid Build Coastguard Worker ctype=ctype, 167*da0073e9SAndroid Build Coastguard Worker ) 168*da0073e9SAndroid Build Coastguard Worker else: 169*da0073e9SAndroid Build Coastguard Worker raise Exception(f"Cannot handle type {t}. arg_name: {arg_name}") # noqa: TRY002 170*da0073e9SAndroid Build Coastguard Worker return out_name, ctype, code, decl 171*da0073e9SAndroid Build Coastguard Worker 172*da0073e9SAndroid Build Coastguard Worker 173*da0073e9SAndroid Build Coastguard Workerdef _gen_code_base_type( 174*da0073e9SAndroid Build Coastguard Worker arg_name: str, out_name: str, ctype: CType 175*da0073e9SAndroid Build Coastguard Worker) -> tuple[list[str], list[str]]: 176*da0073e9SAndroid Build Coastguard Worker return [ 177*da0073e9SAndroid Build Coastguard Worker f"{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();" 178*da0073e9SAndroid Build Coastguard Worker ], [] 179*da0073e9SAndroid Build Coastguard Worker 180*da0073e9SAndroid Build Coastguard Worker 181*da0073e9SAndroid Build Coastguard Workerdef _gen_code_optional_type( 182*da0073e9SAndroid Build Coastguard Worker arg_name: str, out_name: str, t: OptionalType, ctype: CType 183*da0073e9SAndroid Build Coastguard Worker) -> tuple[list[str], list[str]]: 184*da0073e9SAndroid Build Coastguard Worker in_name = f"{arg_name}_opt_in" 185*da0073e9SAndroid Build Coastguard Worker res_name, _, res_code, decl = argumenttype_ivalue_convert(t.elem, in_name) 186*da0073e9SAndroid Build Coastguard Worker return ( 187*da0073e9SAndroid Build Coastguard Worker f""" 188*da0073e9SAndroid Build Coastguard Workerauto {arg_name}_opt = {arg_name}.toOptional<c10::IValue>(); 189*da0073e9SAndroid Build Coastguard Worker{ctype.cpp_type(strip_ref=True)} {out_name}; 190*da0073e9SAndroid Build Coastguard Workerif ({arg_name}_opt.has_value()) {{ 191*da0073e9SAndroid Build Coastguard Worker const c10::IValue {in_name} = {arg_name}_opt.value(); 192*da0073e9SAndroid Build Coastguard Worker {connector.join(res_code)} 193*da0073e9SAndroid Build Coastguard Worker {out_name} = {ctype.cpp_type(strip_ref=True)}({res_name}); 194*da0073e9SAndroid Build Coastguard Worker}} else {{ 195*da0073e9SAndroid Build Coastguard Worker {out_name} = {ctype.cpp_type(strip_ref=True)}(); 196*da0073e9SAndroid Build Coastguard Worker}} 197*da0073e9SAndroid Build Coastguard Worker """.split( 198*da0073e9SAndroid Build Coastguard Worker "\n" 199*da0073e9SAndroid Build Coastguard Worker ), 200*da0073e9SAndroid Build Coastguard Worker decl, 201*da0073e9SAndroid Build Coastguard Worker ) 202*da0073e9SAndroid Build Coastguard Worker 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Workerdef _gen_code_list_type( 205*da0073e9SAndroid Build Coastguard Worker arg_name: str, out_name: str, t: ListType, ctype: CType 206*da0073e9SAndroid Build Coastguard Worker) -> tuple[list[str], list[str]]: 207*da0073e9SAndroid Build Coastguard Worker in_name = f"{arg_name}_list_in" 208*da0073e9SAndroid Build Coastguard Worker elem_name = f"{arg_name}_elem" 209*da0073e9SAndroid Build Coastguard Worker code = [f"const c10::List<c10::IValue> {in_name} = {arg_name}.toList();"] 210*da0073e9SAndroid Build Coastguard Worker res_name, res_ctype, res_code, decl = argumenttype_ivalue_convert(t.elem, elem_name) 211*da0073e9SAndroid Build Coastguard Worker # handle list type with size, e.g., bool[4] 212*da0073e9SAndroid Build Coastguard Worker if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool and t.size: 213*da0073e9SAndroid Build Coastguard Worker code.extend( 214*da0073e9SAndroid Build Coastguard Worker f""" 215*da0073e9SAndroid Build Coastguard Worker{ctype.cpp_type(strip_ref=True)} {out_name} = as_array<{res_ctype.cpp_type(strip_ref=True)}, {t.size}>({in_name}); 216*da0073e9SAndroid Build Coastguard Worker """.split( 217*da0073e9SAndroid Build Coastguard Worker "\n" 218*da0073e9SAndroid Build Coastguard Worker ) 219*da0073e9SAndroid Build Coastguard Worker ) 220*da0073e9SAndroid Build Coastguard Worker # we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional<at::Tensor>> 221*da0073e9SAndroid Build Coastguard Worker elif isinstance(t.elem, OptionalType): 222*da0073e9SAndroid Build Coastguard Worker code.extend( 223*da0073e9SAndroid Build Coastguard Worker f""" 224*da0073e9SAndroid Build Coastguard Worker{ctype.cpp_type(strip_ref=True)} {out_name}; 225*da0073e9SAndroid Build Coastguard Workerfor (c10::IValue {elem_name}: {in_name}) {{ 226*da0073e9SAndroid Build Coastguard Worker {connector.join(res_code)} 227*da0073e9SAndroid Build Coastguard Worker {out_name}.push_back({res_name}); 228*da0073e9SAndroid Build Coastguard Worker}} 229*da0073e9SAndroid Build Coastguard Worker """.split( 230*da0073e9SAndroid Build Coastguard Worker "\n" 231*da0073e9SAndroid Build Coastguard Worker ) 232*da0073e9SAndroid Build Coastguard Worker ) 233*da0073e9SAndroid Build Coastguard Worker else: 234*da0073e9SAndroid Build Coastguard Worker # use ArrayRef as default. 235*da0073e9SAndroid Build Coastguard Worker vec_name = arg_name + "_vec" 236*da0073e9SAndroid Build Coastguard Worker # need to bring vector instantiation out of scope so that ArrayRef has valid data 237*da0073e9SAndroid Build Coastguard Worker decl.append(f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};") 238*da0073e9SAndroid Build Coastguard Worker code.extend( 239*da0073e9SAndroid Build Coastguard Worker f""" 240*da0073e9SAndroid Build Coastguard Workerfor (c10::IValue {elem_name}: {in_name}) {{ 241*da0073e9SAndroid Build Coastguard Worker {connector.join(res_code)} 242*da0073e9SAndroid Build Coastguard Worker {vec_name}.push_back({res_name}); 243*da0073e9SAndroid Build Coastguard Worker}} 244*da0073e9SAndroid Build Coastguard Worker{ctype.cpp_type(strip_ref=True)} {out_name}({vec_name}); 245*da0073e9SAndroid Build Coastguard Worker """.split( 246*da0073e9SAndroid Build Coastguard Worker "\n" 247*da0073e9SAndroid Build Coastguard Worker ) 248*da0073e9SAndroid Build Coastguard Worker ) 249*da0073e9SAndroid Build Coastguard Worker return code, decl 250