xref: /aosp_15_r20/external/pytorch/torchgen/api/dispatcher.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Workerfrom __future__ import annotations
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport itertools
4*da0073e9SAndroid Build Coastguard Workerfrom typing import Sequence
5*da0073e9SAndroid Build Coastguard Worker
6*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api import cpp
7*da0073e9SAndroid Build Coastguard Workerfrom torchgen.api.types import ArgName, Binding, CType, NamedCType
8*da0073e9SAndroid Build Coastguard Workerfrom torchgen.model import (
9*da0073e9SAndroid Build Coastguard Worker    Argument,
10*da0073e9SAndroid Build Coastguard Worker    FunctionSchema,
11*da0073e9SAndroid Build Coastguard Worker    Return,
12*da0073e9SAndroid Build Coastguard Worker    SelfArgument,
13*da0073e9SAndroid Build Coastguard Worker    TensorOptionsArguments,
14*da0073e9SAndroid Build Coastguard Worker    Type,
15*da0073e9SAndroid Build Coastguard Worker)
16*da0073e9SAndroid Build Coastguard Workerfrom torchgen.utils import assert_never, concatMap
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker# This file describes the translation of JIT schema to the dispatcher
20*da0073e9SAndroid Build Coastguard Worker# API, the *unboxed* calling convention by which invocations through
21*da0073e9SAndroid Build Coastguard Worker# the dispatcher are made.  Historically, the dispatcher API matched
22*da0073e9SAndroid Build Coastguard Worker# the C++ API, but with the establishment of the boxed API, we've
23*da0073e9SAndroid Build Coastguard Worker# made changes to the dispatcher API to so that the unboxed API
24*da0073e9SAndroid Build Coastguard Worker# better aligns with the boxed API.  The dispatcher API hooks heavily
25*da0073e9SAndroid Build Coastguard Worker# into our template based boxing/unboxing machinery, so changes
26*da0073e9SAndroid Build Coastguard Worker# to this convention will usually need template updates too.
27*da0073e9SAndroid Build Coastguard Worker#
28*da0073e9SAndroid Build Coastguard Worker# Prominent characteristics of the dispatcher API:
29*da0073e9SAndroid Build Coastguard Worker#
30*da0073e9SAndroid Build Coastguard Worker#   - dtype, layout, device and pin_memory are represented as separate
31*da0073e9SAndroid Build Coastguard Worker#     arguments.
32*da0073e9SAndroid Build Coastguard Worker#
33*da0073e9SAndroid Build Coastguard Worker
34*da0073e9SAndroid Build Coastguard Worker
35*da0073e9SAndroid Build Coastguard Workerdef name(func: FunctionSchema) -> str:
36*da0073e9SAndroid Build Coastguard Worker    return cpp.name(func)
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Workerdef argumenttype_type(
40*da0073e9SAndroid Build Coastguard Worker    t: Type,
41*da0073e9SAndroid Build Coastguard Worker    *,
42*da0073e9SAndroid Build Coastguard Worker    mutable: bool,
43*da0073e9SAndroid Build Coastguard Worker    binds: ArgName,
44*da0073e9SAndroid Build Coastguard Worker    remove_non_owning_ref_types: bool = False,
45*da0073e9SAndroid Build Coastguard Worker    symint: bool = True,
46*da0073e9SAndroid Build Coastguard Worker) -> NamedCType:
47*da0073e9SAndroid Build Coastguard Worker    # This is a faux amis.  If it makes sense in the future to add
48*da0073e9SAndroid Build Coastguard Worker    # more special cases here, or invert things so cpp.argument_type
49*da0073e9SAndroid Build Coastguard Worker    # calls this, or just completely inline the function, please do
50*da0073e9SAndroid Build Coastguard Worker    # it.
51*da0073e9SAndroid Build Coastguard Worker    return cpp.argumenttype_type(
52*da0073e9SAndroid Build Coastguard Worker        t,
53*da0073e9SAndroid Build Coastguard Worker        mutable=mutable,
54*da0073e9SAndroid Build Coastguard Worker        binds=binds,
55*da0073e9SAndroid Build Coastguard Worker        symint=symint,
56*da0073e9SAndroid Build Coastguard Worker        remove_non_owning_ref_types=remove_non_owning_ref_types,
57*da0073e9SAndroid Build Coastguard Worker    )
58*da0073e9SAndroid Build Coastguard Worker
59*da0073e9SAndroid Build Coastguard Worker
60*da0073e9SAndroid Build Coastguard Workerdef argument_type(
61*da0073e9SAndroid Build Coastguard Worker    a: Argument,
62*da0073e9SAndroid Build Coastguard Worker    *,
63*da0073e9SAndroid Build Coastguard Worker    binds: ArgName,
64*da0073e9SAndroid Build Coastguard Worker    remove_non_owning_ref_types: bool = False,
65*da0073e9SAndroid Build Coastguard Worker    symint: bool = True,
66*da0073e9SAndroid Build Coastguard Worker) -> NamedCType:
67*da0073e9SAndroid Build Coastguard Worker    return argumenttype_type(
68*da0073e9SAndroid Build Coastguard Worker        a.type,
69*da0073e9SAndroid Build Coastguard Worker        mutable=a.is_write,
70*da0073e9SAndroid Build Coastguard Worker        binds=binds,
71*da0073e9SAndroid Build Coastguard Worker        remove_non_owning_ref_types=remove_non_owning_ref_types,
72*da0073e9SAndroid Build Coastguard Worker        symint=symint,
73*da0073e9SAndroid Build Coastguard Worker    )
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker
76*da0073e9SAndroid Build Coastguard Workerdef returns_type(rs: Sequence[Return], *, symint: bool = True) -> CType:
77*da0073e9SAndroid Build Coastguard Worker    # At present, there is no difference. But there could be!
78*da0073e9SAndroid Build Coastguard Worker    return cpp.returns_type(rs, symint=symint)
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Workerdef jit_arguments(func: FunctionSchema) -> list[Argument]:
82*da0073e9SAndroid Build Coastguard Worker    def to_argument(
83*da0073e9SAndroid Build Coastguard Worker        a: Argument | TensorOptionsArguments | SelfArgument,
84*da0073e9SAndroid Build Coastguard Worker    ) -> list[Argument]:
85*da0073e9SAndroid Build Coastguard Worker        if isinstance(a, Argument):
86*da0073e9SAndroid Build Coastguard Worker            return [a]
87*da0073e9SAndroid Build Coastguard Worker        elif isinstance(a, SelfArgument):
88*da0073e9SAndroid Build Coastguard Worker            return [a.argument]
89*da0073e9SAndroid Build Coastguard Worker        elif isinstance(a, TensorOptionsArguments):
90*da0073e9SAndroid Build Coastguard Worker            return [a.dtype, a.layout, a.device, a.pin_memory]
91*da0073e9SAndroid Build Coastguard Worker        else:
92*da0073e9SAndroid Build Coastguard Worker            assert_never(a)
93*da0073e9SAndroid Build Coastguard Worker
94*da0073e9SAndroid Build Coastguard Worker    return list(
95*da0073e9SAndroid Build Coastguard Worker        concatMap(
96*da0073e9SAndroid Build Coastguard Worker            to_argument,
97*da0073e9SAndroid Build Coastguard Worker            itertools.chain(
98*da0073e9SAndroid Build Coastguard Worker                func.arguments.positional, func.arguments.kwarg_only, func.arguments.out
99*da0073e9SAndroid Build Coastguard Worker            ),
100*da0073e9SAndroid Build Coastguard Worker        )
101*da0073e9SAndroid Build Coastguard Worker    )
102*da0073e9SAndroid Build Coastguard Worker
103*da0073e9SAndroid Build Coastguard Worker
104*da0073e9SAndroid Build Coastguard Workerdef argument(
105*da0073e9SAndroid Build Coastguard Worker    a: Argument, *, remove_non_owning_ref_types: bool = False, symint: bool = True
106*da0073e9SAndroid Build Coastguard Worker) -> Binding:
107*da0073e9SAndroid Build Coastguard Worker    return Binding(
108*da0073e9SAndroid Build Coastguard Worker        nctype=argument_type(
109*da0073e9SAndroid Build Coastguard Worker            a,
110*da0073e9SAndroid Build Coastguard Worker            binds=a.name,
111*da0073e9SAndroid Build Coastguard Worker            remove_non_owning_ref_types=remove_non_owning_ref_types,
112*da0073e9SAndroid Build Coastguard Worker            symint=symint,
113*da0073e9SAndroid Build Coastguard Worker        ),
114*da0073e9SAndroid Build Coastguard Worker        name=a.name,
115*da0073e9SAndroid Build Coastguard Worker        argument=a,
116*da0073e9SAndroid Build Coastguard Worker    )
117*da0073e9SAndroid Build Coastguard Worker
118*da0073e9SAndroid Build Coastguard Worker
119*da0073e9SAndroid Build Coastguard Workerdef arguments(func: FunctionSchema, *, symint: bool = True) -> list[Binding]:
120*da0073e9SAndroid Build Coastguard Worker    return [argument(a, symint=symint) for a in jit_arguments(func)]
121