xref: /aosp_15_r20/external/pytorch/torchgen/gen_vmap_plumbing.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3import textwrap
4from dataclasses import dataclass
5from typing import Sequence
6
7from torchgen.api.translate import translate
8from torchgen.api.types import DispatcherSignature
9from torchgen.context import method_with_native_function
10from torchgen.model import (
11    Argument,
12    BaseTy,
13    BaseType,
14    FunctionSchema,
15    ListType,
16    NativeFunction,
17    OptionalType,
18    Return,
19    SchemaKind,
20    Type,
21)
22from torchgen.utils import mapMaybe
23
24
25def is_tensor(typ: Type) -> bool:
26    return isinstance(typ, BaseType) and typ.name == BaseTy.Tensor
27
28
29def is_optional_tensor(typ: Type) -> bool:
30    return isinstance(typ, OptionalType) and is_tensor(typ.elem)
31
32
33def is_tensor_list(typ: Type) -> bool:
34    return isinstance(typ, ListType) and is_tensor(typ.elem)
35
36
37def unwrap_tensor(name: str, cur_level_var: str) -> list[str]:
38    result = f"""\
39    auto [{name}_value, {name}_bdim] = unwrapTensorAtLevel({name}, {cur_level_var});"""
40    return textwrap.dedent(result).split("\n")
41
42
43def unwrap_optional_tensor(name: str, cur_level_var: str) -> list[str]:
44    result = f"""\
45    std::optional<Tensor> {name}_value;
46    std::optional<int64_t> {name}_bdim;
47    if ({name}) {{
48        std::tie({name}_value, {name}_bdim) = unwrapTensorAtLevel({name}.value(), {cur_level_var});
49    }}"""
50    return textwrap.dedent(result).split("\n")
51
52
53def gen_unwraps(
54    flat_arguments: Sequence[Argument], cur_level_var: str
55) -> tuple[str, list[str]]:
56    arg_names = [a.name for a in flat_arguments]
57    arg_types = [a.type for a in flat_arguments]
58
59    tensors = [name for typ, name in zip(arg_types, arg_names) if is_tensor(typ)]
60    optional_tensors = [
61        name for typ, name in zip(arg_types, arg_names) if is_optional_tensor(typ)
62    ]
63
64    unwraps = []
65    for tensor in tensors:
66        unwraps += unwrap_tensor(tensor, cur_level_var)
67
68    for opt_tensor in optional_tensors:
69        unwraps += unwrap_optional_tensor(opt_tensor, cur_level_var)
70    unwrap_code = "\n".join(unwraps)
71
72    unwrapped_arg_list = []
73    for arg in arg_names:
74        if arg in tensors or arg in optional_tensors:
75            unwrapped_arg_list += [f"{arg}_value", f"{arg}_bdim"]
76        else:
77            unwrapped_arg_list.append(arg)
78    return unwrap_code, unwrapped_arg_list
79
80
81def gen_case_where_all_bdims_are_none(
82    outer_sig: DispatcherSignature, schema: FunctionSchema, cur_level_var: str
83) -> str:
84    conditions = []
85    flat_args = schema.arguments.flat_all
86    for arg in flat_args:
87        if not arg.type.is_tensor_like():
88            continue
89        conditions.append(f"!isBatchedAtLevel({arg.name}, {cur_level_var})")
90
91    sig = DispatcherSignature.from_schema(schema)
92    translated_args = ", ".join(
93        e.expr for e in translate(outer_sig.arguments(), sig.arguments())
94    )
95    return f"""\
96if ({' && '.join(conditions)}) {{
97  return at::_ops::{sig.func.name.unambiguous_name()}::call({translated_args});
98}}"""
99
100
101def gen_returns(
102    returns: tuple[Return, ...], cur_level_var: str, results_var: str
103) -> str:
104    idx = 0
105    wrapped_returns = []
106    for ret in returns:
107        if is_tensor(ret.type):
108            wrapped_returns.append(
109                f"makeBatched(std::get<{idx}>({results_var}), std::get<{idx + 1}>({results_var}), {cur_level_var})"
110            )
111            idx += 2
112        elif is_tensor_list(ret.type):
113            wrapped_returns.append(
114                f"makeBatchedVector(std::get<{idx}>({results_var}), std::get<{idx+1}>({results_var}), {cur_level_var})"
115            )
116            idx += 2
117        else:
118            wrapped_returns.append(f"std::get<{idx}>({results_var})")
119            idx += 1
120    if len(wrapped_returns) == 1:
121        result = f"return {wrapped_returns[0]};"
122    else:
123        result = f'return std::make_tuple({", ".join(wrapped_returns)});'
124    return result
125
126
127def accepts_at_least_one_tensor_input(schema: FunctionSchema) -> bool:
128    return any(a.type.is_tensor_like() for a in schema.arguments.flat_all)
129
130
131def is_mutated_arg(argument: Argument) -> bool:
132    return argument.annotation is not None and argument.annotation.is_write
133
134
135def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> str | None:
136    # Assumptions:
137    # - only one argument is being modified in-place
138    # - the argument that is being modified in-place is the first argument
139    # - all returns are either Tensor, tuple of Tensor, or TensorList
140    schema = native_function.func
141    sig = DispatcherSignature.from_schema(schema)
142    returns = schema.returns
143
144    # Check assumptions. If these are invalid we return None
145    # and punt the work to handle them to the future.
146    assert schema.kind() == SchemaKind.inplace
147    if not is_mutated_arg(schema.arguments.flat_all[0]):
148        return None
149    if not len([arg for arg in schema.arguments.flat_all if is_mutated_arg(arg)]) == 1:
150        return None
151
152    # Only support cases where all returns are Tensors or vector<Tensor>
153    if len(returns) == 0:
154        return None
155    if not all(is_tensor(ret.type) or is_tensor_list(ret.type) for ret in returns):
156        return None
157    if not accepts_at_least_one_tensor_input(schema):
158        return None
159
160    cur_level_var = "cur_level"
161
162    unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
163    bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
164
165    return f"""\
166template <typename batch_rule_t, batch_rule_t batch_rule>
167{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{
168  c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
169  auto maybe_layer = maybeCurrentDynamicLayer();
170  vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing");
171  int64_t {cur_level_var} = maybe_layer->layerId();
172{textwrap.indent(bdims_all_none_case, "  ")}
173{textwrap.indent(unwraps, "  ")}
174  batch_rule({', '.join(unwrapped_arg_list)});
175  return {schema.arguments.flat_all[0].name};
176}}"""
177
178
179def gen_vmap_plumbing_no_returns(native_function: NativeFunction) -> str:
180    schema = native_function.func
181    sig = DispatcherSignature.from_schema(schema)
182    cur_level_var = "cur_level"
183
184    unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
185    bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
186
187    return f"""\
188template <typename batch_rule_t, batch_rule_t batch_rule>
189{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{
190  c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
191  auto maybe_layer = maybeCurrentDynamicLayer();
192  vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns");
193  int64_t {cur_level_var} = maybe_layer->layerId();
194{textwrap.indent(bdims_all_none_case, "  ")}
195{textwrap.indent(unwraps, "  ")}
196  batch_rule({', '.join(unwrapped_arg_list)});
197}}"""
198
199
200def gen_vmap_plumbing(native_function: NativeFunction) -> str | None:
201    schema = native_function.func
202    sig = DispatcherSignature.from_schema(schema)
203    returns = schema.returns
204
205    # Only support cases where all returns are Tensors or vector<Tensor>
206    if not accepts_at_least_one_tensor_input(schema):
207        return None
208    if len(returns) == 0:
209        return gen_vmap_plumbing_no_returns(native_function)
210    return_symint_overrides = [
211        "_scaled_dot_product_flash_attention",
212        "_scaled_dot_product_cudnn_attention",
213    ]
214    if (
215        not all(ret.type.is_tensor_like() for ret in returns)
216        and schema.name.unambiguous_name() not in return_symint_overrides
217    ):
218        return None
219    # in-place views need special handling
220    if "inplace_view" in native_function.tags:
221        return None
222
223    if schema.kind() == SchemaKind.inplace:
224        return gen_vmap_inplace_plumbing(native_function)
225
226    # Don't support these (mutable, out, scratch)
227    if schema.kind() != SchemaKind.functional:
228        return None
229
230    results_var = "results"
231    cur_level_var = "cur_level"
232
233    unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var)
234    bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var)
235
236    wrapped_returns = gen_returns(returns, cur_level_var, results_var)
237    return f"""\
238template <typename batch_rule_t, batch_rule_t batch_rule>
239{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{
240  c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
241  auto maybe_layer = maybeCurrentDynamicLayer();
242  vmap_check_escaped(maybe_layer, "gen_vmap_plumbing");
243  int64_t {cur_level_var} = maybe_layer->layerId();
244{textwrap.indent(bdims_all_none_case, "  ")}
245{textwrap.indent(unwraps, "  ")}
246  auto {results_var} = batch_rule({', '.join(unwrapped_arg_list)});
247  {wrapped_returns}
248}}"""
249
250
251@dataclass(frozen=True)
252class ComputeBatchRulePlumbing:
253    @method_with_native_function
254    def __call__(self, f: NativeFunction) -> str | None:
255        result = gen_vmap_plumbing(f)
256        return result
257
258
259def gen_all_vmap_plumbing(native_functions: Sequence[NativeFunction]) -> str:
260    body = "\n".join(list(mapMaybe(ComputeBatchRulePlumbing(), native_functions)))
261    return f"""
262#pragma once
263#include <ATen/Operators.h>
264#include <ATen/functorch/PlumbingHelper.h>
265
266namespace at {{ namespace functorch {{
267
268{body}
269
270}}}} // namespace at::functorch
271"""
272