xref: /aosp_15_r20/external/pytorch/torchgen/dest/ufunc.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1from __future__ import annotations
2
3from dataclasses import dataclass
4from typing import Sequence, TYPE_CHECKING
5
6import torchgen.api.ufunc as ufunc
7from torchgen.api.translate import translate
8from torchgen.api.types import (
9    BaseCType,
10    Binding,
11    CType,
12    Expr,
13    NamedCType,
14    opmath_t,
15    scalar_t,
16    StructuredImplSignature,
17    VectorizedCType,
18)
19from torchgen.context import with_native_function
20from torchgen.model import (
21    Argument,
22    BaseTy,
23    BaseType,
24    DispatchKey,
25    NativeFunctionsGroup,
26    ScalarType,
27    UfuncKey,
28)
29from torchgen.utils import OrderedSet
30
31
32if TYPE_CHECKING:
33    from torchgen.api.ufunc import UfunctorBindings
34
35
36# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
37#
38#                                  CUDA STUFF
39#
40# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
41
42# NB: not bothering to generate dispatch stub forward declaration in header,
43# we can just paste it whereever necessary
44
45# TODO: use BackendIndex
46# dispatch_key: DispatchKey  # only CPU/CUDA right now
47
48
49# Represents functors for implementing CUDA ufuncs.
50# Functors are templated by scalar_t because when USERS instantiate functors
51# they are templated.  A functor looks something like this:
52#
53#   template <typename scalar_t>
54#   struct CUDAFunctorOnSelf_add {
55#     using opmath_t = at::opmath_type<scalar_t>;
56#     opmath_t other_;
57#     opmath_t alpha_;
58#     CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha)
59#         : other_(other), alpha_(alpha) {}
60#     __device__ scalar_t operator()(scalar_t self) {
61#       return ufunc::add(static_cast<opmath_t>(self), other_, alpha_);
62#     }
63#   };
64#
65@dataclass(frozen=True)
66class UfunctorSignature:
67    g: NativeFunctionsGroup
68    scalar_tensor_idx: int | None
69    name: str
70
71    def arguments(self) -> UfunctorBindings:
72        return ufunc.ufunctor_arguments(
73            self.g, scalar_tensor_idx=self.scalar_tensor_idx, scalar_t=scalar_t
74        )
75
76    def fields(self) -> list[Binding]:
77        # fields are renamed to have a trailing underscore, as is conventional
78        return [b.rename(f"{b.name}_") for b in self.arguments().ctor]
79
80    def returns_type(self) -> CType:
81        # TODO: don't hardcode; return type will be inferred based on tags on
82        # the native function
83        return BaseCType(scalar_t)
84
85    def decl_fields(self) -> str:
86        return "\n".join(f"{f.type} {f.name};" for f in self.fields())
87
88    def inline_defn_ctor(self) -> str:
89        args_str = ", ".join(a.decl() for a in self.arguments().ctor)
90        # NB: hypothetically could do this with translate but the
91        # transition here is very regular
92        init_str = ", ".join(f"{a.name}_({a.name})" for a in self.arguments().ctor)
93        return f"{self.name}({args_str}) : {init_str} {{}}"
94
95    def decl_apply(self) -> str:
96        args_str = ", ".join(a.decl() for a in self.arguments().apply)
97        return f"{self.returns_type().cpp_type()} operator()({args_str}) const"
98
99
100@dataclass(frozen=True)
101class UfuncSignature:
102    g: NativeFunctionsGroup
103    name: str
104    compute_t: CType
105
106    def arguments(self) -> list[Binding]:
107        return ufunc.ufunc_arguments(self.g, compute_t=self.compute_t)
108
109    def call(self, ctx: Sequence[Binding | Expr]) -> str:
110        return f"{self.name}({', '.join(a.expr for a in translate(ctx, self.arguments()))})"
111
112
113# steps:
114#   1. take the functional signature
115#   2. use api.ufunc to convert it to template signature.  this establishes
116#      the type of the template function
117#   3. use api.ufunc (II) to generate a split struct / operator() signature.
118#      this establish context in which we call the template signature
119#
120# StructuredImplSignature context
121#   ~> functor constructor sig
122#
123# Functor constructor context
124#   ~> functor fields sig
125#
126# Functor apply context (functor fields + functor apply sig)
127#   ~> template sig
128#
129
130
131def eligible_for_binary_scalar_specialization(g: NativeFunctionsGroup) -> bool:
132    num_tensors = sum(
133        1 for a in g.functional.func.arguments.flat_non_out if a.type.is_tensor_like()
134    )
135    return num_tensors == 2
136
137
138def compute_ufunc_cuda_functors(
139    g: NativeFunctionsGroup,
140) -> tuple[dict[ScalarType, dict[UfuncKey, UfunctorSignature]], str]:
141    # First, build the functors.
142    ufunctor_sigs: dict[ScalarType, dict[UfuncKey, UfunctorSignature]] = {}
143    ufunctors: list[str] = []
144    loops = g.out.ufunc_inner_loop
145    scalar_tensor_idx_lookup = {
146        UfuncKey.CUDAFunctorOnSelf: 1,
147        UfuncKey.CUDAFunctorOnOther: 0,
148        UfuncKey.CUDAFunctor: None,
149    }
150    if eligible_for_binary_scalar_specialization(g):
151        keys = [
152            UfuncKey.CUDAFunctorOnSelf,
153            UfuncKey.CUDAFunctorOnOther,
154            UfuncKey.CUDAFunctor,
155        ]
156    else:
157        keys = [UfuncKey.CUDAFunctor]
158        for k in [UfuncKey.CUDAFunctorOnSelf, UfuncKey.CUDAFunctorOnOther]:
159            assert k not in loops, f"cannot use {k} on non-binary function"
160    for k in keys:
161        # If the key was directly defined, skip functor codegen; we assume the
162        # user already done it for us
163        if k in loops:
164            ufunctor_sig = UfunctorSignature(
165                g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=loops[k].name
166            )
167            for dtype in loops[k].supported_dtypes:
168                ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
169            continue
170
171        # Note [ScalarOnly and Generic must match names for CUDA]
172        # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
173        # Otherwise, look in ANY of the generic entries.  For simplicity of
174        # codegen, both ScalarOnly and Generic are defined, the ufunc name
175        # must match  (if they didn't match, we'd have to generate distinct
176        # functors per dtype, which is awful, so we're not going to do it unless
177        # someone really forces us to)
178        ufunc_name = None
179        supported_dtypes: OrderedSet[ScalarType] = OrderedSet()
180        for lk in [UfuncKey.ScalarOnly, UfuncKey.Generic]:
181            if lk not in loops:
182                continue
183            if ufunc_name is None:
184                ufunc_name = loops[lk].name
185            else:
186                # See Note [ScalarOnly and Generic must match names for CUDA]
187                assert (
188                    ufunc_name == loops[lk].name
189                ), "ScalarOnly and Generic must have same ufunc name"
190            supported_dtypes |= loops[lk].supported_dtypes
191        assert ufunc_name is not None
192
193        name = f"{k}_{ufunc_name}"
194        ufunctor_sig = UfunctorSignature(
195            g, scalar_tensor_idx=scalar_tensor_idx_lookup[k], name=name
196        )
197        for dtype in supported_dtypes:
198            ufunctor_sigs.setdefault(dtype, {})[k] = ufunctor_sig
199
200        ufunc_sig = UfuncSignature(
201            g, name=f"ufunc::{ufunc_name}", compute_t=BaseCType(opmath_t)
202        )
203        apply_ctx = ufunctor_sig.fields() + ufunctor_sig.arguments().apply
204        ufunctors.append(
205            f"""
206template <typename scalar_t>
207struct {ufunctor_sig.name} {{
208  using opmath_t = at::opmath_type<scalar_t>;
209  {ufunctor_sig.decl_fields()}
210  {ufunctor_sig.inline_defn_ctor()}
211  __device__ {ufunctor_sig.decl_apply()} {{
212    return {ufunc_sig.call(apply_ctx)};
213  }}
214}};
215"""
216        )
217
218    return ufunctor_sigs, "\n".join(ufunctors)
219
220
221@dataclass(frozen=True)
222class BinaryScalarSpecializationConfig:
223    scalar_idx: int
224    ctor_tensor: str
225    ufunc_key: UfuncKey
226
227
228BinaryScalarSpecializationConfigs = [
229    BinaryScalarSpecializationConfig(
230        scalar_idx=0,
231        ctor_tensor="self",
232        ufunc_key=UfuncKey.CUDAFunctorOnOther,
233    ),
234    BinaryScalarSpecializationConfig(
235        scalar_idx=1,
236        ctor_tensor="other",
237        ufunc_key=UfuncKey.CUDAFunctorOnSelf,
238    ),
239]
240
241
242def compute_ufunc_cuda_dtype_body(
243    g: NativeFunctionsGroup,
244    dtype: ScalarType,
245    inner_loops: dict[UfuncKey, UfunctorSignature],
246    parent_ctx: Sequence[Binding],
247) -> str:
248    body = "using opmath_t = at::opmath_type<scalar_t>;"
249    body += "if (false) {}\n"  # for ease of codegen
250    for config in BinaryScalarSpecializationConfigs:
251        if config.ufunc_key not in inner_loops:
252            continue
253        ufunctor_sig = inner_loops[config.ufunc_key]
254        scalar_idx = config.scalar_idx + 1
255        # Make a copy and at the same time widen the type (not permissible
256        # without copy; we don't want to mutate the input argument anyway)
257        ctx: list[Expr | Binding] = list(parent_ctx)
258        ctx.append(
259            Expr(
260                expr=f"iter.scalar_value<opmath_t>({scalar_idx})",
261                type=NamedCType(config.ctor_tensor, BaseCType(opmath_t)),
262            )
263        )
264        ufunctor_ctor_exprs_str = ", ".join(
265            a.expr for a in translate(ctx, ufunctor_sig.arguments().ctor)
266        )
267
268        # NB: ufunctor must be allocated before iter.remove_operand is called,
269        # as it relies on iter
270        body += f"""\
271else if (iter.is_cpu_scalar({scalar_idx})) {{
272  {ufunctor_sig.name}<scalar_t> ufunctor({ufunctor_ctor_exprs_str});
273  iter.remove_operand({scalar_idx});
274  gpu_kernel(iter, ufunctor);
275}}"""
276
277    ufunctor_sig = inner_loops[UfuncKey.CUDAFunctor]
278    ufunctor_ctor_exprs_str = ", ".join(
279        a.expr for a in translate(parent_ctx, ufunctor_sig.arguments().ctor)
280    )
281    body += f"""
282else {{
283  gpu_kernel(iter, {ufunctor_sig.name}<scalar_t>({ufunctor_ctor_exprs_str}));
284}}
285    """
286    return body
287
288
289@with_native_function
290def compute_ufunc_cuda(g: NativeFunctionsGroup) -> str:
291    # First, build the functors, indexing them by dtype
292    ufunctor_sigs, ufunctors = compute_ufunc_cuda_functors(g)
293
294    # Next, build the conditionals
295    sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CUDA))
296    dtype_cases = []
297    for dtype, inner_ufunc_sigs in ufunctor_sigs.items():
298        dtype_cases.append(
299            f"""
300AT_DISPATCH_CASE(at::ScalarType::{dtype},
301  [&]() {{
302    {compute_ufunc_cuda_dtype_body(g, dtype, inner_ufunc_sigs, sig.arguments())}
303  }}
304)
305"""
306        )
307
308    dtype_cases_str = "\n".join(dtype_cases)
309
310    stub_sig = StubSignature(g)
311
312    return f"""
313{ufunctors}
314
315{stub_sig.type_defn()};
316{stub_sig.dispatch_decl()};
317
318{stub_sig.kernel_defn()} {{
319  AT_DISPATCH_SWITCH(iter.common_dtype(), "{sig.name}",
320    {dtype_cases_str}
321  );
322}}
323REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name});
324
325{sig.defn()} {{
326  {stub_sig.direct_call(sig.arguments())};
327}}
328"""
329
330
331# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
332#
333#                                   CPU STUFF
334#
335# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
336
337
338@dataclass(frozen=True)
339class StubSignature:
340    g: NativeFunctionsGroup
341
342    @property
343    def name(self) -> str:
344        return f"{str(self.g.functional.func.name.name)}_stub"
345
346    @property
347    def kernel_name(self) -> str:
348        return f"{str(self.g.functional.func.name.name)}_kernel"
349
350    @property
351    def type_name(self) -> str:
352        return f"{str(self.g.functional.func.name.name)}_fn"
353
354    def arguments(self) -> list[Binding]:
355        return ufunc.stub_arguments(self.g)
356
357    def type(self) -> str:
358        cpp_args = self.arguments()
359        return f"void(*)(TensorIteratorBase&, {', '.join(a.type for a in cpp_args)})"
360
361    def dispatch_decl(self) -> str:
362        return f"DECLARE_DISPATCH({self.type_name}, {self.name})"
363
364    def dispatch_defn(self) -> str:
365        return f"DEFINE_DISPATCH({self.name})"
366
367    def kernel_defn(self) -> str:
368        return f"void {self.kernel_name}(TensorIteratorBase& iter, {', '.join(a.defn() for a in self.arguments())})"
369
370    def type_defn(self) -> str:
371        return f"using {self.type_name} = {self.type()}"
372
373    # must be called from context where this is TensorIteratorBase*
374    def call(self, ctx: Sequence[Binding]) -> str:
375        return f"{self.name}(device_type(), *this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
376
377    # used in CUDA to skip the unnecessary dynamic dispatch
378    def direct_call(self, ctx: Sequence[Binding]) -> str:
379        return f"{self.kernel_name}(*this, {', '.join(a.expr for a in translate(ctx, self.arguments()))})"
380
381
382@with_native_function
383def compute_ufunc_cpu(g: NativeFunctionsGroup) -> str:
384    stub_sig = StubSignature(g)
385    sig = StructuredImplSignature(g, ufunc.kernel_name(g, DispatchKey.CPU))
386
387    return f"""
388{stub_sig.type_defn()};
389{stub_sig.dispatch_decl()};
390{stub_sig.dispatch_defn()};
391
392{sig.defn()} {{
393  {stub_sig.call(sig.arguments())};
394}}
395"""
396
397
398def compute_ufunc_cpu_dtype_body(
399    g: NativeFunctionsGroup,
400    dtype: ScalarType,
401    inner_loops: dict[UfuncKey, UfuncSignature],
402    parent_ctx: Sequence[Binding],
403) -> str:
404    assert UfuncKey.CPUScalar in inner_loops, f"{dtype}, {inner_loops.keys()}"
405    assert inner_loops.keys() <= {UfuncKey.CPUScalar, UfuncKey.CPUVector}
406    scalar_loop = inner_loops[UfuncKey.CPUScalar]
407    vec_loop = None
408    if UfuncKey.CPUVector in inner_loops:
409        vec_loop = inner_loops[UfuncKey.CPUVector]
410
411    # NB: We DON'T use translate here, because translate is
412    # incapable of CSE'ing the scalar accesses in case it is also
413    # used by Vectorized; also, the unpacking here is very simple
414    # and only affects Scalar; everything else is implicitly captured
415    # by the lambda
416
417    # Setup scalar in scope
418    body = []
419    ctx = []
420    for b in parent_ctx:
421        if isinstance(b.argument, Argument) and b.argument.type != BaseType(
422            BaseTy.Scalar
423        ):
424            continue
425        body.append(f"auto _s_{b.name} = {b.name}.to<scalar_t>();")
426        ctx.append(Expr(f"_s_{b.name}", NamedCType(b.nctype.name, BaseCType(scalar_t))))
427    if vec_loop is not None:
428        for b in parent_ctx:
429            if isinstance(b.argument, Argument) and b.argument.type != BaseType(
430                BaseTy.Scalar
431            ):
432                continue
433            body.append(
434                f"auto _v_{b.name} = at::vec::Vectorized<scalar_t>(_s_{b.name});"
435            )
436            ctx.append(
437                Expr(
438                    f"_v_{b.name}",
439                    NamedCType(b.nctype.name, VectorizedCType(BaseCType(scalar_t))),
440                )
441            )
442
443    # Setup lambda signature
444    # NB: simplified version of ufunctor_arguments
445    scalar_bindings = []
446    vec_bindings = []
447    for a in g.functional.func.arguments.flat_non_out:
448        if not a.type.is_tensor_like():
449            continue
450        assert a.type == BaseType(BaseTy.Tensor)
451        scalar_bindings.append(
452            Binding(
453                name=a.name,
454                nctype=NamedCType(a.name, BaseCType(scalar_t)),
455                argument=a,
456            )
457        )
458        if vec_loop is not None:
459            vec_bindings.append(
460                Binding(
461                    name=a.name,
462                    nctype=NamedCType(a.name, VectorizedCType(BaseCType(scalar_t))),
463                    argument=a,
464                )
465            )
466
467    def with_ctx(b: Sequence[Binding]) -> list[Expr | Binding]:
468        r: list[Expr | Binding] = []
469        r.extend(ctx)
470        r.extend(b)
471        return r
472
473    body_str = "\n".join(body)
474    if vec_loop is not None:
475        return f"""
476{body_str}
477cpu_kernel_vec(iter,
478  [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }},
479  [=]({', '.join(b.decl() for b in vec_bindings)}) {{ return {vec_loop.call(with_ctx(vec_bindings))}; }}
480);
481"""
482    else:
483        return f"""
484{body_str}
485cpu_kernel(iter,
486  [=]({', '.join(b.decl() for b in scalar_bindings)}) {{ return {scalar_loop.call(with_ctx(scalar_bindings))}; }}
487);
488"""
489
490
491@with_native_function
492def compute_ufunc_cpu_kernel(g: NativeFunctionsGroup) -> str:
493    stub_sig = StubSignature(g)
494
495    # Reindex the ufunc by dtypes; processing generic/scalaronly as well
496    loops = g.out.ufunc_inner_loop
497    ufunc_sigs: dict[ScalarType, dict[UfuncKey, UfuncSignature]] = {}
498    for k in [UfuncKey.CPUScalar, UfuncKey.CPUVector]:
499        lks = []
500        # ORDER MATTERS: this specifies overriding precedence
501        if k in loops:  # should happen rarely
502            lks.append(k)
503        if UfuncKey.ScalarOnly in loops and k is UfuncKey.CPUScalar:
504            lks.append(UfuncKey.ScalarOnly)
505        if UfuncKey.Generic in loops:
506            lks.append(UfuncKey.Generic)
507        # TODO: don't hardcode ufunc:: namespace here, should be centralized smh
508        for lk in lks:
509            for dtype in loops[lk].supported_dtypes:
510                compute_t: CType
511                if k is UfuncKey.CPUScalar:
512                    compute_t = BaseCType(scalar_t)
513                elif k is UfuncKey.CPUVector:
514                    compute_t = VectorizedCType(BaseCType(scalar_t))
515                else:
516                    raise AssertionError
517                inner_ufunc_sigs = ufunc_sigs.setdefault(dtype, {})
518                if k not in inner_ufunc_sigs:
519                    inner_ufunc_sigs[k] = UfuncSignature(
520                        g, name=f"ufunc::{loops[lk].name}", compute_t=compute_t
521                    )
522
523    # Build the conditionals
524    dtype_cases = []
525    for dtype, inner_ufunc_sigs in ufunc_sigs.items():
526        dtype_cases.append(
527            f"""
528AT_DISPATCH_CASE(at::ScalarType::{dtype},
529  [&]() {{
530    {compute_ufunc_cpu_dtype_body(g, dtype, inner_ufunc_sigs, stub_sig.arguments())}
531  }}
532)
533"""
534        )
535
536    dtype_cases_str = "\n".join(dtype_cases)
537    return f"""
538namespace {{
539
540{stub_sig.kernel_defn()} {{
541  AT_DISPATCH_SWITCH(iter.common_dtype(), "{stub_sig.name}",
542    {dtype_cases_str}
543  );
544}}
545
546}} // anonymous namespace
547
548{stub_sig.type_defn()};
549{stub_sig.dispatch_decl()};
550REGISTER_DISPATCH({stub_sig.name}, &{stub_sig.kernel_name});
551"""
552