xref: /aosp_15_r20/external/pytorch/tools/autograd/gen_view_funcs.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Generates ViewFuncs.h/cpp
2#
3# NOTE: If any changes are being made to the ViewFunc codegen please also check
4# if updates are needed in torch/csrc/autograd/autograd_not_implemented_fallback.cpp
5# The fallback is expected to mimic this codegen, so we should keep the two in sync.
6
7from __future__ import annotations
8
9from typing import TYPE_CHECKING
10
11import torchgen.api.dispatcher as dispatcher
12from torchgen.api.translate import translate
13from torchgen.api.types import (
14    BaseCType,
15    Binding,
16    NamedCType,
17    SymIntT,
18    tensorT,
19    VectorCType,
20)
21from torchgen.code_template import CodeTemplate
22from torchgen.model import Argument, NativeFunction, OptionalType
23from torchgen.utils import FileManager
24
25from .gen_inplace_or_view_type import (
26    CALL_DISPATCH,
27    extract_bindings,
28    get_view_info,
29    modifies_arguments,
30    use_derived,
31)
32
33
34if TYPE_CHECKING:
35    from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo
36
37
38FUNCTION_DECLARATION = CodeTemplate(
39    """\
40#define ${uppercase_op}_AVAILABLE
41struct ${op} : public ${superclass} {
42  ${op}(${constructor_args}) ${initializer_list}
43  {};
44  virtual ~${op}() override {};
45  virtual std::vector<c10::SymInt> get_symints() const override;
46  virtual size_t num_symints() const override;
47  virtual std::vector<at::Tensor> get_tensors() const override;
48  virtual size_t num_tensors() const override;
49  virtual at::Tensor operator()(const at::Tensor&) const override;
50  virtual std::unique_ptr<ViewFunc> clone_and_set(
51      std::optional<std::vector<c10::SymInt>> = ::std::nullopt,
52      std::optional<std::vector<at::Tensor>> = ::std::nullopt) const override;
53
54protected:
55  virtual void set_symints(std::vector<c10::SymInt>) override;
56  virtual void set_tensors(std::vector<at::Tensor>) override;
57
58private:
59  ${state}
60};
61
62"""
63)
64
65FUNCTION_DEFINITION = CodeTemplate(
66    """\
67std::vector<c10::SymInt> ${op}::get_symints() const {
68  ${get_symints}
69}
70
71size_t ${op}::num_symints() const {
72  return static_cast<size_t>(${num_symints});
73}
74
75void ${op}::set_symints(std::vector<c10::SymInt> ${symints_vec}) {
76  TORCH_INTERNAL_ASSERT(${symints_vec}.size() == num_symints());
77  ${set_symints}
78}
79
80std::vector<at::Tensor> ${op}::get_tensors() const {
81  ${get_tensors}
82}
83
84size_t ${op}::num_tensors() const {
85  return static_cast<size_t>(${num_tensors});
86}
87
88void ${op}::set_tensors(std::vector<at::Tensor> ${tensors_vec}) {
89  TORCH_INTERNAL_ASSERT(${tensors_vec}.size() == num_tensors());
90  ${set_tensors}
91}
92
93at::Tensor ${op}::operator()(const at::Tensor& ${call_input_name}) const {
94  return ${op_call};
95}
96
97std::unique_ptr<ViewFunc> ${op}::clone_and_set(
98    std::optional<std::vector<c10::SymInt>> ${symints_vec},
99    std::optional<std::vector<at::Tensor>> ${tensors_vec}) const {
100  auto output = std::make_unique<${op}>(${clone_args});
101  if (${symints_vec}.has_value()) {
102    output->set_symints(std::move(*(${symints_vec})));
103  }
104  if (${tensors_vec}.has_value()) {
105    output->set_tensors(std::move(*(${tensors_vec})));
106  }
107  return output;
108}
109
110"""
111)
112
113
114# e.g. as_strided -> AsStridedViewFunc for camel case or
115# as_strided_view_func otherwise
116def view_func_name(
117    f: NativeFunction, include_namespace: bool = False, camel_case: bool = True
118) -> str:
119    name = f.func.name.unambiguous_name()
120    view_func_name = f"{name.replace('.', '_')}_view_func"
121    if camel_case:
122        is_private = view_func_name.startswith("_")
123        view_func_name = "".join(
124            [p.title() for p in view_func_name.replace(".", "_").split("_")]
125        )
126        if is_private:
127            # put the leading underscore back in
128            view_func_name = f"_{view_func_name}"
129    namespace = "torch::autograd::generated::" if include_namespace else ""
130    return f"{namespace}{view_func_name}"
131
132
133def is_symint_or_tensor(arg: Argument) -> bool:
134    return arg.type.is_tensor_like() or arg.type.is_symint_like()
135
136
137def remove_const_ref(binding: Binding) -> Binding:
138    return Binding(
139        name=binding.name,
140        nctype=binding.nctype.remove_const_ref(),
141        argument=binding.argument,
142        default=binding.default,
143    )
144
145
146def returns_multi_tensor(fn: NativeFunction) -> bool:
147    returns = fn.func.returns
148    assert len(returns) == 1
149    returns_list_like = returns[0].type.is_list_like() is not None
150    returns_tensor_like = returns[0].type.is_tensor_like()
151    return returns_list_like and returns_tensor_like
152
153
154# Generates strings with logic for getting / setting state of a particular type.
155#
156# Args:
157#   bindings (list): List of state bindings of interest (may be empty)
158#   state_vec_type (NamedCType): Type of vector to either return or copy from
159#
160# Returns:
161#   tuple: (list of getter logic strings, list of setter logic strings, string
162#     with num items expression)
163def generate_state_getter_setter(
164    bindings: list[Binding],
165    state_vec_type: NamedCType,
166) -> tuple[list[str], list[str], str]:
167    getter_logic = []
168    setter_logic = []
169
170    state_vec = state_vec_type.name
171    getter_logic.append(f"{state_vec_type.cpp_type()} {state_vec};")
172    if len(bindings) > 0:
173        setter_logic.append("auto i = 0;")
174
175    num_exprs = []
176    for i, b in enumerate(bindings):
177        assert isinstance(b.argument, Argument)
178        if b.argument.type.is_list_like():
179            # Handle list-likes.
180            num_expr = f"{b.name}.size()"
181            num_exprs.append(num_expr)
182            getter = f"{state_vec}.insert({state_vec}.end(), {b.name}.begin(), {b.name}.end());"
183            setter = f"std::copy({state_vec}.begin() + i, {state_vec}.begin() + i + {b.name}.size(), {b.name}.begin());"
184        elif isinstance(b.argument.type, OptionalType):
185            # Handle optionals.
186            num_expr = f"({b.name}.has_value() ? 1 : 0)"
187            num_exprs.append(num_expr)
188            conditional = f"if({b.name}.has_value())"
189            getter = (
190                f"{conditional} {state_vec}.insert({state_vec}.end(), *({b.name}));"
191            )
192            setter = f"{conditional} {b.name} = {state_vec}[i];"
193        else:
194            num_expr = "1"
195            num_exprs.append(num_expr)
196            getter = f"{state_vec}.push_back({b.name});"
197            setter = f"{b.name} = {state_vec}[i];"
198
199        getter_logic.append(getter)
200        setter_logic.append(setter)
201        if i < len(bindings) - 1:
202            setter_logic.append(f"i += {num_expr};")
203
204    # Reserve / assert based on the total number of items expression.
205    num_items = "0" if len(num_exprs) == 0 else " + ".join(num_exprs)
206    if len(bindings) > 0:
207        getter_logic.insert(1, f"{state_vec}.reserve({num_items});")
208
209    getter_logic.append(f"return {state_vec};")
210
211    return getter_logic, setter_logic, num_items
212
213
214def process_function(fn: NativeFunction, template: CodeTemplate) -> str:
215    bindings = extract_bindings(fn)
216    non_self_bindings = [b for b in bindings if b.name != "self"]
217
218    non_self_args = fn.func.arguments.flat_all[1:]
219    non_self_value_bindings = [
220        dispatcher.argument(a, remove_non_owning_ref_types=True) for a in non_self_args
221    ]
222
223    # Generate constructor / clone args for the generated struct.
224    constructor_args = [b.defn() for b in non_self_bindings]
225    clone_args = [b.name for b in non_self_bindings]
226
227    # Generate state variable declarations for the generated struct.
228    state_variables = [
229        f"{remove_const_ref(b).defn()};" for b in non_self_value_bindings
230    ]
231
232    # Generate initializer list expressions for the generated struct.
233    # allow_expensive_conversions=True because we need to store e.g. SymIntArrayRefs as
234    # vector<SymInt>s.
235    init_exprs = translate(
236        non_self_bindings, non_self_value_bindings, allow_expensive_conversions=True
237    )
238    initializers = []
239    for b, init_expr in zip(non_self_bindings, init_exprs):
240        name = b.nctype.name
241        assert isinstance(name, str)
242        initializers.append(f"{name}({init_expr.expr})")
243
244    # Generate call to underlying view op
245    call_input_name = "input_base"
246    op_call_args = [call_input_name, *(b.name for b in non_self_bindings)]
247    op_call = CALL_DISPATCH.substitute(
248        unambiguous_name=fn.func.name.unambiguous_name(),
249        unpacked_args=op_call_args,
250    )
251
252    # Multi-output views additionally require a view_idx for disambiguation.
253    if returns_multi_tensor(fn):
254        view_idx_name = "view_idx"
255        view_idx_typename = "int64_t"
256        view_idx_decl = f"{view_idx_typename} {view_idx_name}"
257        constructor_args.append(view_idx_decl)
258        clone_args.append(view_idx_name)
259        state_variables.append(f"{view_idx_decl};")
260        initializers.append(f"{view_idx_name}({view_idx_name})")
261        op_call += f"[{view_idx_name}]"
262
263    # Generate initializer list for the generated struct.
264    initializer_list = f": {', '.join(initializers)}" if len(initializers) > 0 else ""
265
266    # Generate getter / setter logic for any symints.
267    symint_bindings = [
268        b
269        for b in non_self_bindings
270        if isinstance(b.argument, Argument) and b.argument.type.is_symint_like()
271    ]
272    symints_vec_type = NamedCType("symints", VectorCType(BaseCType(SymIntT)))
273    get_symints, set_symints, num_symints = generate_state_getter_setter(
274        symint_bindings, symints_vec_type
275    )
276
277    # Generate getter / setter logic for any tensors.
278    tensor_bindings = [
279        b
280        for b in non_self_bindings
281        if isinstance(b.argument, Argument) and b.argument.type.is_tensor_like()
282    ]
283    tensors_vec_type = NamedCType("tensors", VectorCType(BaseCType(tensorT)))
284    get_tensors, set_tensors, num_tensors = generate_state_getter_setter(
285        tensor_bindings, tensors_vec_type
286    )
287
288    return template.substitute(
289        op=view_func_name(fn),
290        uppercase_op=view_func_name(fn, camel_case=False).upper(),
291        superclass="torch::autograd::ViewFunc",
292        initializer_list=initializer_list,
293        state=state_variables,
294        constructor_args=constructor_args,
295        clone_args=clone_args,
296        symints_vec=symints_vec_type.name,
297        get_symints=get_symints,
298        set_symints=set_symints,
299        num_symints=num_symints,
300        tensors_vec=tensors_vec_type.name,
301        get_tensors=get_tensors,
302        set_tensors=set_tensors,
303        num_tensors=num_tensors,
304        call_input_name=call_input_name,
305        op_call=op_call,
306    )
307
308
309def gen_view_funcs(
310    out: str,
311    fns_with_infos: list[NativeFunctionWithDifferentiabilityInfo],
312    template_path: str,
313) -> None:
314    # don't need the info parts, just the function
315    fns = [fn.func for fn in fns_with_infos if use_derived(fn)]
316    # only want out-of-place views
317    view_fns = [
318        fn for fn in fns if get_view_info(fn) is not None and not modifies_arguments(fn)
319    ]
320
321    declarations = [process_function(fn, FUNCTION_DECLARATION) for fn in view_fns]
322    definitions = [process_function(fn, FUNCTION_DEFINITION) for fn in view_fns]
323    ops_headers = [f"#include <ATen/ops/{fn.root_name}_ops.h>" for fn in view_fns]
324
325    file_basename = "ViewFuncs"
326    fm = FileManager(install_dir=out, template_dir=template_path, dry_run=False)
327    for suffix in [".h", ".cpp"]:
328        fname = file_basename + suffix
329        fm.write_with_template(
330            fname,
331            fname,
332            lambda: {
333                "generated_comment": "@"
334                + f"generated from {fm.template_dir_for_comments()}/"
335                + fname,
336                "view_func_declarations": declarations,
337                "view_func_definitions": definitions,
338                "ops_headers": ops_headers,
339            },
340        )
341