xref: /aosp_15_r20/external/pytorch/torchgen/api/unboxing.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api import cpp
4*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.types import Binding, CppSignatureGroup, CType
5*da0073e9SAndroid Build Coastguard Workerfrom torchgen.model import (
6*da0073e9SAndroid Build Coastguard Worker    Argument,
7*da0073e9SAndroid Build Coastguard Worker    BaseTy,
8*da0073e9SAndroid Build Coastguard Worker    BaseType,
9*da0073e9SAndroid Build Coastguard Worker    ListType,
10*da0073e9SAndroid Build Coastguard Worker    NativeFunction,
11*da0073e9SAndroid Build Coastguard Worker    OptionalType,
12*da0073e9SAndroid Build Coastguard Worker    Type,
13*da0073e9SAndroid Build Coastguard Worker)
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker# This file generates the code for unboxing wrappers, i.e., the glue logic to unbox a boxed operator and convert the
17*da0073e9SAndroid Build Coastguard Worker# ivalues from stack to correct arguments to the unboxed kernel, based on corresponding JIT schema. This codegen is
18*da0073e9SAndroid Build Coastguard Worker# an alternative way to generate unboxing wrappers similar to the existing C++ metaprogramming approach but gets the
19*da0073e9SAndroid Build Coastguard Worker# job done statically. These generated unboxing wrappers will be useful under the scenario where we need to register
20*da0073e9SAndroid Build Coastguard Worker# a fixed set of operators known at compile time and thus can save some time in runtime initialization phase.
21*da0073e9SAndroid Build Coastguard Worker#
22*da0073e9SAndroid Build Coastguard Worker# Here's an example on how the codegen works:
23*da0073e9SAndroid Build Coastguard Worker#
24*da0073e9SAndroid Build Coastguard Worker# - Function Schema (source of truth)
25*da0073e9SAndroid Build Coastguard Worker#
26*da0073e9SAndroid Build Coastguard Worker#      aten::empty.names(int[] size, *, Dimname[]? names,
27*da0073e9SAndroid Build Coastguard Worker#                        ScalarType? dtype=None, Layout? layout=None,
28*da0073e9SAndroid Build Coastguard Worker#                        Device? device=None, bool? pin_memory=None,
29*da0073e9SAndroid Build Coastguard Worker#                        MemoryFormat? memory_format=None) -> Tensor
30*da0073e9SAndroid Build Coastguard Worker# - Argument Conversion
31*da0073e9SAndroid Build Coastguard Worker#       Generates C++ code to convert an ivalue (from stack) to its underlying C++ type.
32*da0073e9SAndroid Build Coastguard Worker#    - int[] size
33*da0073e9SAndroid Build Coastguard Worker#        ```cpp
34*da0073e9SAndroid Build Coastguard Worker#           const c10::List<c10::IValue> size_list_in = (std::move(peek(stack, 0, 7))).toList();
35*da0073e9SAndroid Build Coastguard Worker#
36*da0073e9SAndroid Build Coastguard Worker#           std::vector<int64_t> size_vec;
37*da0073e9SAndroid Build Coastguard Worker#           for (c10::IValue size_elem: size_list_in) {
38*da0073e9SAndroid Build Coastguard Worker#               int64_t size_base = size_elem.to<int64_t>();
39*da0073e9SAndroid Build Coastguard Worker#               size_vec.push_back(size_base);
40*da0073e9SAndroid Build Coastguard Worker#           }
41*da0073e9SAndroid Build Coastguard Worker#           at::ArrayRef<int64_t> size_list_out(size_vec);
42*da0073e9SAndroid Build Coastguard Worker#                                 ~~~~~~~~~~~~~ <-- The converted argument from ivalues in the stack.
43*da0073e9SAndroid Build Coastguard Worker#                                                   Will be passed to unboxed kernel.
44*da0073e9SAndroid Build Coastguard Worker#       ```
45*da0073e9SAndroid Build Coastguard Worker#    - Dimname[]? names
46*da0073e9SAndroid Build Coastguard Worker#       ```cpp
47*da0073e9SAndroid Build Coastguard Worker#           ::std::optional<c10::IValue> names_opt = (std::move(peek(stack, 1, 7))).toOptional<c10::IValue>();
48*da0073e9SAndroid Build Coastguard Worker#           ::std::optional<at::ArrayRef<at::Dimname>> names_opt_out;
49*da0073e9SAndroid Build Coastguard Worker#           if (names_opt.has_value()) {
50*da0073e9SAndroid Build Coastguard Worker#                         ~~~~~~~~~~~ <-- Unwrapping optional shell
51*da0073e9SAndroid Build Coastguard Worker#               const c10::IValue names_opt_in = names_opt.value();
52*da0073e9SAndroid Build Coastguard Worker#               const c10::List<c10::IValue> names_list_in = names_opt_in.toList();
53*da0073e9SAndroid Build Coastguard Worker#
54*da0073e9SAndroid Build Coastguard Worker#               std::vector<at::Dimname> names_vec;
55*da0073e9SAndroid Build Coastguard Worker#               for (c10::IValue names_elem: names_list_in) {
56*da0073e9SAndroid Build Coastguard Worker#                                ~~~~~~~~~~~~~~~~~~~~~~~~~ <-- Unrolling list, then convert elements one by one.
57*da0073e9SAndroid Build Coastguard Worker#                   at::Dimname names_base = names_elem.to<at::Dimname>();
58*da0073e9SAndroid Build Coastguard Worker#                   names_vec.push_back(names_base);
59*da0073e9SAndroid Build Coastguard Worker#               }
60*da0073e9SAndroid Build Coastguard Worker#               at::ArrayRef<at::Dimname> names_list_out(names_vec);
61*da0073e9SAndroid Build Coastguard Worker#
62*da0073e9SAndroid Build Coastguard Worker#               names_opt_out = ::std::optional<at::ArrayRef<at::Dimname>>(names_list_out);
63*da0073e9SAndroid Build Coastguard Worker#           } else {
64*da0073e9SAndroid Build Coastguard Worker#               names_opt_out = ::std::optional<at::ArrayRef<at::Dimname>>();
65*da0073e9SAndroid Build Coastguard Worker#           }
66*da0073e9SAndroid Build Coastguard Worker#       ```
67*da0073e9SAndroid Build Coastguard Worker#    - ScalarType? dtype (similarly for the rest of the arguments)
68*da0073e9SAndroid Build Coastguard Worker#       ```cpp
69*da0073e9SAndroid Build Coastguard Worker#           ::std::optional<c10::IValue> dtype_opt = (std::move(peek(stack, 2, 7))).toOptional<c10::IValue>();
70*da0073e9SAndroid Build Coastguard Worker#           ::std::optional<at::ScalarType> dtype_opt_out;
71*da0073e9SAndroid Build Coastguard Worker#           if (dtype_opt.has_value()) {
72*da0073e9SAndroid Build Coastguard Worker#               const c10::IValue dtype_opt_in = dtype_opt.value();
73*da0073e9SAndroid Build Coastguard Worker#               at::ScalarType dtype_base = dtype_opt_in.to<at::ScalarType>();
74*da0073e9SAndroid Build Coastguard Worker#                                                        ~~~~~~~~~~~~~~~~~~~~ <-- For base types, convert ivalue to it
75*da0073e9SAndroid Build Coastguard Worker#                                                                                 directly using ".to<T>()" API.
76*da0073e9SAndroid Build Coastguard Worker#               dtype_opt_out = ::std::optional<at::ScalarType>(dtype_base);
77*da0073e9SAndroid Build Coastguard Worker#           } else {
78*da0073e9SAndroid Build Coastguard Worker#               dtype_opt_out = ::std::optional<at::ScalarType>();
79*da0073e9SAndroid Build Coastguard Worker#           }
80*da0073e9SAndroid Build Coastguard Worker#       ```
81*da0073e9SAndroid Build Coastguard Worker#
82*da0073e9SAndroid Build Coastguard Worker# - Unboxed Kernel Call
83*da0073e9SAndroid Build Coastguard Worker#   ```cpp
84*da0073e9SAndroid Build Coastguard Worker#       auto result_ = torch::empty(
85*da0073e9SAndroid Build Coastguard Worker#           size_list_out,
86*da0073e9SAndroid Build Coastguard Worker#           names_opt_out,
87*da0073e9SAndroid Build Coastguard Worker#           options,
88*da0073e9SAndroid Build Coastguard Worker#           memory_format_opt_out
89*da0073e9SAndroid Build Coastguard Worker#       );
90*da0073e9SAndroid Build Coastguard Worker#   ```
91*da0073e9SAndroid Build Coastguard Worker#
92*da0073e9SAndroid Build Coastguard Worker# - Push Result Back to Stack
93*da0073e9SAndroid Build Coastguard Worker#   ```cpp
94*da0073e9SAndroid Build Coastguard Worker#       drop(stack, 7);
95*da0073e9SAndroid Build Coastguard Worker#       pack(stack, std::move(result_));
96*da0073e9SAndroid Build Coastguard Worker#   ```
97*da0073e9SAndroid Build Coastguard Workerconnector = "\n\t"
98*da0073e9SAndroid Build Coastguard Worker
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker# Return unboxing function name for a NativeFunction
101*da0073e9SAndroid Build Coastguard Workerdef name(f: NativeFunction) -> str:
102*da0073e9SAndroid Build Coastguard Worker    return f.func.name.unambiguous_name()
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Worker
105*da0073e9SAndroid Build Coastguard Worker# Convert all the arguments in a NativeFunction to C++ code
106*da0073e9SAndroid Build Coastguard Workerdef convert_arguments(f: NativeFunction) -> tuple[list[Binding], list[str]]:
107*da0073e9SAndroid Build Coastguard Worker    # we need the 'self' argument so method needs to be False
108*da0073e9SAndroid Build Coastguard Worker    args = (
109*da0073e9SAndroid Build Coastguard Worker        CppSignatureGroup.from_native_function(f, method=False)
110*da0073e9SAndroid Build Coastguard Worker        .most_faithful_signature()
111*da0073e9SAndroid Build Coastguard Worker        .arguments()
112*da0073e9SAndroid Build Coastguard Worker    )
113*da0073e9SAndroid Build Coastguard Worker    code_list = [
114*da0073e9SAndroid Build Coastguard Worker        f"c10::IValue {args[i].name} = std::move(peek(stack, {i}, {len(args)}));"
115*da0073e9SAndroid Build Coastguard Worker        for i in range(len(args))
116*da0073e9SAndroid Build Coastguard Worker    ] + [""]
117*da0073e9SAndroid Build Coastguard Worker    binding_list = []
118*da0073e9SAndroid Build Coastguard Worker    for arg in args:
119*da0073e9SAndroid Build Coastguard Worker        # expecting only Argument
120*da0073e9SAndroid Build Coastguard Worker        if not isinstance(arg.argument, Argument):
121*da0073e9SAndroid Build Coastguard Worker            raise Exception(  # noqa: TRY002
122*da0073e9SAndroid Build Coastguard Worker                f"Unexpected argument type, expecting `Argument` but got {arg}"
123*da0073e9SAndroid Build Coastguard Worker            )
124*da0073e9SAndroid Build Coastguard Worker        argument: Argument = arg.argument
125*da0073e9SAndroid Build Coastguard Worker        unboxed_name, _, code, decl = argumenttype_ivalue_convert(
126*da0073e9SAndroid Build Coastguard Worker            argument.type,
127*da0073e9SAndroid Build Coastguard Worker            argument.name,
128*da0073e9SAndroid Build Coastguard Worker            mutable=argument.is_write,
129*da0073e9SAndroid Build Coastguard Worker        )
130*da0073e9SAndroid Build Coastguard Worker        code_list.extend(decl)
131*da0073e9SAndroid Build Coastguard Worker        code_list.extend(code)
132*da0073e9SAndroid Build Coastguard Worker        binding_list.append(arg.with_name(unboxed_name))
133*da0073e9SAndroid Build Coastguard Worker    return binding_list, code_list
134*da0073e9SAndroid Build Coastguard Worker
135*da0073e9SAndroid Build Coastguard Worker
136*da0073e9SAndroid Build Coastguard Worker# Takes in the type, name and mutability corresponding to an argument, and generates a tuple of:
137*da0073e9SAndroid Build Coastguard Worker# (1) the C++ code necessary to unbox the argument
138*da0073e9SAndroid Build Coastguard Worker# (2) A Binding corresponding to the newly created unboxed variable, including variable name and its CType
139*da0073e9SAndroid Build Coastguard Workerdef argumenttype_ivalue_convert(
140*da0073e9SAndroid Build Coastguard Worker    t: Type, arg_name: str, *, mutable: bool = False
141*da0073e9SAndroid Build Coastguard Worker) -> tuple[str, CType, list[str], list[str]]:
142*da0073e9SAndroid Build Coastguard Worker    # Unboxing is for mobile, which doesn't care about SymInts
143*da0073e9SAndroid Build Coastguard Worker    ctype = cpp.argumenttype_type(
144*da0073e9SAndroid Build Coastguard Worker        t=t, mutable=mutable, binds=arg_name, symint=False
145*da0073e9SAndroid Build Coastguard Worker    ).type
146*da0073e9SAndroid Build Coastguard Worker
147*da0073e9SAndroid Build Coastguard Worker    if isinstance(t, BaseType):
148*da0073e9SAndroid Build Coastguard Worker        out_name = f"{arg_name}_base"
149*da0073e9SAndroid Build Coastguard Worker        code, decl = _gen_code_base_type(
150*da0073e9SAndroid Build Coastguard Worker            arg_name=arg_name, out_name=out_name, ctype=ctype
151*da0073e9SAndroid Build Coastguard Worker        )
152*da0073e9SAndroid Build Coastguard Worker    elif isinstance(t, OptionalType):
153*da0073e9SAndroid Build Coastguard Worker        out_name = f"{arg_name}_opt_out"
154*da0073e9SAndroid Build Coastguard Worker        code, decl = _gen_code_optional_type(
155*da0073e9SAndroid Build Coastguard Worker            arg_name=arg_name,
156*da0073e9SAndroid Build Coastguard Worker            out_name=out_name,
157*da0073e9SAndroid Build Coastguard Worker            t=t,
158*da0073e9SAndroid Build Coastguard Worker            ctype=ctype,
159*da0073e9SAndroid Build Coastguard Worker        )
160*da0073e9SAndroid Build Coastguard Worker    elif isinstance(t, ListType):
161*da0073e9SAndroid Build Coastguard Worker        out_name = f"{arg_name}_list_out"
162*da0073e9SAndroid Build Coastguard Worker        code, decl = _gen_code_list_type(
163*da0073e9SAndroid Build Coastguard Worker            arg_name=arg_name,
164*da0073e9SAndroid Build Coastguard Worker            out_name=out_name,
165*da0073e9SAndroid Build Coastguard Worker            t=t,
166*da0073e9SAndroid Build Coastguard Worker            ctype=ctype,
167*da0073e9SAndroid Build Coastguard Worker        )
168*da0073e9SAndroid Build Coastguard Worker    else:
169*da0073e9SAndroid Build Coastguard Worker        raise Exception(f"Cannot handle type {t}. arg_name: {arg_name}")  # noqa: TRY002
170*da0073e9SAndroid Build Coastguard Worker    return out_name, ctype, code, decl
171*da0073e9SAndroid Build Coastguard Worker
172*da0073e9SAndroid Build Coastguard Worker
173*da0073e9SAndroid Build Coastguard Workerdef _gen_code_base_type(
174*da0073e9SAndroid Build Coastguard Worker    arg_name: str, out_name: str, ctype: CType
175*da0073e9SAndroid Build Coastguard Worker) -> tuple[list[str], list[str]]:
176*da0073e9SAndroid Build Coastguard Worker    return [
177*da0073e9SAndroid Build Coastguard Worker        f"{ctype.cpp_type(strip_ref=True)} {out_name} = {arg_name}.to<{ctype.cpp_type(strip_ref=True)}>();"
178*da0073e9SAndroid Build Coastguard Worker    ], []
179*da0073e9SAndroid Build Coastguard Worker
180*da0073e9SAndroid Build Coastguard Worker
181*da0073e9SAndroid Build Coastguard Workerdef _gen_code_optional_type(
182*da0073e9SAndroid Build Coastguard Worker    arg_name: str, out_name: str, t: OptionalType, ctype: CType
183*da0073e9SAndroid Build Coastguard Worker) -> tuple[list[str], list[str]]:
184*da0073e9SAndroid Build Coastguard Worker    in_name = f"{arg_name}_opt_in"
185*da0073e9SAndroid Build Coastguard Worker    res_name, _, res_code, decl = argumenttype_ivalue_convert(t.elem, in_name)
186*da0073e9SAndroid Build Coastguard Worker    return (
187*da0073e9SAndroid Build Coastguard Worker        f"""
188*da0073e9SAndroid Build Coastguard Workerauto {arg_name}_opt = {arg_name}.toOptional<c10::IValue>();
189*da0073e9SAndroid Build Coastguard Worker{ctype.cpp_type(strip_ref=True)} {out_name};
190*da0073e9SAndroid Build Coastguard Workerif ({arg_name}_opt.has_value()) {{
191*da0073e9SAndroid Build Coastguard Worker    const c10::IValue {in_name} = {arg_name}_opt.value();
192*da0073e9SAndroid Build Coastguard Worker    {connector.join(res_code)}
193*da0073e9SAndroid Build Coastguard Worker    {out_name} = {ctype.cpp_type(strip_ref=True)}({res_name});
194*da0073e9SAndroid Build Coastguard Worker}} else {{
195*da0073e9SAndroid Build Coastguard Worker    {out_name} = {ctype.cpp_type(strip_ref=True)}();
196*da0073e9SAndroid Build Coastguard Worker}}
197*da0073e9SAndroid Build Coastguard Worker        """.split(
198*da0073e9SAndroid Build Coastguard Worker            "\n"
199*da0073e9SAndroid Build Coastguard Worker        ),
200*da0073e9SAndroid Build Coastguard Worker        decl,
201*da0073e9SAndroid Build Coastguard Worker    )
202*da0073e9SAndroid Build Coastguard Worker
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Workerdef _gen_code_list_type(
205*da0073e9SAndroid Build Coastguard Worker    arg_name: str, out_name: str, t: ListType, ctype: CType
206*da0073e9SAndroid Build Coastguard Worker) -> tuple[list[str], list[str]]:
207*da0073e9SAndroid Build Coastguard Worker    in_name = f"{arg_name}_list_in"
208*da0073e9SAndroid Build Coastguard Worker    elem_name = f"{arg_name}_elem"
209*da0073e9SAndroid Build Coastguard Worker    code = [f"const c10::List<c10::IValue> {in_name} = {arg_name}.toList();"]
210*da0073e9SAndroid Build Coastguard Worker    res_name, res_ctype, res_code, decl = argumenttype_ivalue_convert(t.elem, elem_name)
211*da0073e9SAndroid Build Coastguard Worker    # handle list type with size, e.g., bool[4]
212*da0073e9SAndroid Build Coastguard Worker    if isinstance(t.elem, BaseType) and t.elem.name == BaseTy.bool and t.size:
213*da0073e9SAndroid Build Coastguard Worker        code.extend(
214*da0073e9SAndroid Build Coastguard Worker            f"""
215*da0073e9SAndroid Build Coastguard Worker{ctype.cpp_type(strip_ref=True)} {out_name} = as_array<{res_ctype.cpp_type(strip_ref=True)}, {t.size}>({in_name});
216*da0073e9SAndroid Build Coastguard Worker            """.split(
217*da0073e9SAndroid Build Coastguard Worker                "\n"
218*da0073e9SAndroid Build Coastguard Worker            )
219*da0073e9SAndroid Build Coastguard Worker        )
220*da0073e9SAndroid Build Coastguard Worker    # we have to use c10::List for optional element. e.g., Tensor?[] -> c10::List<::std::optional<at::Tensor>>
221*da0073e9SAndroid Build Coastguard Worker    elif isinstance(t.elem, OptionalType):
222*da0073e9SAndroid Build Coastguard Worker        code.extend(
223*da0073e9SAndroid Build Coastguard Worker            f"""
224*da0073e9SAndroid Build Coastguard Worker{ctype.cpp_type(strip_ref=True)} {out_name};
225*da0073e9SAndroid Build Coastguard Workerfor (c10::IValue {elem_name}: {in_name}) {{
226*da0073e9SAndroid Build Coastguard Worker    {connector.join(res_code)}
227*da0073e9SAndroid Build Coastguard Worker    {out_name}.push_back({res_name});
228*da0073e9SAndroid Build Coastguard Worker}}
229*da0073e9SAndroid Build Coastguard Worker            """.split(
230*da0073e9SAndroid Build Coastguard Worker                "\n"
231*da0073e9SAndroid Build Coastguard Worker            )
232*da0073e9SAndroid Build Coastguard Worker        )
233*da0073e9SAndroid Build Coastguard Worker    else:
234*da0073e9SAndroid Build Coastguard Worker        # use ArrayRef as default.
235*da0073e9SAndroid Build Coastguard Worker        vec_name = arg_name + "_vec"
236*da0073e9SAndroid Build Coastguard Worker        # need to bring vector instantiation out of scope so that ArrayRef has valid data
237*da0073e9SAndroid Build Coastguard Worker        decl.append(f"std::vector<{res_ctype.cpp_type(strip_ref=True)}> {vec_name};")
238*da0073e9SAndroid Build Coastguard Worker        code.extend(
239*da0073e9SAndroid Build Coastguard Worker            f"""
240*da0073e9SAndroid Build Coastguard Workerfor (c10::IValue {elem_name}: {in_name}) {{
241*da0073e9SAndroid Build Coastguard Worker    {connector.join(res_code)}
242*da0073e9SAndroid Build Coastguard Worker    {vec_name}.push_back({res_name});
243*da0073e9SAndroid Build Coastguard Worker}}
244*da0073e9SAndroid Build Coastguard Worker{ctype.cpp_type(strip_ref=True)} {out_name}({vec_name});
245*da0073e9SAndroid Build Coastguard Worker            """.split(
246*da0073e9SAndroid Build Coastguard Worker                "\n"
247*da0073e9SAndroid Build Coastguard Worker            )
248*da0073e9SAndroid Build Coastguard Worker        )
249*da0073e9SAndroid Build Coastguard Worker    return code, decl
250