1# Generates ViewFuncs.h/cpp 2# 3# NOTE: If any changes are being made to the ViewFunc codegen please also check 4# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp 5# The fallback is expected to mimic this codegen, so we should keep the two in sync. 6 7from __future__ import annotations 8 9from typing import TYPE_CHECKING 10 11import torchgen.api.dispatcher as dispatcher 12from torchgen.api.translate import translate 13from torchgen.api.types import ( 14 BaseCType, 15 Binding, 16 NamedCType, 17 SymIntT, 18 tensorT, 19 VectorCType, 20) 21from torchgen.code_template import CodeTemplate 22from torchgen.model import Argument, NativeFunction, OptionalType 23from torchgen.utils import FileManager 24 25from .gen_inplace_or_view_type import ( 26 CALL_DISPATCH, 27 extract_bindings, 28 get_view_info, 29 modifies_arguments, 30 use_derived, 31) 32 33 34if TYPE_CHECKING: 35 from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo 36 37 38FUNCTION_DECLARATION = CodeTemplate( 39 """\ 40#define ${uppercase_op}_AVAILABLE 41struct ${op} : public ${superclass} { 42 ${op}(${constructor_args}) ${initializer_list} 43 {}; 44 virtual ~${op}() override {}; 45 virtual std::vector<c10::SymInt> get_symints() const override; 46 virtual size_t num_symints() const override; 47 virtual std::vector<at::Tensor> get_tensors() const override; 48 virtual size_t num_tensors() const override; 49 virtual at::Tensor operator()(const at::Tensor&) const override; 50 virtual std::unique_ptr<ViewFunc> clone_and_set( 51 std::optional<std::vector<c10::SymInt>> = ::std::nullopt, 52 std::optional<std::vector<at::Tensor>> = ::std::nullopt) const override; 53 54protected: 55 virtual void set_symints(std::vector<c10::SymInt>) override; 56 virtual void set_tensors(std::vector<at::Tensor>) override; 57 58private: 59 ${state} 60}; 61 62""" 63) 64 65FUNCTION_DEFINITION = CodeTemplate( 66 """\ 67std::vector<c10::SymInt> ${op}::get_symints() const { 68 ${get_symints} 69} 70 71size_t ${op}::num_symints() const { 72 return static_cast<size_t>(${num_symints}); 73} 74 75void ${op}::set_symints(std::vector<c10::SymInt> ${symints_vec}) { 76 TORCH_INTERNAL_ASSERT(${symints_vec}.size() == num_symints()); 77 ${set_symints} 78} 79 80std::vector<at::Tensor> ${op}::get_tensors() const { 81 ${get_tensors} 82} 83 84size_t ${op}::num_tensors() const { 85 return static_cast<size_t>(${num_tensors}); 86} 87 88void ${op}::set_tensors(std::vector<at::Tensor> ${tensors_vec}) { 89 TORCH_INTERNAL_ASSERT(${tensors_vec}.size() == num_tensors()); 90 ${set_tensors} 91} 92 93at::Tensor ${op}::operator()(const at::Tensor& ${call_input_name}) const { 94 return ${op_call}; 95} 96 97std::unique_ptr<ViewFunc> ${op}::clone_and_set( 98 std::optional<std::vector<c10::SymInt>> ${symints_vec}, 99 std::optional<std::vector<at::Tensor>> ${tensors_vec}) const { 100 auto output = std::make_unique<${op}>(${clone_args}); 101 if (${symints_vec}.has_value()) { 102 output->set_symints(std::move(*(${symints_vec}))); 103 } 104 if (${tensors_vec}.has_value()) { 105 output->set_tensors(std::move(*(${tensors_vec}))); 106 } 107 return output; 108} 109 110""" 111) 112 113 114# e.g. as_strided -> AsStridedViewFunc for camel case or 115# as_strided_view_func otherwise 116def view_func_name( 117 f: NativeFunction, include_namespace: bool = False, camel_case: bool = True 118) -> str: 119 name = f.func.name.unambiguous_name() 120 view_func_name = f"{name.replace('.', '_')}_view_func" 121 if camel_case: 122 is_private = view_func_name.startswith("_") 123 view_func_name = "".join( 124 [p.title() for p in view_func_name.replace(".", "_").split("_")] 125 ) 126 if is_private: 127 # put the leading underscore back in 128 view_func_name = f"_{view_func_name}" 129 namespace = "torch::autograd::generated::" if include_namespace else "" 130 return f"{namespace}{view_func_name}" 131 132 133def is_symint_or_tensor(arg: Argument) -> bool: 134 return arg.type.is_tensor_like() or arg.type.is_symint_like() 135 136 137def remove_const_ref(binding: Binding) -> Binding: 138 return Binding( 139 name=binding.name, 140 nctype=binding.nctype.remove_const_ref(), 141 argument=binding.argument, 142 default=binding.default, 143 ) 144 145 146def returns_multi_tensor(fn: NativeFunction) -> bool: 147 returns = fn.func.returns 148 assert len(returns) == 1 149 returns_list_like = returns[0].type.is_list_like() is not None 150 returns_tensor_like = returns[0].type.is_tensor_like() 151 return returns_list_like and returns_tensor_like 152 153 154# Generates strings with logic for getting / setting state of a particular type. 155# 156# Args: 157# bindings (list): List of state bindings of interest (may be empty) 158# state_vec_type (NamedCType): Type of vector to either return or copy from 159# 160# Returns: 161# tuple: (list of getter logic strings, list of setter logic strings, string 162# with num items expression) 163def generate_state_getter_setter( 164 bindings: list[Binding], 165 state_vec_type: NamedCType, 166) -> tuple[list[str], list[str], str]: 167 getter_logic = [] 168 setter_logic = [] 169 170 state_vec = state_vec_type.name 171 getter_logic.append(f"{state_vec_type.cpp_type()} {state_vec};") 172 if len(bindings) > 0: 173 setter_logic.append("auto i = 0;") 174 175 num_exprs = [] 176 for i, b in enumerate(bindings): 177 assert isinstance(b.argument, Argument) 178 if b.argument.type.is_list_like(): 179 # Handle list-likes. 180 num_expr = f"{b.name}.size()" 181 num_exprs.append(num_expr) 182 getter = f"{state_vec}.insert({state_vec}.end(), {b.name}.begin(), {b.name}.end());" 183 setter = f"std::copy({state_vec}.begin() + i, {state_vec}.begin() + i + {b.name}.size(), {b.name}.begin());" 184 elif isinstance(b.argument.type, OptionalType): 185 # Handle optionals. 186 num_expr = f"({b.name}.has_value() ? 1 : 0)" 187 num_exprs.append(num_expr) 188 conditional = f"if({b.name}.has_value())" 189 getter = ( 190 f"{conditional} {state_vec}.insert({state_vec}.end(), *({b.name}));" 191 ) 192 setter = f"{conditional} {b.name} = {state_vec}[i];" 193 else: 194 num_expr = "1" 195 num_exprs.append(num_expr) 196 getter = f"{state_vec}.push_back({b.name});" 197 setter = f"{b.name} = {state_vec}[i];" 198 199 getter_logic.append(getter) 200 setter_logic.append(setter) 201 if i < len(bindings) - 1: 202 setter_logic.append(f"i += {num_expr};") 203 204 # Reserve / assert based on the total number of items expression. 205 num_items = "0" if len(num_exprs) == 0 else " + ".join(num_exprs) 206 if len(bindings) > 0: 207 getter_logic.insert(1, f"{state_vec}.reserve({num_items});") 208 209 getter_logic.append(f"return {state_vec};") 210 211 return getter_logic, setter_logic, num_items 212 213 214def process_function(fn: NativeFunction, template: CodeTemplate) -> str: 215 bindings = extract_bindings(fn) 216 non_self_bindings = [b for b in bindings if b.name != "self"] 217 218 non_self_args = fn.func.arguments.flat_all[1:] 219 non_self_value_bindings = [ 220 dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args 221 ] 222 223 # Generate constructor / clone args for the generated struct. 224 constructor_args = [b.defn() for b in non_self_bindings] 225 clone_args = [b.name for b in non_self_bindings] 226 227 # Generate state variable declarations for the generated struct. 228 state_variables = [ 229 f"{remove_const_ref(b).defn()};" for b in non_self_value_bindings 230 ] 231 232 # Generate initializer list expressions for the generated struct. 233 # allow_expensive_conversions=True because we need to store e.g. SymIntArrayRefs as 234 # vector<SymInt>s. 235 init_exprs = translate( 236 non_self_bindings, non_self_value_bindings, allow_expensive_conversions=True 237 ) 238 initializers = [] 239 for b, init_expr in zip(non_self_bindings, init_exprs): 240 name = b.nctype.name 241 assert isinstance(name, str) 242 initializers.append(f"{name}({init_expr.expr})") 243 244 # Generate call to underlying view op 245 call_input_name = "input_base" 246 op_call_args = [call_input_name, *(b.name for b in non_self_bindings)] 247 op_call = CALL_DISPATCH.substitute( 248 unambiguous_name=fn.func.name.unambiguous_name(), 249 unpacked_args=op_call_args, 250 ) 251 252 # Multi-output views additionally require a view_idx for disambiguation. 253 if returns_multi_tensor(fn): 254 view_idx_name = "view_idx" 255 view_idx_typename = "int64_t" 256 view_idx_decl = f"{view_idx_typename} {view_idx_name}" 257 constructor_args.append(view_idx_decl) 258 clone_args.append(view_idx_name) 259 state_variables.append(f"{view_idx_decl};") 260 initializers.append(f"{view_idx_name}({view_idx_name})") 261 op_call += f"[{view_idx_name}]" 262 263 # Generate initializer list for the generated struct. 264 initializer_list = f": {', '.join(initializers)}" if len(initializers) > 0 else "" 265 266 # Generate getter / setter logic for any symints. 267 symint_bindings = [ 268 b 269 for b in non_self_bindings 270 if isinstance(b.argument, Argument) and b.argument.type.is_symint_like() 271 ] 272 symints_vec_type = NamedCType("symints", VectorCType(BaseCType(SymIntT))) 273 get_symints, set_symints, num_symints = generate_state_getter_setter( 274 symint_bindings, symints_vec_type 275 ) 276 277 # Generate getter / setter logic for any tensors. 278 tensor_bindings = [ 279 b 280 for b in non_self_bindings 281 if isinstance(b.argument, Argument) and b.argument.type.is_tensor_like() 282 ] 283 tensors_vec_type = NamedCType("tensors", VectorCType(BaseCType(tensorT))) 284 get_tensors, set_tensors, num_tensors = generate_state_getter_setter( 285 tensor_bindings, tensors_vec_type 286 ) 287 288 return template.substitute( 289 op=view_func_name(fn), 290 uppercase_op=view_func_name(fn, camel_case=False).upper(), 291 superclass="torch::autograd::ViewFunc", 292 initializer_list=initializer_list, 293 state=state_variables, 294 constructor_args=constructor_args, 295 clone_args=clone_args, 296 symints_vec=symints_vec_type.name, 297 get_symints=get_symints, 298 set_symints=set_symints, 299 num_symints=num_symints, 300 tensors_vec=tensors_vec_type.name, 301 get_tensors=get_tensors, 302 set_tensors=set_tensors, 303 num_tensors=num_tensors, 304 call_input_name=call_input_name, 305 op_call=op_call, 306 ) 307 308 309def gen_view_funcs( 310 out: str, 311 fns_with_infos: list[NativeFunctionWithDifferentiabilityInfo], 312 template_path: str, 313) -> None: 314 # don't need the info parts, just the function 315 fns = [fn.func for fn in fns_with_infos if use_derived(fn)] 316 # only want out-of-place views 317 view_fns = [ 318 fn for fn in fns if get_view_info(fn) is not None and not modifies_arguments(fn) 319 ] 320 321 declarations = [process_function(fn, FUNCTION_DECLARATION) for fn in view_fns] 322 definitions = [process_function(fn, FUNCTION_DEFINITION) for fn in view_fns] 323 ops_headers = [f"#include <ATen/ops/{fn.root_name}_ops.h>" for fn in view_fns] 324 325 file_basename = "ViewFuncs" 326 fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False) 327 for suffix in [".h", ".cpp"]: 328 fname = file_basename + suffix 329 fm.write_with_template( 330 fname, 331 fname, 332 lambda: { 333 "generated_comment": "@" 334 + f"generated from {fm.template_dir_for_comments()}/" 335 + fname, 336 "view_func_declarations": declarations, 337 "view_func_definitions": definitions, 338 "ops_headers": ops_headers, 339 }, 340 ) 341