1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerfrom typing import Sequence 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerfrom torchgen import local 6*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.types import ( 7*da0073e9SAndroid Build Coastguard Worker ArgName, 8*da0073e9SAndroid Build Coastguard Worker ArrayCType, 9*da0073e9SAndroid Build Coastguard Worker ArrayRefCType, 10*da0073e9SAndroid Build Coastguard Worker BaseCType, 11*da0073e9SAndroid Build Coastguard Worker BaseTypeToCppMapping, 12*da0073e9SAndroid Build Coastguard Worker Binding, 13*da0073e9SAndroid Build Coastguard Worker boolT, 14*da0073e9SAndroid Build Coastguard Worker ConstRefCType, 15*da0073e9SAndroid Build Coastguard Worker CType, 16*da0073e9SAndroid Build Coastguard Worker dimnameListT, 17*da0073e9SAndroid Build Coastguard Worker intArrayRefT, 18*da0073e9SAndroid Build Coastguard Worker iTensorListRefT, 19*da0073e9SAndroid Build Coastguard Worker ListCType, 20*da0073e9SAndroid Build Coastguard Worker longT, 21*da0073e9SAndroid Build Coastguard Worker MutRefCType, 22*da0073e9SAndroid Build Coastguard Worker NamedCType, 23*da0073e9SAndroid Build Coastguard Worker OptionalCType, 24*da0073e9SAndroid Build Coastguard Worker optionalIntArrayRefT, 25*da0073e9SAndroid Build Coastguard Worker optionalSymIntArrayRefT, 26*da0073e9SAndroid Build Coastguard Worker scalarT, 27*da0073e9SAndroid Build Coastguard Worker SpecialArgName, 28*da0073e9SAndroid Build Coastguard Worker symIntArrayRefT, 29*da0073e9SAndroid Build Coastguard Worker SymIntT, 30*da0073e9SAndroid Build Coastguard Worker tensorListT, 31*da0073e9SAndroid Build Coastguard Worker tensorOptionsT, 32*da0073e9SAndroid Build Coastguard Worker tensorT, 33*da0073e9SAndroid Build Coastguard Worker TupleCType, 34*da0073e9SAndroid Build Coastguard Worker VectorCType, 35*da0073e9SAndroid Build Coastguard Worker voidT, 36*da0073e9SAndroid Build Coastguard Worker) 37*da0073e9SAndroid Build Coastguard Workerfrom torchgen.model import ( 38*da0073e9SAndroid Build Coastguard Worker Argument, 39*da0073e9SAndroid Build Coastguard Worker Arguments, 40*da0073e9SAndroid Build Coastguard Worker BaseTy, 41*da0073e9SAndroid Build Coastguard Worker BaseType, 42*da0073e9SAndroid Build Coastguard Worker FunctionSchema, 43*da0073e9SAndroid Build Coastguard Worker ListType, 44*da0073e9SAndroid Build Coastguard Worker NativeFunction, 45*da0073e9SAndroid Build Coastguard Worker OptionalType, 46*da0073e9SAndroid Build Coastguard Worker Return, 47*da0073e9SAndroid Build Coastguard Worker SelfArgument, 48*da0073e9SAndroid Build Coastguard Worker TensorOptionsArguments, 49*da0073e9SAndroid Build Coastguard Worker Type, 50*da0073e9SAndroid Build Coastguard Worker) 51*da0073e9SAndroid Build Coastguard Workerfrom torchgen.utils import assert_never 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker 54*da0073e9SAndroid Build Coastguard Worker# This file describes the translation of JIT schema to the public C++ 55*da0073e9SAndroid Build Coastguard Worker# API, which is what people use when they call functions like at::add. 56*da0073e9SAndroid Build Coastguard Worker# 57*da0073e9SAndroid Build Coastguard Worker# Prominent characteristics of the C++ API: 58*da0073e9SAndroid Build Coastguard Worker# 59*da0073e9SAndroid Build Coastguard Worker# - dtype, layout, device and pin_memory are collected into 60*da0073e9SAndroid Build Coastguard Worker# a single C++ type TensorOptions (the native functions API 61*da0073e9SAndroid Build Coastguard Worker# also has this, but tensor options is really most relevant 62*da0073e9SAndroid Build Coastguard Worker# for the C++ API; it makes calling kwarg factory functions 63*da0073e9SAndroid Build Coastguard Worker# pleasant) 64*da0073e9SAndroid Build Coastguard Worker# 65*da0073e9SAndroid Build Coastguard Worker# - defaulting lives here (in fact, the dispatcher is completely 66*da0073e9SAndroid Build Coastguard Worker# oblivious of defaults!) 67*da0073e9SAndroid Build Coastguard Worker# 68*da0073e9SAndroid Build Coastguard Worker# BTW: policy on name collisions: we try not to have types with 69*da0073e9SAndroid Build Coastguard Worker# collisions, but functions are fair game to collide 70*da0073e9SAndroid Build Coastguard Worker 71*da0073e9SAndroid Build Coastguard Worker 72*da0073e9SAndroid Build Coastguard Workerdef name( 73*da0073e9SAndroid Build Coastguard Worker func: FunctionSchema, 74*da0073e9SAndroid Build Coastguard Worker *, 75*da0073e9SAndroid Build Coastguard Worker faithful_name_for_out_overloads: bool = False, 76*da0073e9SAndroid Build Coastguard Worker symint_overload: bool = False, 77*da0073e9SAndroid Build Coastguard Worker) -> str: 78*da0073e9SAndroid Build Coastguard Worker name = str(func.name.name) 79*da0073e9SAndroid Build Coastguard Worker if symint_overload: 80*da0073e9SAndroid Build Coastguard Worker name += "_symint" 81*da0073e9SAndroid Build Coastguard Worker if func.is_out_fn(): 82*da0073e9SAndroid Build Coastguard Worker if faithful_name_for_out_overloads: 83*da0073e9SAndroid Build Coastguard Worker name += "_outf" 84*da0073e9SAndroid Build Coastguard Worker else: 85*da0073e9SAndroid Build Coastguard Worker name += "_out" 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker return name 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker 90*da0073e9SAndroid Build Coastguard Worker# Translation of "value types" in JIT schema to C++ API type. Value 91*da0073e9SAndroid Build Coastguard Worker# types look the same no matter if they are argument types or return 92*da0073e9SAndroid Build Coastguard Worker# types. Returns None if the type in question is not a value type. 93*da0073e9SAndroid Build Coastguard Workerdef valuetype_type( 94*da0073e9SAndroid Build Coastguard Worker t: Type, 95*da0073e9SAndroid Build Coastguard Worker *, 96*da0073e9SAndroid Build Coastguard Worker binds: ArgName, 97*da0073e9SAndroid Build Coastguard Worker mutable: bool = True, 98*da0073e9SAndroid Build Coastguard Worker remove_non_owning_ref_types: bool = False, 99*da0073e9SAndroid Build Coastguard Worker symint: bool = False, 100*da0073e9SAndroid Build Coastguard Worker) -> NamedCType | None: 101*da0073e9SAndroid Build Coastguard Worker if isinstance(t, BaseType): 102*da0073e9SAndroid Build Coastguard Worker if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar: 103*da0073e9SAndroid Build Coastguard Worker return None 104*da0073e9SAndroid Build Coastguard Worker elif str(t) == "SymInt": 105*da0073e9SAndroid Build Coastguard Worker if symint: 106*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, BaseCType(SymIntT)) 107*da0073e9SAndroid Build Coastguard Worker else: 108*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, BaseCType(longT)) 109*da0073e9SAndroid Build Coastguard Worker if remove_non_owning_ref_types: 110*da0073e9SAndroid Build Coastguard Worker if t.name == BaseTy.str: 111*da0073e9SAndroid Build Coastguard Worker raise AssertionError( 112*da0073e9SAndroid Build Coastguard Worker "string ref->value conversion: not implemented yet" 113*da0073e9SAndroid Build Coastguard Worker ) 114*da0073e9SAndroid Build Coastguard Worker # All other BaseType currently map directly to BaseCppTypes. 115*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name])) 116*da0073e9SAndroid Build Coastguard Worker elif isinstance(t, OptionalType): 117*da0073e9SAndroid Build Coastguard Worker elem = valuetype_type(t.elem, binds=binds, mutable=mutable, symint=symint) 118*da0073e9SAndroid Build Coastguard Worker if elem is None: 119*da0073e9SAndroid Build Coastguard Worker return None 120*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, OptionalCType(elem.type)) 121*da0073e9SAndroid Build Coastguard Worker elif isinstance(t, ListType): 122*da0073e9SAndroid Build Coastguard Worker if str(t.elem) == "bool": 123*da0073e9SAndroid Build Coastguard Worker assert t.size is not None 124*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, ArrayCType(BaseCType(boolT), t.size)) 125*da0073e9SAndroid Build Coastguard Worker else: 126*da0073e9SAndroid Build Coastguard Worker return None 127*da0073e9SAndroid Build Coastguard Worker else: 128*da0073e9SAndroid Build Coastguard Worker raise AssertionError(f"unrecognized type {repr(t)}") 129*da0073e9SAndroid Build Coastguard Worker 130*da0073e9SAndroid Build Coastguard Worker 131*da0073e9SAndroid Build Coastguard Worker# Translation of types occurring in JIT arguments to a C++ argument type. 132*da0073e9SAndroid Build Coastguard Worker# If remove_non_owning_ref_types is set, we'll guarantee that the outputed CType is not a non-owning reference type. 133*da0073e9SAndroid Build Coastguard Worker# For example, we'll return std::vector<int> instead of IntArrayRef. 134*da0073e9SAndroid Build Coastguard Worker# See Note [translation from C++ reference to value types] 135*da0073e9SAndroid Build Coastguard Workerdef argumenttype_type( 136*da0073e9SAndroid Build Coastguard Worker t: Type, 137*da0073e9SAndroid Build Coastguard Worker *, 138*da0073e9SAndroid Build Coastguard Worker mutable: bool, 139*da0073e9SAndroid Build Coastguard Worker binds: ArgName, 140*da0073e9SAndroid Build Coastguard Worker remove_non_owning_ref_types: bool = False, 141*da0073e9SAndroid Build Coastguard Worker symint: bool = False, 142*da0073e9SAndroid Build Coastguard Worker) -> NamedCType: 143*da0073e9SAndroid Build Coastguard Worker # If it's a value type, do the value type translation 144*da0073e9SAndroid Build Coastguard Worker r = valuetype_type( 145*da0073e9SAndroid Build Coastguard Worker t, 146*da0073e9SAndroid Build Coastguard Worker binds=binds, 147*da0073e9SAndroid Build Coastguard Worker mutable=mutable, 148*da0073e9SAndroid Build Coastguard Worker symint=symint, 149*da0073e9SAndroid Build Coastguard Worker remove_non_owning_ref_types=remove_non_owning_ref_types, 150*da0073e9SAndroid Build Coastguard Worker ) 151*da0073e9SAndroid Build Coastguard Worker if r is not None: 152*da0073e9SAndroid Build Coastguard Worker return r 153*da0073e9SAndroid Build Coastguard Worker 154*da0073e9SAndroid Build Coastguard Worker if isinstance(t, BaseType): 155*da0073e9SAndroid Build Coastguard Worker if t.name == BaseTy.Tensor: 156*da0073e9SAndroid Build Coastguard Worker if mutable and not local.use_const_ref_for_mutable_tensors(): 157*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, MutRefCType(BaseCType(tensorT))) 158*da0073e9SAndroid Build Coastguard Worker else: 159*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, ConstRefCType(BaseCType(tensorT))) 160*da0073e9SAndroid Build Coastguard Worker elif t.name == BaseTy.Scalar: 161*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) 162*da0073e9SAndroid Build Coastguard Worker else: 163*da0073e9SAndroid Build Coastguard Worker raise AssertionError(f"base type should have been value type {t}") 164*da0073e9SAndroid Build Coastguard Worker elif isinstance(t, OptionalType): 165*da0073e9SAndroid Build Coastguard Worker if str(t.elem) == "Tensor": 166*da0073e9SAndroid Build Coastguard Worker if mutable and not local.use_const_ref_for_mutable_tensors(): 167*da0073e9SAndroid Build Coastguard Worker return NamedCType( 168*da0073e9SAndroid Build Coastguard Worker binds, MutRefCType(BaseCType(tensorT)) 169*da0073e9SAndroid Build Coastguard Worker ) # TODO: fix this discrepancy 170*da0073e9SAndroid Build Coastguard Worker else: 171*da0073e9SAndroid Build Coastguard Worker return NamedCType( 172*da0073e9SAndroid Build Coastguard Worker binds, ConstRefCType(OptionalCType(BaseCType(tensorT))) 173*da0073e9SAndroid Build Coastguard Worker ) 174*da0073e9SAndroid Build Coastguard Worker elif str(t.elem) == "Scalar": 175*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT)))) 176*da0073e9SAndroid Build Coastguard Worker elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int": 177*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, BaseCType(optionalIntArrayRefT)) 178*da0073e9SAndroid Build Coastguard Worker elif isinstance(t.elem, ListType) and str(t.elem.elem) == "SymInt": 179*da0073e9SAndroid Build Coastguard Worker if symint: 180*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, BaseCType(optionalSymIntArrayRefT)) 181*da0073e9SAndroid Build Coastguard Worker else: 182*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, BaseCType(optionalIntArrayRefT)) 183*da0073e9SAndroid Build Coastguard Worker elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint) 184*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, OptionalCType(elem.type)) 185*da0073e9SAndroid Build Coastguard Worker elif isinstance(t, ListType): 186*da0073e9SAndroid Build Coastguard Worker # TODO: remove these special cases, ArrayRef fallthrough works fine 187*da0073e9SAndroid Build Coastguard Worker if str(t.elem) == "int": 188*da0073e9SAndroid Build Coastguard Worker if remove_non_owning_ref_types: 189*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, VectorCType(BaseCType(longT))) 190*da0073e9SAndroid Build Coastguard Worker else: 191*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, BaseCType(intArrayRefT)) 192*da0073e9SAndroid Build Coastguard Worker if str(t.elem) == "SymInt": 193*da0073e9SAndroid Build Coastguard Worker if remove_non_owning_ref_types: 194*da0073e9SAndroid Build Coastguard Worker if symint: 195*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, VectorCType(BaseCType(SymIntT))) 196*da0073e9SAndroid Build Coastguard Worker else: 197*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, VectorCType(BaseCType(longT))) 198*da0073e9SAndroid Build Coastguard Worker else: 199*da0073e9SAndroid Build Coastguard Worker if symint: 200*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, BaseCType(symIntArrayRefT)) 201*da0073e9SAndroid Build Coastguard Worker else: 202*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, BaseCType(intArrayRefT)) 203*da0073e9SAndroid Build Coastguard Worker if str(t.elem) == "Tensor": 204*da0073e9SAndroid Build Coastguard Worker if local.use_ilistref_for_tensor_lists(): 205*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, ConstRefCType(BaseCType(iTensorListRefT))) 206*da0073e9SAndroid Build Coastguard Worker else: 207*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, BaseCType(tensorListT)) 208*da0073e9SAndroid Build Coastguard Worker elif str(t.elem) == "Scalar": 209*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, ArrayRefCType(BaseCType(scalarT))) 210*da0073e9SAndroid Build Coastguard Worker elif str(t.elem) == "Dimname": 211*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, BaseCType(dimnameListT)) 212*da0073e9SAndroid Build Coastguard Worker elif str(t.elem) == "Tensor?": 213*da0073e9SAndroid Build Coastguard Worker return NamedCType( 214*da0073e9SAndroid Build Coastguard Worker binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))) 215*da0073e9SAndroid Build Coastguard Worker ) 216*da0073e9SAndroid Build Coastguard Worker elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint) 217*da0073e9SAndroid Build Coastguard Worker return NamedCType(binds, ArrayRefCType(elem.type)) 218*da0073e9SAndroid Build Coastguard Worker else: 219*da0073e9SAndroid Build Coastguard Worker raise AssertionError(f"unrecognized type {repr(t)}") 220*da0073e9SAndroid Build Coastguard Worker 221*da0073e9SAndroid Build Coastguard Worker 222*da0073e9SAndroid Build Coastguard Worker# Translate a JIT argument into its C++ type 223*da0073e9SAndroid Build Coastguard Workerdef argument_type(a: Argument, *, binds: ArgName, symint: bool = False) -> NamedCType: 224*da0073e9SAndroid Build Coastguard Worker return argumenttype_type(a.type, mutable=a.is_write, symint=symint, binds=binds) 225*da0073e9SAndroid Build Coastguard Worker 226*da0073e9SAndroid Build Coastguard Worker 227*da0073e9SAndroid Build Coastguard Worker# Translation of a (non-multi) return type from JIT to C++ 228*da0073e9SAndroid Build Coastguard Worker# N.B: returntype_type returns a CType, not a NamedCType. 229*da0073e9SAndroid Build Coastguard Worker# This is mostly because of the mismatch between return types and return names. 230*da0073e9SAndroid Build Coastguard Worker# e.g. a function with a return type of 'void' has 0 return names, 231*da0073e9SAndroid Build Coastguard Worker# and a function with a return type of 'std::tuple' has >1 return name. 232*da0073e9SAndroid Build Coastguard Workerdef returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType: 233*da0073e9SAndroid Build Coastguard Worker # placeholder is ignored 234*da0073e9SAndroid Build Coastguard Worker # NB: symint is ALWAYS respected for return types. So symint argument 235*da0073e9SAndroid Build Coastguard Worker # here is IGNORED 236*da0073e9SAndroid Build Coastguard Worker r = valuetype_type(t, binds="__placeholder__", mutable=mutable, symint=True) 237*da0073e9SAndroid Build Coastguard Worker if r is not None: 238*da0073e9SAndroid Build Coastguard Worker return r.type 239*da0073e9SAndroid Build Coastguard Worker 240*da0073e9SAndroid Build Coastguard Worker if isinstance(t, BaseType): 241*da0073e9SAndroid Build Coastguard Worker if t.name == BaseTy.Tensor: 242*da0073e9SAndroid Build Coastguard Worker if mutable: 243*da0073e9SAndroid Build Coastguard Worker if local.use_const_ref_for_mutable_tensors(): 244*da0073e9SAndroid Build Coastguard Worker return ConstRefCType(BaseCType(tensorT)) 245*da0073e9SAndroid Build Coastguard Worker else: 246*da0073e9SAndroid Build Coastguard Worker return MutRefCType(BaseCType(tensorT)) 247*da0073e9SAndroid Build Coastguard Worker else: 248*da0073e9SAndroid Build Coastguard Worker # Note [Tensor Copy Returns] 249*da0073e9SAndroid Build Coastguard Worker # Currently, we use "Argument.is_write" to determine 250*da0073e9SAndroid Build Coastguard Worker # whether or not Tensor return types should be copies or references. 251*da0073e9SAndroid Build Coastguard Worker # If that ever changes, take a look at other locations of this note! 252*da0073e9SAndroid Build Coastguard Worker return BaseCType(tensorT) 253*da0073e9SAndroid Build Coastguard Worker elif t.name == BaseTy.Scalar: 254*da0073e9SAndroid Build Coastguard Worker return BaseCType(scalarT) 255*da0073e9SAndroid Build Coastguard Worker elif isinstance(t, ListType): 256*da0073e9SAndroid Build Coastguard Worker assert ( 257*da0073e9SAndroid Build Coastguard Worker not mutable 258*da0073e9SAndroid Build Coastguard Worker ), "Native functions should never return a mutable tensor list. They should return void." 259*da0073e9SAndroid Build Coastguard Worker elem = returntype_type(t.elem, mutable=False) 260*da0073e9SAndroid Build Coastguard Worker assert t.size is None, f"fixed size list returns not supported: {t}" 261*da0073e9SAndroid Build Coastguard Worker return VectorCType(elem) 262*da0073e9SAndroid Build Coastguard Worker elif isinstance(t, OptionalType): 263*da0073e9SAndroid Build Coastguard Worker elem = returntype_type(t.elem, mutable=mutable) 264*da0073e9SAndroid Build Coastguard Worker if str(t.elem) == "Tensor": 265*da0073e9SAndroid Build Coastguard Worker return OptionalCType(elem) 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Worker raise AssertionError(f"unrecognized return type {t}") 268*da0073e9SAndroid Build Coastguard Worker 269*da0073e9SAndroid Build Coastguard Worker 270*da0073e9SAndroid Build Coastguard Worker# Translation of a single return to its C++ type 271*da0073e9SAndroid Build Coastguard Workerdef return_type(r: Return, *, symint: bool = False) -> CType: 272*da0073e9SAndroid Build Coastguard Worker return returntype_type(r.type, mutable=r.is_write, symint=symint) 273*da0073e9SAndroid Build Coastguard Worker 274*da0073e9SAndroid Build Coastguard Worker 275*da0073e9SAndroid Build Coastguard Worker# Translation of a full (possibly multi) return from JIT to its C++ type 276*da0073e9SAndroid Build Coastguard Workerdef returns_type(rs: Sequence[Return], *, symint: bool = False) -> CType: 277*da0073e9SAndroid Build Coastguard Worker if len(rs) == 0: 278*da0073e9SAndroid Build Coastguard Worker return BaseCType(voidT) 279*da0073e9SAndroid Build Coastguard Worker elif len(rs) == 1: 280*da0073e9SAndroid Build Coastguard Worker return return_type(rs[0], symint=symint) 281*da0073e9SAndroid Build Coastguard Worker else: 282*da0073e9SAndroid Build Coastguard Worker return TupleCType([return_type(r, symint=symint) for r in rs]) 283*da0073e9SAndroid Build Coastguard Worker 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Workerdef return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]: 286*da0073e9SAndroid Build Coastguard Worker returns: list[str] = [] 287*da0073e9SAndroid Build Coastguard Worker for i, r in enumerate(f.func.returns): 288*da0073e9SAndroid Build Coastguard Worker # If we have an inplace function, the return argument is 289*da0073e9SAndroid Build Coastguard Worker # implicitly named self. 290*da0073e9SAndroid Build Coastguard Worker # TODO: Consider incorporating this into the data model 291*da0073e9SAndroid Build Coastguard Worker if f.func.name.name.inplace: 292*da0073e9SAndroid Build Coastguard Worker assert i == 0, "illegal inplace function with multiple returns" 293*da0073e9SAndroid Build Coastguard Worker name = "self" 294*da0073e9SAndroid Build Coastguard Worker # If we are out function, the name is the name of the 295*da0073e9SAndroid Build Coastguard Worker # corresponding output function (r.name will get recorded 296*da0073e9SAndroid Build Coastguard Worker # in field_name later.) 297*da0073e9SAndroid Build Coastguard Worker elif f.func.is_out_fn(): 298*da0073e9SAndroid Build Coastguard Worker name = f.func.arguments.out[i].name 299*da0073e9SAndroid Build Coastguard Worker # If the return argument is explicitly named... 300*da0073e9SAndroid Build Coastguard Worker elif r.name: 301*da0073e9SAndroid Build Coastguard Worker name_conflict = any( 302*da0073e9SAndroid Build Coastguard Worker r.name == a.name for a in f.func.schema_order_arguments() 303*da0073e9SAndroid Build Coastguard Worker ) 304*da0073e9SAndroid Build Coastguard Worker if name_conflict and not f.func.is_out_fn(): 305*da0073e9SAndroid Build Coastguard Worker name = f"{r.name}_return" 306*da0073e9SAndroid Build Coastguard Worker else: 307*da0073e9SAndroid Build Coastguard Worker name = r.name 308*da0073e9SAndroid Build Coastguard Worker # If there is no explicit name and no fallback name was passed in, we just name the output result, 309*da0073e9SAndroid Build Coastguard Worker # unless it's a multi-return, in which case it's result0, 310*da0073e9SAndroid Build Coastguard Worker # result1, etc (zero-indexed) 311*da0073e9SAndroid Build Coastguard Worker else: 312*da0073e9SAndroid Build Coastguard Worker name = fallback_name if len(f.func.returns) == 1 else f"{fallback_name}{i}" 313*da0073e9SAndroid Build Coastguard Worker returns.append(name) 314*da0073e9SAndroid Build Coastguard Worker return returns 315*da0073e9SAndroid Build Coastguard Worker 316*da0073e9SAndroid Build Coastguard Worker 317*da0073e9SAndroid Build Coastguard WorkerJIT_TO_CPP_DEFAULT = { 318*da0073e9SAndroid Build Coastguard Worker "False": "false", 319*da0073e9SAndroid Build Coastguard Worker "True": "true", 320*da0073e9SAndroid Build Coastguard Worker "None": "::std::nullopt", # UGH this one is type directed 321*da0073e9SAndroid Build Coastguard Worker "Mean": "at::Reduction::Mean", 322*da0073e9SAndroid Build Coastguard Worker "[]": "{}", 323*da0073e9SAndroid Build Coastguard Worker "contiguous_format": "c10::MemoryFormat::Contiguous", 324*da0073e9SAndroid Build Coastguard Worker "long": "at::kLong", 325*da0073e9SAndroid Build Coastguard Worker} 326*da0073e9SAndroid Build Coastguard Worker 327*da0073e9SAndroid Build Coastguard Worker 328*da0073e9SAndroid Build Coastguard Worker# Convert a JIT default into C++ expression representing the default 329*da0073e9SAndroid Build Coastguard Workerdef default_expr(d: str, t: Type, *, symint: bool) -> str: 330*da0073e9SAndroid Build Coastguard Worker if d == "None" and str(t) == "Tensor?": 331*da0073e9SAndroid Build Coastguard Worker return "{}" 332*da0073e9SAndroid Build Coastguard Worker if isinstance(t, BaseType) and t.name is BaseTy.str: 333*da0073e9SAndroid Build Coastguard Worker # Schema allows single quotes but C++ needs double 334*da0073e9SAndroid Build Coastguard Worker if len(d) >= 2 and d[0] == "'" and d[-1] == "'": 335*da0073e9SAndroid Build Coastguard Worker s = "" 336*da0073e9SAndroid Build Coastguard Worker i = 1 337*da0073e9SAndroid Build Coastguard Worker while i + 1 < len(d): 338*da0073e9SAndroid Build Coastguard Worker if d[i] != "\\": 339*da0073e9SAndroid Build Coastguard Worker if d[i] == '"': 340*da0073e9SAndroid Build Coastguard Worker s += '\\"' 341*da0073e9SAndroid Build Coastguard Worker else: 342*da0073e9SAndroid Build Coastguard Worker s += d[i] 343*da0073e9SAndroid Build Coastguard Worker i += 1 344*da0073e9SAndroid Build Coastguard Worker else: 345*da0073e9SAndroid Build Coastguard Worker if d[i + 1] == "'": 346*da0073e9SAndroid Build Coastguard Worker s += "'" 347*da0073e9SAndroid Build Coastguard Worker else: 348*da0073e9SAndroid Build Coastguard Worker s += d[i : i + 2] 349*da0073e9SAndroid Build Coastguard Worker i += 2 350*da0073e9SAndroid Build Coastguard Worker 351*da0073e9SAndroid Build Coastguard Worker return f'"{s}"' 352*da0073e9SAndroid Build Coastguard Worker 353*da0073e9SAndroid Build Coastguard Worker if isinstance(t, OptionalType): 354*da0073e9SAndroid Build Coastguard Worker if d == "None": 355*da0073e9SAndroid Build Coastguard Worker return "::std::nullopt" 356*da0073e9SAndroid Build Coastguard Worker 357*da0073e9SAndroid Build Coastguard Worker return default_expr(d, t.elem, symint=symint) 358*da0073e9SAndroid Build Coastguard Worker 359*da0073e9SAndroid Build Coastguard Worker if isinstance(t, ListType): 360*da0073e9SAndroid Build Coastguard Worker if d.startswith("[") and d.endswith("]"): 361*da0073e9SAndroid Build Coastguard Worker return "{" + d[1:-1] + "}" 362*da0073e9SAndroid Build Coastguard Worker elif symint and d.isdigit() and str(t.elem) == "SymInt": 363*da0073e9SAndroid Build Coastguard Worker return f"c10::SymInt({d})" 364*da0073e9SAndroid Build Coastguard Worker elif t.size is None: 365*da0073e9SAndroid Build Coastguard Worker # NOTE: Sized lists can have scalar defaults 366*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"Expected a list default '[...]' but found: '{d}'") 367*da0073e9SAndroid Build Coastguard Worker 368*da0073e9SAndroid Build Coastguard Worker return JIT_TO_CPP_DEFAULT.get(d, d) 369*da0073e9SAndroid Build Coastguard Worker 370*da0073e9SAndroid Build Coastguard Worker 371*da0073e9SAndroid Build Coastguard Worker# Convert an argument into its C++ API form 372*da0073e9SAndroid Build Coastguard Worker 373*da0073e9SAndroid Build Coastguard Worker 374*da0073e9SAndroid Build Coastguard Workerdef argument( 375*da0073e9SAndroid Build Coastguard Worker a: Argument | TensorOptionsArguments | SelfArgument, 376*da0073e9SAndroid Build Coastguard Worker *, 377*da0073e9SAndroid Build Coastguard Worker cpp_no_default_args: set[str], 378*da0073e9SAndroid Build Coastguard Worker method: bool, 379*da0073e9SAndroid Build Coastguard Worker faithful: bool, 380*da0073e9SAndroid Build Coastguard Worker symint: bool = False, 381*da0073e9SAndroid Build Coastguard Worker has_tensor_options: bool, 382*da0073e9SAndroid Build Coastguard Worker) -> list[Binding]: 383*da0073e9SAndroid Build Coastguard Worker def sub_argument( 384*da0073e9SAndroid Build Coastguard Worker a: Argument | TensorOptionsArguments | SelfArgument, 385*da0073e9SAndroid Build Coastguard Worker ) -> list[Binding]: 386*da0073e9SAndroid Build Coastguard Worker return argument( 387*da0073e9SAndroid Build Coastguard Worker a, 388*da0073e9SAndroid Build Coastguard Worker cpp_no_default_args=cpp_no_default_args, 389*da0073e9SAndroid Build Coastguard Worker method=method, 390*da0073e9SAndroid Build Coastguard Worker faithful=faithful, 391*da0073e9SAndroid Build Coastguard Worker symint=symint, 392*da0073e9SAndroid Build Coastguard Worker has_tensor_options=has_tensor_options, 393*da0073e9SAndroid Build Coastguard Worker ) 394*da0073e9SAndroid Build Coastguard Worker 395*da0073e9SAndroid Build Coastguard Worker if isinstance(a, Argument): 396*da0073e9SAndroid Build Coastguard Worker binds: ArgName 397*da0073e9SAndroid Build Coastguard Worker if a.name == "memory_format" and has_tensor_options: 398*da0073e9SAndroid Build Coastguard Worker binds = SpecialArgName.possibly_redundant_memory_format 399*da0073e9SAndroid Build Coastguard Worker else: 400*da0073e9SAndroid Build Coastguard Worker binds = a.name 401*da0073e9SAndroid Build Coastguard Worker default: str | None = None 402*da0073e9SAndroid Build Coastguard Worker if a.name not in cpp_no_default_args and a.default is not None: 403*da0073e9SAndroid Build Coastguard Worker default = default_expr(a.default, a.type, symint=symint) 404*da0073e9SAndroid Build Coastguard Worker return [ 405*da0073e9SAndroid Build Coastguard Worker Binding( 406*da0073e9SAndroid Build Coastguard Worker nctype=argument_type(a, binds=binds, symint=symint), 407*da0073e9SAndroid Build Coastguard Worker name=a.name, 408*da0073e9SAndroid Build Coastguard Worker default=default, 409*da0073e9SAndroid Build Coastguard Worker argument=a, 410*da0073e9SAndroid Build Coastguard Worker ) 411*da0073e9SAndroid Build Coastguard Worker ] 412*da0073e9SAndroid Build Coastguard Worker elif isinstance(a, TensorOptionsArguments): 413*da0073e9SAndroid Build Coastguard Worker if faithful: 414*da0073e9SAndroid Build Coastguard Worker return ( 415*da0073e9SAndroid Build Coastguard Worker sub_argument(a.dtype) 416*da0073e9SAndroid Build Coastguard Worker + sub_argument(a.layout) 417*da0073e9SAndroid Build Coastguard Worker + sub_argument(a.device) 418*da0073e9SAndroid Build Coastguard Worker + sub_argument(a.pin_memory) 419*da0073e9SAndroid Build Coastguard Worker ) 420*da0073e9SAndroid Build Coastguard Worker else: 421*da0073e9SAndroid Build Coastguard Worker default = None 422*da0073e9SAndroid Build Coastguard Worker # Enforced by NativeFunction.__post_init__ 423*da0073e9SAndroid Build Coastguard Worker assert "options" not in cpp_no_default_args 424*da0073e9SAndroid Build Coastguard Worker if all(x.default == "None" for x in a.all()): 425*da0073e9SAndroid Build Coastguard Worker default = "{}" 426*da0073e9SAndroid Build Coastguard Worker elif a.dtype.default == "long": 427*da0073e9SAndroid Build Coastguard Worker default = "at::kLong" # TODO: this is wrong 428*da0073e9SAndroid Build Coastguard Worker return [ 429*da0073e9SAndroid Build Coastguard Worker Binding( 430*da0073e9SAndroid Build Coastguard Worker nctype=NamedCType("options", BaseCType(tensorOptionsT)), 431*da0073e9SAndroid Build Coastguard Worker name="options", 432*da0073e9SAndroid Build Coastguard Worker default=default, 433*da0073e9SAndroid Build Coastguard Worker argument=a, 434*da0073e9SAndroid Build Coastguard Worker ) 435*da0073e9SAndroid Build Coastguard Worker ] 436*da0073e9SAndroid Build Coastguard Worker elif isinstance(a, SelfArgument): 437*da0073e9SAndroid Build Coastguard Worker if method: 438*da0073e9SAndroid Build Coastguard Worker # Caller is responsible for installing implicit this in context! 439*da0073e9SAndroid Build Coastguard Worker return [] 440*da0073e9SAndroid Build Coastguard Worker else: 441*da0073e9SAndroid Build Coastguard Worker return sub_argument(a.argument) 442*da0073e9SAndroid Build Coastguard Worker else: 443*da0073e9SAndroid Build Coastguard Worker assert_never(a) 444*da0073e9SAndroid Build Coastguard Worker 445*da0073e9SAndroid Build Coastguard Worker 446*da0073e9SAndroid Build Coastguard Workerdef arguments( 447*da0073e9SAndroid Build Coastguard Worker arguments: Arguments, 448*da0073e9SAndroid Build Coastguard Worker *, 449*da0073e9SAndroid Build Coastguard Worker faithful: bool, 450*da0073e9SAndroid Build Coastguard Worker symint: bool = False, 451*da0073e9SAndroid Build Coastguard Worker method: bool, 452*da0073e9SAndroid Build Coastguard Worker cpp_no_default_args: set[str], 453*da0073e9SAndroid Build Coastguard Worker) -> list[Binding]: 454*da0073e9SAndroid Build Coastguard Worker args: list[Argument | TensorOptionsArguments | SelfArgument] = [] 455*da0073e9SAndroid Build Coastguard Worker if faithful: 456*da0073e9SAndroid Build Coastguard Worker args.extend(arguments.non_out) 457*da0073e9SAndroid Build Coastguard Worker args.extend(arguments.out) 458*da0073e9SAndroid Build Coastguard Worker else: 459*da0073e9SAndroid Build Coastguard Worker args.extend(arguments.out) 460*da0073e9SAndroid Build Coastguard Worker args.extend(arguments.non_out) 461*da0073e9SAndroid Build Coastguard Worker return [ 462*da0073e9SAndroid Build Coastguard Worker r.no_default() if faithful else r 463*da0073e9SAndroid Build Coastguard Worker for a in args 464*da0073e9SAndroid Build Coastguard Worker for r in argument( 465*da0073e9SAndroid Build Coastguard Worker a, 466*da0073e9SAndroid Build Coastguard Worker faithful=faithful, 467*da0073e9SAndroid Build Coastguard Worker symint=symint, 468*da0073e9SAndroid Build Coastguard Worker method=method, 469*da0073e9SAndroid Build Coastguard Worker has_tensor_options=arguments.tensor_options is not None, 470*da0073e9SAndroid Build Coastguard Worker cpp_no_default_args=cpp_no_default_args, 471*da0073e9SAndroid Build Coastguard Worker ) 472*da0073e9SAndroid Build Coastguard Worker ] 473