1from __future__ import annotations 2 3import textwrap 4from dataclasses import dataclass 5from typing import Sequence 6 7from torchgen.api.translate import translate 8from torchgen.api.types import DispatcherSignature 9from torchgen.context import method_with_native_function 10from torchgen.model import ( 11 Argument, 12 BaseTy, 13 BaseType, 14 FunctionSchema, 15 ListType, 16 NativeFunction, 17 OptionalType, 18 Return, 19 SchemaKind, 20 Type, 21) 22from torchgen.utils import mapMaybe 23 24 25def is_tensor(typ: Type) -> bool: 26 return isinstance(typ, BaseType) and typ.name == BaseTy.Tensor 27 28 29def is_optional_tensor(typ: Type) -> bool: 30 return isinstance(typ, OptionalType) and is_tensor(typ.elem) 31 32 33def is_tensor_list(typ: Type) -> bool: 34 return isinstance(typ, ListType) and is_tensor(typ.elem) 35 36 37def unwrap_tensor(name: str, cur_level_var: str) -> list[str]: 38 result = f"""\ 39 auto [{name}_value, {name}_bdim] = unwrapTensorAtLevel({name}, {cur_level_var});""" 40 return textwrap.dedent(result).split("\n") 41 42 43def unwrap_optional_tensor(name: str, cur_level_var: str) -> list[str]: 44 result = f"""\ 45 std::optional<Tensor> {name}_value; 46 std::optional<int64_t> {name}_bdim; 47 if ({name}) {{ 48 std::tie({name}_value, {name}_bdim) = unwrapTensorAtLevel({name}.value(), {cur_level_var}); 49 }}""" 50 return textwrap.dedent(result).split("\n") 51 52 53def gen_unwraps( 54 flat_arguments: Sequence[Argument], cur_level_var: str 55) -> tuple[str, list[str]]: 56 arg_names = [a.name for a in flat_arguments] 57 arg_types = [a.type for a in flat_arguments] 58 59 tensors = [name for typ, name in zip(arg_types, arg_names) if is_tensor(typ)] 60 optional_tensors = [ 61 name for typ, name in zip(arg_types, arg_names) if is_optional_tensor(typ) 62 ] 63 64 unwraps = [] 65 for tensor in tensors: 66 unwraps += unwrap_tensor(tensor, cur_level_var) 67 68 for opt_tensor in optional_tensors: 69 unwraps += unwrap_optional_tensor(opt_tensor, cur_level_var) 70 unwrap_code = "\n".join(unwraps) 71 72 unwrapped_arg_list = [] 73 for arg in arg_names: 74 if arg in tensors or arg in optional_tensors: 75 unwrapped_arg_list += [f"{arg}_value", f"{arg}_bdim"] 76 else: 77 unwrapped_arg_list.append(arg) 78 return unwrap_code, unwrapped_arg_list 79 80 81def gen_case_where_all_bdims_are_none( 82 outer_sig: DispatcherSignature, schema: FunctionSchema, cur_level_var: str 83) -> str: 84 conditions = [] 85 flat_args = schema.arguments.flat_all 86 for arg in flat_args: 87 if not arg.type.is_tensor_like(): 88 continue 89 conditions.append(f"!isBatchedAtLevel({arg.name}, {cur_level_var})") 90 91 sig = DispatcherSignature.from_schema(schema) 92 translated_args = ", ".join( 93 e.expr for e in translate(outer_sig.arguments(), sig.arguments()) 94 ) 95 return f"""\ 96if ({' && '.join(conditions)}) {{ 97 return at::_ops::{sig.func.name.unambiguous_name()}::call({translated_args}); 98}}""" 99 100 101def gen_returns( 102 returns: tuple[Return, ...], cur_level_var: str, results_var: str 103) -> str: 104 idx = 0 105 wrapped_returns = [] 106 for ret in returns: 107 if is_tensor(ret.type): 108 wrapped_returns.append( 109 f"makeBatched(std::get<{idx}>({results_var}), std::get<{idx + 1}>({results_var}), {cur_level_var})" 110 ) 111 idx += 2 112 elif is_tensor_list(ret.type): 113 wrapped_returns.append( 114 f"makeBatchedVector(std::get<{idx}>({results_var}), std::get<{idx+1}>({results_var}), {cur_level_var})" 115 ) 116 idx += 2 117 else: 118 wrapped_returns.append(f"std::get<{idx}>({results_var})") 119 idx += 1 120 if len(wrapped_returns) == 1: 121 result = f"return {wrapped_returns[0]};" 122 else: 123 result = f'return std::make_tuple({", ".join(wrapped_returns)});' 124 return result 125 126 127def accepts_at_least_one_tensor_input(schema: FunctionSchema) -> bool: 128 return any(a.type.is_tensor_like() for a in schema.arguments.flat_all) 129 130 131def is_mutated_arg(argument: Argument) -> bool: 132 return argument.annotation is not None and argument.annotation.is_write 133 134 135def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> str | None: 136 # Assumptions: 137 # - only one argument is being modified in-place 138 # - the argument that is being modified in-place is the first argument 139 # - all returns are either Tensor, tuple of Tensor, or TensorList 140 schema = native_function.func 141 sig = DispatcherSignature.from_schema(schema) 142 returns = schema.returns 143 144 # Check assumptions. If these are invalid we return None 145 # and punt the work to handle them to the future. 146 assert schema.kind() == SchemaKind.inplace 147 if not is_mutated_arg(schema.arguments.flat_all[0]): 148 return None 149 if not len([arg for arg in schema.arguments.flat_all if is_mutated_arg(arg)]) == 1: 150 return None 151 152 # Only support cases where all returns are Tensors or vector<Tensor> 153 if len(returns) == 0: 154 return None 155 if not all(is_tensor(ret.type) or is_tensor_list(ret.type) for ret in returns): 156 return None 157 if not accepts_at_least_one_tensor_input(schema): 158 return None 159 160 cur_level_var = "cur_level" 161 162 unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var) 163 bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var) 164 165 return f"""\ 166template <typename batch_rule_t, batch_rule_t batch_rule> 167{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{ 168 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); 169 auto maybe_layer = maybeCurrentDynamicLayer(); 170 vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); 171 int64_t {cur_level_var} = maybe_layer->layerId(); 172{textwrap.indent(bdims_all_none_case, " ")} 173{textwrap.indent(unwraps, " ")} 174 batch_rule({', '.join(unwrapped_arg_list)}); 175 return {schema.arguments.flat_all[0].name}; 176}}""" 177 178 179def gen_vmap_plumbing_no_returns(native_function: NativeFunction) -> str: 180 schema = native_function.func 181 sig = DispatcherSignature.from_schema(schema) 182 cur_level_var = "cur_level" 183 184 unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var) 185 bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var) 186 187 return f"""\ 188template <typename batch_rule_t, batch_rule_t batch_rule> 189{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{ 190 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); 191 auto maybe_layer = maybeCurrentDynamicLayer(); 192 vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); 193 int64_t {cur_level_var} = maybe_layer->layerId(); 194{textwrap.indent(bdims_all_none_case, " ")} 195{textwrap.indent(unwraps, " ")} 196 batch_rule({', '.join(unwrapped_arg_list)}); 197}}""" 198 199 200def gen_vmap_plumbing(native_function: NativeFunction) -> str | None: 201 schema = native_function.func 202 sig = DispatcherSignature.from_schema(schema) 203 returns = schema.returns 204 205 # Only support cases where all returns are Tensors or vector<Tensor> 206 if not accepts_at_least_one_tensor_input(schema): 207 return None 208 if len(returns) == 0: 209 return gen_vmap_plumbing_no_returns(native_function) 210 return_symint_overrides = [ 211 "_scaled_dot_product_flash_attention", 212 "_scaled_dot_product_cudnn_attention", 213 ] 214 if ( 215 not all(ret.type.is_tensor_like() for ret in returns) 216 and schema.name.unambiguous_name() not in return_symint_overrides 217 ): 218 return None 219 # in-place views need special handling 220 if "inplace_view" in native_function.tags: 221 return None 222 223 if schema.kind() == SchemaKind.inplace: 224 return gen_vmap_inplace_plumbing(native_function) 225 226 # Don't support these (mutable, out, scratch) 227 if schema.kind() != SchemaKind.functional: 228 return None 229 230 results_var = "results" 231 cur_level_var = "cur_level" 232 233 unwraps, unwrapped_arg_list = gen_unwraps(schema.arguments.flat_all, cur_level_var) 234 bdims_all_none_case = gen_case_where_all_bdims_are_none(sig, schema, cur_level_var) 235 236 wrapped_returns = gen_returns(returns, cur_level_var, results_var) 237 return f"""\ 238template <typename batch_rule_t, batch_rule_t batch_rule> 239{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{ 240 c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); 241 auto maybe_layer = maybeCurrentDynamicLayer(); 242 vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); 243 int64_t {cur_level_var} = maybe_layer->layerId(); 244{textwrap.indent(bdims_all_none_case, " ")} 245{textwrap.indent(unwraps, " ")} 246 auto {results_var} = batch_rule({', '.join(unwrapped_arg_list)}); 247 {wrapped_returns} 248}}""" 249 250 251@dataclass(frozen=True) 252class ComputeBatchRulePlumbing: 253 @method_with_native_function 254 def __call__(self, f: NativeFunction) -> str | None: 255 result = gen_vmap_plumbing(f) 256 return result 257 258 259def gen_all_vmap_plumbing(native_functions: Sequence[NativeFunction]) -> str: 260 body = "\n".join(list(mapMaybe(ComputeBatchRulePlumbing(), native_functions))) 261 return f""" 262#pragma once 263#include <ATen/Operators.h> 264#include <ATen/functorch/PlumbingHelper.h> 265 266namespace at {{ namespace functorch {{ 267 268{body} 269 270}}}} // namespace at::functorch 271""" 272