xref: /aosp_15_r20/external/pytorch/torchgen/api/cpp.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerfrom typing import Sequence
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerfrom torchgen import local
6*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.types import (
7*da0073e9SAndroid Build Coastguard Worker    ArgName,
8*da0073e9SAndroid Build Coastguard Worker    ArrayCType,
9*da0073e9SAndroid Build Coastguard Worker    ArrayRefCType,
10*da0073e9SAndroid Build Coastguard Worker    BaseCType,
11*da0073e9SAndroid Build Coastguard Worker    BaseTypeToCppMapping,
12*da0073e9SAndroid Build Coastguard Worker    Binding,
13*da0073e9SAndroid Build Coastguard Worker    boolT,
14*da0073e9SAndroid Build Coastguard Worker    ConstRefCType,
15*da0073e9SAndroid Build Coastguard Worker    CType,
16*da0073e9SAndroid Build Coastguard Worker    dimnameListT,
17*da0073e9SAndroid Build Coastguard Worker    intArrayRefT,
18*da0073e9SAndroid Build Coastguard Worker    iTensorListRefT,
19*da0073e9SAndroid Build Coastguard Worker    ListCType,
20*da0073e9SAndroid Build Coastguard Worker    longT,
21*da0073e9SAndroid Build Coastguard Worker    MutRefCType,
22*da0073e9SAndroid Build Coastguard Worker    NamedCType,
23*da0073e9SAndroid Build Coastguard Worker    OptionalCType,
24*da0073e9SAndroid Build Coastguard Worker    optionalIntArrayRefT,
25*da0073e9SAndroid Build Coastguard Worker    optionalSymIntArrayRefT,
26*da0073e9SAndroid Build Coastguard Worker    scalarT,
27*da0073e9SAndroid Build Coastguard Worker    SpecialArgName,
28*da0073e9SAndroid Build Coastguard Worker    symIntArrayRefT,
29*da0073e9SAndroid Build Coastguard Worker    SymIntT,
30*da0073e9SAndroid Build Coastguard Worker    tensorListT,
31*da0073e9SAndroid Build Coastguard Worker    tensorOptionsT,
32*da0073e9SAndroid Build Coastguard Worker    tensorT,
33*da0073e9SAndroid Build Coastguard Worker    TupleCType,
34*da0073e9SAndroid Build Coastguard Worker    VectorCType,
35*da0073e9SAndroid Build Coastguard Worker    voidT,
36*da0073e9SAndroid Build Coastguard Worker)
37*da0073e9SAndroid Build Coastguard Workerfrom torchgen.model import (
38*da0073e9SAndroid Build Coastguard Worker    Argument,
39*da0073e9SAndroid Build Coastguard Worker    Arguments,
40*da0073e9SAndroid Build Coastguard Worker    BaseTy,
41*da0073e9SAndroid Build Coastguard Worker    BaseType,
42*da0073e9SAndroid Build Coastguard Worker    FunctionSchema,
43*da0073e9SAndroid Build Coastguard Worker    ListType,
44*da0073e9SAndroid Build Coastguard Worker    NativeFunction,
45*da0073e9SAndroid Build Coastguard Worker    OptionalType,
46*da0073e9SAndroid Build Coastguard Worker    Return,
47*da0073e9SAndroid Build Coastguard Worker    SelfArgument,
48*da0073e9SAndroid Build Coastguard Worker    TensorOptionsArguments,
49*da0073e9SAndroid Build Coastguard Worker    Type,
50*da0073e9SAndroid Build Coastguard Worker)
51*da0073e9SAndroid Build Coastguard Workerfrom torchgen.utils import assert_never
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker
54*da0073e9SAndroid Build Coastguard Worker# This file describes the translation of JIT schema to the public C++
55*da0073e9SAndroid Build Coastguard Worker# API, which is what people use when they call functions like at::add.
56*da0073e9SAndroid Build Coastguard Worker#
57*da0073e9SAndroid Build Coastguard Worker# Prominent characteristics of the C++ API:
58*da0073e9SAndroid Build Coastguard Worker#
59*da0073e9SAndroid Build Coastguard Worker#   - dtype, layout, device and pin_memory are collected into
60*da0073e9SAndroid Build Coastguard Worker#     a single C++ type TensorOptions  (the native functions API
61*da0073e9SAndroid Build Coastguard Worker#     also has this, but tensor options is really most relevant
62*da0073e9SAndroid Build Coastguard Worker#     for the C++ API; it makes calling kwarg factory functions
63*da0073e9SAndroid Build Coastguard Worker#     pleasant)
64*da0073e9SAndroid Build Coastguard Worker#
65*da0073e9SAndroid Build Coastguard Worker#   - defaulting lives here (in fact, the dispatcher is completely
66*da0073e9SAndroid Build Coastguard Worker#     oblivious of defaults!)
67*da0073e9SAndroid Build Coastguard Worker#
68*da0073e9SAndroid Build Coastguard Worker# BTW: policy on name collisions: we try not to have types with
69*da0073e9SAndroid Build Coastguard Worker# collisions, but functions are fair game to collide
70*da0073e9SAndroid Build Coastguard Worker
71*da0073e9SAndroid Build Coastguard Worker
72*da0073e9SAndroid Build Coastguard Workerdef name(
73*da0073e9SAndroid Build Coastguard Worker    func: FunctionSchema,
74*da0073e9SAndroid Build Coastguard Worker    *,
75*da0073e9SAndroid Build Coastguard Worker    faithful_name_for_out_overloads: bool = False,
76*da0073e9SAndroid Build Coastguard Worker    symint_overload: bool = False,
77*da0073e9SAndroid Build Coastguard Worker) -> str:
78*da0073e9SAndroid Build Coastguard Worker    name = str(func.name.name)
79*da0073e9SAndroid Build Coastguard Worker    if symint_overload:
80*da0073e9SAndroid Build Coastguard Worker        name += "_symint"
81*da0073e9SAndroid Build Coastguard Worker    if func.is_out_fn():
82*da0073e9SAndroid Build Coastguard Worker        if faithful_name_for_out_overloads:
83*da0073e9SAndroid Build Coastguard Worker            name += "_outf"
84*da0073e9SAndroid Build Coastguard Worker        else:
85*da0073e9SAndroid Build Coastguard Worker            name += "_out"
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker    return name
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker
90*da0073e9SAndroid Build Coastguard Worker# Translation of "value types" in JIT schema to C++ API type.  Value
91*da0073e9SAndroid Build Coastguard Worker# types look the same no matter if they are argument types or return
92*da0073e9SAndroid Build Coastguard Worker# types.  Returns None if the type in question is not a value type.
93*da0073e9SAndroid Build Coastguard Workerdef valuetype_type(
94*da0073e9SAndroid Build Coastguard Worker    t: Type,
95*da0073e9SAndroid Build Coastguard Worker    *,
96*da0073e9SAndroid Build Coastguard Worker    binds: ArgName,
97*da0073e9SAndroid Build Coastguard Worker    mutable: bool = True,
98*da0073e9SAndroid Build Coastguard Worker    remove_non_owning_ref_types: bool = False,
99*da0073e9SAndroid Build Coastguard Worker    symint: bool = False,
100*da0073e9SAndroid Build Coastguard Worker) -> NamedCType | None:
101*da0073e9SAndroid Build Coastguard Worker    if isinstance(t, BaseType):
102*da0073e9SAndroid Build Coastguard Worker        if t.name == BaseTy.Tensor or t.name == BaseTy.Scalar:
103*da0073e9SAndroid Build Coastguard Worker            return None
104*da0073e9SAndroid Build Coastguard Worker        elif str(t) == "SymInt":
105*da0073e9SAndroid Build Coastguard Worker            if symint:
106*da0073e9SAndroid Build Coastguard Worker                return NamedCType(binds, BaseCType(SymIntT))
107*da0073e9SAndroid Build Coastguard Worker            else:
108*da0073e9SAndroid Build Coastguard Worker                return NamedCType(binds, BaseCType(longT))
109*da0073e9SAndroid Build Coastguard Worker        if remove_non_owning_ref_types:
110*da0073e9SAndroid Build Coastguard Worker            if t.name == BaseTy.str:
111*da0073e9SAndroid Build Coastguard Worker                raise AssertionError(
112*da0073e9SAndroid Build Coastguard Worker                    "string ref->value conversion: not implemented yet"
113*da0073e9SAndroid Build Coastguard Worker                )
114*da0073e9SAndroid Build Coastguard Worker        # All other BaseType currently map directly to BaseCppTypes.
115*da0073e9SAndroid Build Coastguard Worker        return NamedCType(binds, BaseCType(BaseTypeToCppMapping[t.name]))
116*da0073e9SAndroid Build Coastguard Worker    elif isinstance(t, OptionalType):
117*da0073e9SAndroid Build Coastguard Worker        elem = valuetype_type(t.elem, binds=binds, mutable=mutable, symint=symint)
118*da0073e9SAndroid Build Coastguard Worker        if elem is None:
119*da0073e9SAndroid Build Coastguard Worker            return None
120*da0073e9SAndroid Build Coastguard Worker        return NamedCType(binds, OptionalCType(elem.type))
121*da0073e9SAndroid Build Coastguard Worker    elif isinstance(t, ListType):
122*da0073e9SAndroid Build Coastguard Worker        if str(t.elem) == "bool":
123*da0073e9SAndroid Build Coastguard Worker            assert t.size is not None
124*da0073e9SAndroid Build Coastguard Worker            return NamedCType(binds, ArrayCType(BaseCType(boolT), t.size))
125*da0073e9SAndroid Build Coastguard Worker        else:
126*da0073e9SAndroid Build Coastguard Worker            return None
127*da0073e9SAndroid Build Coastguard Worker    else:
128*da0073e9SAndroid Build Coastguard Worker        raise AssertionError(f"unrecognized type {repr(t)}")
129*da0073e9SAndroid Build Coastguard Worker
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker# Translation of types occurring in JIT arguments to a C++ argument type.
132*da0073e9SAndroid Build Coastguard Worker# If remove_non_owning_ref_types is set, we'll guarantee that the outputed CType is not a non-owning reference type.
133*da0073e9SAndroid Build Coastguard Worker# For example, we'll return std::vector<int> instead of IntArrayRef.
134*da0073e9SAndroid Build Coastguard Worker# See Note [translation from C++ reference to value types]
135*da0073e9SAndroid Build Coastguard Workerdef argumenttype_type(
136*da0073e9SAndroid Build Coastguard Worker    t: Type,
137*da0073e9SAndroid Build Coastguard Worker    *,
138*da0073e9SAndroid Build Coastguard Worker    mutable: bool,
139*da0073e9SAndroid Build Coastguard Worker    binds: ArgName,
140*da0073e9SAndroid Build Coastguard Worker    remove_non_owning_ref_types: bool = False,
141*da0073e9SAndroid Build Coastguard Worker    symint: bool = False,
142*da0073e9SAndroid Build Coastguard Worker) -> NamedCType:
143*da0073e9SAndroid Build Coastguard Worker    # If it's a value type, do the value type translation
144*da0073e9SAndroid Build Coastguard Worker    r = valuetype_type(
145*da0073e9SAndroid Build Coastguard Worker        t,
146*da0073e9SAndroid Build Coastguard Worker        binds=binds,
147*da0073e9SAndroid Build Coastguard Worker        mutable=mutable,
148*da0073e9SAndroid Build Coastguard Worker        symint=symint,
149*da0073e9SAndroid Build Coastguard Worker        remove_non_owning_ref_types=remove_non_owning_ref_types,
150*da0073e9SAndroid Build Coastguard Worker    )
151*da0073e9SAndroid Build Coastguard Worker    if r is not None:
152*da0073e9SAndroid Build Coastguard Worker        return r
153*da0073e9SAndroid Build Coastguard Worker
154*da0073e9SAndroid Build Coastguard Worker    if isinstance(t, BaseType):
155*da0073e9SAndroid Build Coastguard Worker        if t.name == BaseTy.Tensor:
156*da0073e9SAndroid Build Coastguard Worker            if mutable and not local.use_const_ref_for_mutable_tensors():
157*da0073e9SAndroid Build Coastguard Worker                return NamedCType(binds, MutRefCType(BaseCType(tensorT)))
158*da0073e9SAndroid Build Coastguard Worker            else:
159*da0073e9SAndroid Build Coastguard Worker                return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
160*da0073e9SAndroid Build Coastguard Worker        elif t.name == BaseTy.Scalar:
161*da0073e9SAndroid Build Coastguard Worker            return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
162*da0073e9SAndroid Build Coastguard Worker        else:
163*da0073e9SAndroid Build Coastguard Worker            raise AssertionError(f"base type should have been value type {t}")
164*da0073e9SAndroid Build Coastguard Worker    elif isinstance(t, OptionalType):
165*da0073e9SAndroid Build Coastguard Worker        if str(t.elem) == "Tensor":
166*da0073e9SAndroid Build Coastguard Worker            if mutable and not local.use_const_ref_for_mutable_tensors():
167*da0073e9SAndroid Build Coastguard Worker                return NamedCType(
168*da0073e9SAndroid Build Coastguard Worker                    binds, MutRefCType(BaseCType(tensorT))
169*da0073e9SAndroid Build Coastguard Worker                )  # TODO: fix this discrepancy
170*da0073e9SAndroid Build Coastguard Worker            else:
171*da0073e9SAndroid Build Coastguard Worker                return NamedCType(
172*da0073e9SAndroid Build Coastguard Worker                    binds, ConstRefCType(OptionalCType(BaseCType(tensorT)))
173*da0073e9SAndroid Build Coastguard Worker                )
174*da0073e9SAndroid Build Coastguard Worker        elif str(t.elem) == "Scalar":
175*da0073e9SAndroid Build Coastguard Worker            return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
176*da0073e9SAndroid Build Coastguard Worker        elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int":
177*da0073e9SAndroid Build Coastguard Worker            return NamedCType(binds, BaseCType(optionalIntArrayRefT))
178*da0073e9SAndroid Build Coastguard Worker        elif isinstance(t.elem, ListType) and str(t.elem.elem) == "SymInt":
179*da0073e9SAndroid Build Coastguard Worker            if symint:
180*da0073e9SAndroid Build Coastguard Worker                return NamedCType(binds, BaseCType(optionalSymIntArrayRefT))
181*da0073e9SAndroid Build Coastguard Worker            else:
182*da0073e9SAndroid Build Coastguard Worker                return NamedCType(binds, BaseCType(optionalIntArrayRefT))
183*da0073e9SAndroid Build Coastguard Worker        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint)
184*da0073e9SAndroid Build Coastguard Worker        return NamedCType(binds, OptionalCType(elem.type))
185*da0073e9SAndroid Build Coastguard Worker    elif isinstance(t, ListType):
186*da0073e9SAndroid Build Coastguard Worker        # TODO: remove these special cases, ArrayRef fallthrough works fine
187*da0073e9SAndroid Build Coastguard Worker        if str(t.elem) == "int":
188*da0073e9SAndroid Build Coastguard Worker            if remove_non_owning_ref_types:
189*da0073e9SAndroid Build Coastguard Worker                return NamedCType(binds, VectorCType(BaseCType(longT)))
190*da0073e9SAndroid Build Coastguard Worker            else:
191*da0073e9SAndroid Build Coastguard Worker                return NamedCType(binds, BaseCType(intArrayRefT))
192*da0073e9SAndroid Build Coastguard Worker        if str(t.elem) == "SymInt":
193*da0073e9SAndroid Build Coastguard Worker            if remove_non_owning_ref_types:
194*da0073e9SAndroid Build Coastguard Worker                if symint:
195*da0073e9SAndroid Build Coastguard Worker                    return NamedCType(binds, VectorCType(BaseCType(SymIntT)))
196*da0073e9SAndroid Build Coastguard Worker                else:
197*da0073e9SAndroid Build Coastguard Worker                    return NamedCType(binds, VectorCType(BaseCType(longT)))
198*da0073e9SAndroid Build Coastguard Worker            else:
199*da0073e9SAndroid Build Coastguard Worker                if symint:
200*da0073e9SAndroid Build Coastguard Worker                    return NamedCType(binds, BaseCType(symIntArrayRefT))
201*da0073e9SAndroid Build Coastguard Worker                else:
202*da0073e9SAndroid Build Coastguard Worker                    return NamedCType(binds, BaseCType(intArrayRefT))
203*da0073e9SAndroid Build Coastguard Worker        if str(t.elem) == "Tensor":
204*da0073e9SAndroid Build Coastguard Worker            if local.use_ilistref_for_tensor_lists():
205*da0073e9SAndroid Build Coastguard Worker                return NamedCType(binds, ConstRefCType(BaseCType(iTensorListRefT)))
206*da0073e9SAndroid Build Coastguard Worker            else:
207*da0073e9SAndroid Build Coastguard Worker                return NamedCType(binds, BaseCType(tensorListT))
208*da0073e9SAndroid Build Coastguard Worker        elif str(t.elem) == "Scalar":
209*da0073e9SAndroid Build Coastguard Worker            return NamedCType(binds, ArrayRefCType(BaseCType(scalarT)))
210*da0073e9SAndroid Build Coastguard Worker        elif str(t.elem) == "Dimname":
211*da0073e9SAndroid Build Coastguard Worker            return NamedCType(binds, BaseCType(dimnameListT))
212*da0073e9SAndroid Build Coastguard Worker        elif str(t.elem) == "Tensor?":
213*da0073e9SAndroid Build Coastguard Worker            return NamedCType(
214*da0073e9SAndroid Build Coastguard Worker                binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT))))
215*da0073e9SAndroid Build Coastguard Worker            )
216*da0073e9SAndroid Build Coastguard Worker        elem = argumenttype_type(t.elem, mutable=mutable, binds=binds, symint=symint)
217*da0073e9SAndroid Build Coastguard Worker        return NamedCType(binds, ArrayRefCType(elem.type))
218*da0073e9SAndroid Build Coastguard Worker    else:
219*da0073e9SAndroid Build Coastguard Worker        raise AssertionError(f"unrecognized type {repr(t)}")
220*da0073e9SAndroid Build Coastguard Worker
221*da0073e9SAndroid Build Coastguard Worker
222*da0073e9SAndroid Build Coastguard Worker# Translate a JIT argument into its C++ type
223*da0073e9SAndroid Build Coastguard Workerdef argument_type(a: Argument, *, binds: ArgName, symint: bool = False) -> NamedCType:
224*da0073e9SAndroid Build Coastguard Worker    return argumenttype_type(a.type, mutable=a.is_write, symint=symint, binds=binds)
225*da0073e9SAndroid Build Coastguard Worker
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Worker# Translation of a (non-multi) return type from JIT to C++
228*da0073e9SAndroid Build Coastguard Worker# N.B: returntype_type returns a CType, not a NamedCType.
229*da0073e9SAndroid Build Coastguard Worker# This is mostly because of the mismatch between return types and return names.
230*da0073e9SAndroid Build Coastguard Worker# e.g. a function with a return type of 'void' has 0 return names,
231*da0073e9SAndroid Build Coastguard Worker# and a function with a return type of 'std::tuple' has >1 return name.
232*da0073e9SAndroid Build Coastguard Workerdef returntype_type(t: Type, *, mutable: bool, symint: bool = False) -> CType:
233*da0073e9SAndroid Build Coastguard Worker    # placeholder is ignored
234*da0073e9SAndroid Build Coastguard Worker    # NB: symint is ALWAYS respected for return types.  So symint argument
235*da0073e9SAndroid Build Coastguard Worker    # here is IGNORED
236*da0073e9SAndroid Build Coastguard Worker    r = valuetype_type(t, binds="__placeholder__", mutable=mutable, symint=True)
237*da0073e9SAndroid Build Coastguard Worker    if r is not None:
238*da0073e9SAndroid Build Coastguard Worker        return r.type
239*da0073e9SAndroid Build Coastguard Worker
240*da0073e9SAndroid Build Coastguard Worker    if isinstance(t, BaseType):
241*da0073e9SAndroid Build Coastguard Worker        if t.name == BaseTy.Tensor:
242*da0073e9SAndroid Build Coastguard Worker            if mutable:
243*da0073e9SAndroid Build Coastguard Worker                if local.use_const_ref_for_mutable_tensors():
244*da0073e9SAndroid Build Coastguard Worker                    return ConstRefCType(BaseCType(tensorT))
245*da0073e9SAndroid Build Coastguard Worker                else:
246*da0073e9SAndroid Build Coastguard Worker                    return MutRefCType(BaseCType(tensorT))
247*da0073e9SAndroid Build Coastguard Worker            else:
248*da0073e9SAndroid Build Coastguard Worker                # Note [Tensor Copy Returns]
249*da0073e9SAndroid Build Coastguard Worker                # Currently, we use "Argument.is_write" to determine
250*da0073e9SAndroid Build Coastguard Worker                # whether or not Tensor return types should be copies or references.
251*da0073e9SAndroid Build Coastguard Worker                # If that ever changes, take a look at other locations of this note!
252*da0073e9SAndroid Build Coastguard Worker                return BaseCType(tensorT)
253*da0073e9SAndroid Build Coastguard Worker        elif t.name == BaseTy.Scalar:
254*da0073e9SAndroid Build Coastguard Worker            return BaseCType(scalarT)
255*da0073e9SAndroid Build Coastguard Worker    elif isinstance(t, ListType):
256*da0073e9SAndroid Build Coastguard Worker        assert (
257*da0073e9SAndroid Build Coastguard Worker            not mutable
258*da0073e9SAndroid Build Coastguard Worker        ), "Native functions should never return a mutable tensor list. They should return void."
259*da0073e9SAndroid Build Coastguard Worker        elem = returntype_type(t.elem, mutable=False)
260*da0073e9SAndroid Build Coastguard Worker        assert t.size is None, f"fixed size list returns not supported: {t}"
261*da0073e9SAndroid Build Coastguard Worker        return VectorCType(elem)
262*da0073e9SAndroid Build Coastguard Worker    elif isinstance(t, OptionalType):
263*da0073e9SAndroid Build Coastguard Worker        elem = returntype_type(t.elem, mutable=mutable)
264*da0073e9SAndroid Build Coastguard Worker        if str(t.elem) == "Tensor":
265*da0073e9SAndroid Build Coastguard Worker            return OptionalCType(elem)
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Worker    raise AssertionError(f"unrecognized return type {t}")
268*da0073e9SAndroid Build Coastguard Worker
269*da0073e9SAndroid Build Coastguard Worker
270*da0073e9SAndroid Build Coastguard Worker# Translation of a single return to its C++ type
271*da0073e9SAndroid Build Coastguard Workerdef return_type(r: Return, *, symint: bool = False) -> CType:
272*da0073e9SAndroid Build Coastguard Worker    return returntype_type(r.type, mutable=r.is_write, symint=symint)
273*da0073e9SAndroid Build Coastguard Worker
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Worker# Translation of a full (possibly multi) return from JIT to its C++ type
276*da0073e9SAndroid Build Coastguard Workerdef returns_type(rs: Sequence[Return], *, symint: bool = False) -> CType:
277*da0073e9SAndroid Build Coastguard Worker    if len(rs) == 0:
278*da0073e9SAndroid Build Coastguard Worker        return BaseCType(voidT)
279*da0073e9SAndroid Build Coastguard Worker    elif len(rs) == 1:
280*da0073e9SAndroid Build Coastguard Worker        return return_type(rs[0], symint=symint)
281*da0073e9SAndroid Build Coastguard Worker    else:
282*da0073e9SAndroid Build Coastguard Worker        return TupleCType([return_type(r, symint=symint) for r in rs])
283*da0073e9SAndroid Build Coastguard Worker
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Workerdef return_names(f: NativeFunction, *, fallback_name: str = "result") -> Sequence[str]:
286*da0073e9SAndroid Build Coastguard Worker    returns: list[str] = []
287*da0073e9SAndroid Build Coastguard Worker    for i, r in enumerate(f.func.returns):
288*da0073e9SAndroid Build Coastguard Worker        # If we have an inplace function, the return argument is
289*da0073e9SAndroid Build Coastguard Worker        # implicitly named self.
290*da0073e9SAndroid Build Coastguard Worker        # TODO: Consider incorporating this into the data model
291*da0073e9SAndroid Build Coastguard Worker        if f.func.name.name.inplace:
292*da0073e9SAndroid Build Coastguard Worker            assert i == 0, "illegal inplace function with multiple returns"
293*da0073e9SAndroid Build Coastguard Worker            name = "self"
294*da0073e9SAndroid Build Coastguard Worker        # If we are out function, the name is the name of the
295*da0073e9SAndroid Build Coastguard Worker        # corresponding output function (r.name will get recorded
296*da0073e9SAndroid Build Coastguard Worker        # in field_name later.)
297*da0073e9SAndroid Build Coastguard Worker        elif f.func.is_out_fn():
298*da0073e9SAndroid Build Coastguard Worker            name = f.func.arguments.out[i].name
299*da0073e9SAndroid Build Coastguard Worker        # If the return argument is explicitly named...
300*da0073e9SAndroid Build Coastguard Worker        elif r.name:
301*da0073e9SAndroid Build Coastguard Worker            name_conflict = any(
302*da0073e9SAndroid Build Coastguard Worker                r.name == a.name for a in f.func.schema_order_arguments()
303*da0073e9SAndroid Build Coastguard Worker            )
304*da0073e9SAndroid Build Coastguard Worker            if name_conflict and not f.func.is_out_fn():
305*da0073e9SAndroid Build Coastguard Worker                name = f"{r.name}_return"
306*da0073e9SAndroid Build Coastguard Worker            else:
307*da0073e9SAndroid Build Coastguard Worker                name = r.name
308*da0073e9SAndroid Build Coastguard Worker        # If there is no explicit name and no fallback name was passed in, we just name the output result,
309*da0073e9SAndroid Build Coastguard Worker        # unless it's a multi-return, in which case it's result0,
310*da0073e9SAndroid Build Coastguard Worker        # result1, etc (zero-indexed)
311*da0073e9SAndroid Build Coastguard Worker        else:
312*da0073e9SAndroid Build Coastguard Worker            name = fallback_name if len(f.func.returns) == 1 else f"{fallback_name}{i}"
313*da0073e9SAndroid Build Coastguard Worker        returns.append(name)
314*da0073e9SAndroid Build Coastguard Worker    return returns
315*da0073e9SAndroid Build Coastguard Worker
316*da0073e9SAndroid Build Coastguard Worker
317*da0073e9SAndroid Build Coastguard WorkerJIT_TO_CPP_DEFAULT = {
318*da0073e9SAndroid Build Coastguard Worker    "False": "false",
319*da0073e9SAndroid Build Coastguard Worker    "True": "true",
320*da0073e9SAndroid Build Coastguard Worker    "None": "::std::nullopt",  # UGH this one is type directed
321*da0073e9SAndroid Build Coastguard Worker    "Mean": "at::Reduction::Mean",
322*da0073e9SAndroid Build Coastguard Worker    "[]": "{}",
323*da0073e9SAndroid Build Coastguard Worker    "contiguous_format": "c10::MemoryFormat::Contiguous",
324*da0073e9SAndroid Build Coastguard Worker    "long": "at::kLong",
325*da0073e9SAndroid Build Coastguard Worker}
326*da0073e9SAndroid Build Coastguard Worker
327*da0073e9SAndroid Build Coastguard Worker
328*da0073e9SAndroid Build Coastguard Worker# Convert a JIT default into C++ expression representing the default
329*da0073e9SAndroid Build Coastguard Workerdef default_expr(d: str, t: Type, *, symint: bool) -> str:
330*da0073e9SAndroid Build Coastguard Worker    if d == "None" and str(t) == "Tensor?":
331*da0073e9SAndroid Build Coastguard Worker        return "{}"
332*da0073e9SAndroid Build Coastguard Worker    if isinstance(t, BaseType) and t.name is BaseTy.str:
333*da0073e9SAndroid Build Coastguard Worker        # Schema allows single quotes but C++ needs double
334*da0073e9SAndroid Build Coastguard Worker        if len(d) >= 2 and d[0] == "'" and d[-1] == "'":
335*da0073e9SAndroid Build Coastguard Worker            s = ""
336*da0073e9SAndroid Build Coastguard Worker            i = 1
337*da0073e9SAndroid Build Coastguard Worker            while i + 1 < len(d):
338*da0073e9SAndroid Build Coastguard Worker                if d[i] != "\\":
339*da0073e9SAndroid Build Coastguard Worker                    if d[i] == '"':
340*da0073e9SAndroid Build Coastguard Worker                        s += '\\"'
341*da0073e9SAndroid Build Coastguard Worker                    else:
342*da0073e9SAndroid Build Coastguard Worker                        s += d[i]
343*da0073e9SAndroid Build Coastguard Worker                    i += 1
344*da0073e9SAndroid Build Coastguard Worker                else:
345*da0073e9SAndroid Build Coastguard Worker                    if d[i + 1] == "'":
346*da0073e9SAndroid Build Coastguard Worker                        s += "'"
347*da0073e9SAndroid Build Coastguard Worker                    else:
348*da0073e9SAndroid Build Coastguard Worker                        s += d[i : i + 2]
349*da0073e9SAndroid Build Coastguard Worker                    i += 2
350*da0073e9SAndroid Build Coastguard Worker
351*da0073e9SAndroid Build Coastguard Worker            return f'"{s}"'
352*da0073e9SAndroid Build Coastguard Worker
353*da0073e9SAndroid Build Coastguard Worker    if isinstance(t, OptionalType):
354*da0073e9SAndroid Build Coastguard Worker        if d == "None":
355*da0073e9SAndroid Build Coastguard Worker            return "::std::nullopt"
356*da0073e9SAndroid Build Coastguard Worker
357*da0073e9SAndroid Build Coastguard Worker        return default_expr(d, t.elem, symint=symint)
358*da0073e9SAndroid Build Coastguard Worker
359*da0073e9SAndroid Build Coastguard Worker    if isinstance(t, ListType):
360*da0073e9SAndroid Build Coastguard Worker        if d.startswith("[") and d.endswith("]"):
361*da0073e9SAndroid Build Coastguard Worker            return "{" + d[1:-1] + "}"
362*da0073e9SAndroid Build Coastguard Worker        elif symint and d.isdigit() and str(t.elem) == "SymInt":
363*da0073e9SAndroid Build Coastguard Worker            return f"c10::SymInt({d})"
364*da0073e9SAndroid Build Coastguard Worker        elif t.size is None:
365*da0073e9SAndroid Build Coastguard Worker            # NOTE: Sized lists can have scalar defaults
366*da0073e9SAndroid Build Coastguard Worker            raise ValueError(f"Expected a list default '[...]' but found: '{d}'")
367*da0073e9SAndroid Build Coastguard Worker
368*da0073e9SAndroid Build Coastguard Worker    return JIT_TO_CPP_DEFAULT.get(d, d)
369*da0073e9SAndroid Build Coastguard Worker
370*da0073e9SAndroid Build Coastguard Worker
371*da0073e9SAndroid Build Coastguard Worker# Convert an argument into its C++ API form
372*da0073e9SAndroid Build Coastguard Worker
373*da0073e9SAndroid Build Coastguard Worker
374*da0073e9SAndroid Build Coastguard Workerdef argument(
375*da0073e9SAndroid Build Coastguard Worker    a: Argument | TensorOptionsArguments | SelfArgument,
376*da0073e9SAndroid Build Coastguard Worker    *,
377*da0073e9SAndroid Build Coastguard Worker    cpp_no_default_args: set[str],
378*da0073e9SAndroid Build Coastguard Worker    method: bool,
379*da0073e9SAndroid Build Coastguard Worker    faithful: bool,
380*da0073e9SAndroid Build Coastguard Worker    symint: bool = False,
381*da0073e9SAndroid Build Coastguard Worker    has_tensor_options: bool,
382*da0073e9SAndroid Build Coastguard Worker) -> list[Binding]:
383*da0073e9SAndroid Build Coastguard Worker    def sub_argument(
384*da0073e9SAndroid Build Coastguard Worker        a: Argument | TensorOptionsArguments | SelfArgument,
385*da0073e9SAndroid Build Coastguard Worker    ) -> list[Binding]:
386*da0073e9SAndroid Build Coastguard Worker        return argument(
387*da0073e9SAndroid Build Coastguard Worker            a,
388*da0073e9SAndroid Build Coastguard Worker            cpp_no_default_args=cpp_no_default_args,
389*da0073e9SAndroid Build Coastguard Worker            method=method,
390*da0073e9SAndroid Build Coastguard Worker            faithful=faithful,
391*da0073e9SAndroid Build Coastguard Worker            symint=symint,
392*da0073e9SAndroid Build Coastguard Worker            has_tensor_options=has_tensor_options,
393*da0073e9SAndroid Build Coastguard Worker        )
394*da0073e9SAndroid Build Coastguard Worker
395*da0073e9SAndroid Build Coastguard Worker    if isinstance(a, Argument):
396*da0073e9SAndroid Build Coastguard Worker        binds: ArgName
397*da0073e9SAndroid Build Coastguard Worker        if a.name == "memory_format" and has_tensor_options:
398*da0073e9SAndroid Build Coastguard Worker            binds = SpecialArgName.possibly_redundant_memory_format
399*da0073e9SAndroid Build Coastguard Worker        else:
400*da0073e9SAndroid Build Coastguard Worker            binds = a.name
401*da0073e9SAndroid Build Coastguard Worker        default: str | None = None
402*da0073e9SAndroid Build Coastguard Worker        if a.name not in cpp_no_default_args and a.default is not None:
403*da0073e9SAndroid Build Coastguard Worker            default = default_expr(a.default, a.type, symint=symint)
404*da0073e9SAndroid Build Coastguard Worker        return [
405*da0073e9SAndroid Build Coastguard Worker            Binding(
406*da0073e9SAndroid Build Coastguard Worker                nctype=argument_type(a, binds=binds, symint=symint),
407*da0073e9SAndroid Build Coastguard Worker                name=a.name,
408*da0073e9SAndroid Build Coastguard Worker                default=default,
409*da0073e9SAndroid Build Coastguard Worker                argument=a,
410*da0073e9SAndroid Build Coastguard Worker            )
411*da0073e9SAndroid Build Coastguard Worker        ]
412*da0073e9SAndroid Build Coastguard Worker    elif isinstance(a, TensorOptionsArguments):
413*da0073e9SAndroid Build Coastguard Worker        if faithful:
414*da0073e9SAndroid Build Coastguard Worker            return (
415*da0073e9SAndroid Build Coastguard Worker                sub_argument(a.dtype)
416*da0073e9SAndroid Build Coastguard Worker                + sub_argument(a.layout)
417*da0073e9SAndroid Build Coastguard Worker                + sub_argument(a.device)
418*da0073e9SAndroid Build Coastguard Worker                + sub_argument(a.pin_memory)
419*da0073e9SAndroid Build Coastguard Worker            )
420*da0073e9SAndroid Build Coastguard Worker        else:
421*da0073e9SAndroid Build Coastguard Worker            default = None
422*da0073e9SAndroid Build Coastguard Worker            # Enforced by NativeFunction.__post_init__
423*da0073e9SAndroid Build Coastguard Worker            assert "options" not in cpp_no_default_args
424*da0073e9SAndroid Build Coastguard Worker            if all(x.default == "None" for x in a.all()):
425*da0073e9SAndroid Build Coastguard Worker                default = "{}"
426*da0073e9SAndroid Build Coastguard Worker            elif a.dtype.default == "long":
427*da0073e9SAndroid Build Coastguard Worker                default = "at::kLong"  # TODO: this is wrong
428*da0073e9SAndroid Build Coastguard Worker            return [
429*da0073e9SAndroid Build Coastguard Worker                Binding(
430*da0073e9SAndroid Build Coastguard Worker                    nctype=NamedCType("options", BaseCType(tensorOptionsT)),
431*da0073e9SAndroid Build Coastguard Worker                    name="options",
432*da0073e9SAndroid Build Coastguard Worker                    default=default,
433*da0073e9SAndroid Build Coastguard Worker                    argument=a,
434*da0073e9SAndroid Build Coastguard Worker                )
435*da0073e9SAndroid Build Coastguard Worker            ]
436*da0073e9SAndroid Build Coastguard Worker    elif isinstance(a, SelfArgument):
437*da0073e9SAndroid Build Coastguard Worker        if method:
438*da0073e9SAndroid Build Coastguard Worker            # Caller is responsible for installing implicit this in context!
439*da0073e9SAndroid Build Coastguard Worker            return []
440*da0073e9SAndroid Build Coastguard Worker        else:
441*da0073e9SAndroid Build Coastguard Worker            return sub_argument(a.argument)
442*da0073e9SAndroid Build Coastguard Worker    else:
443*da0073e9SAndroid Build Coastguard Worker        assert_never(a)
444*da0073e9SAndroid Build Coastguard Worker
445*da0073e9SAndroid Build Coastguard Worker
446*da0073e9SAndroid Build Coastguard Workerdef arguments(
447*da0073e9SAndroid Build Coastguard Worker    arguments: Arguments,
448*da0073e9SAndroid Build Coastguard Worker    *,
449*da0073e9SAndroid Build Coastguard Worker    faithful: bool,
450*da0073e9SAndroid Build Coastguard Worker    symint: bool = False,
451*da0073e9SAndroid Build Coastguard Worker    method: bool,
452*da0073e9SAndroid Build Coastguard Worker    cpp_no_default_args: set[str],
453*da0073e9SAndroid Build Coastguard Worker) -> list[Binding]:
454*da0073e9SAndroid Build Coastguard Worker    args: list[Argument | TensorOptionsArguments | SelfArgument] = []
455*da0073e9SAndroid Build Coastguard Worker    if faithful:
456*da0073e9SAndroid Build Coastguard Worker        args.extend(arguments.non_out)
457*da0073e9SAndroid Build Coastguard Worker        args.extend(arguments.out)
458*da0073e9SAndroid Build Coastguard Worker    else:
459*da0073e9SAndroid Build Coastguard Worker        args.extend(arguments.out)
460*da0073e9SAndroid Build Coastguard Worker        args.extend(arguments.non_out)
461*da0073e9SAndroid Build Coastguard Worker    return [
462*da0073e9SAndroid Build Coastguard Worker        r.no_default() if faithful else r
463*da0073e9SAndroid Build Coastguard Worker        for a in args
464*da0073e9SAndroid Build Coastguard Worker        for r in argument(
465*da0073e9SAndroid Build Coastguard Worker            a,
466*da0073e9SAndroid Build Coastguard Worker            faithful=faithful,
467*da0073e9SAndroid Build Coastguard Worker            symint=symint,
468*da0073e9SAndroid Build Coastguard Worker            method=method,
469*da0073e9SAndroid Build Coastguard Worker            has_tensor_options=arguments.tensor_options is not None,
470*da0073e9SAndroid Build Coastguard Worker            cpp_no_default_args=cpp_no_default_args,
471*da0073e9SAndroid Build Coastguard Worker        )
472*da0073e9SAndroid Build Coastguard Worker    ]
473