xref: /aosp_15_r20/external/pytorch/torch/_custom_op/functional.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import weakref
3
4import torch
5import torch.utils._pytree as pytree
6from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet
7from torch._ops import OpOverload
8from torch.library import Library
9from torchgen.model import (
10    BaseTy,
11    BaseType,
12    FunctionSchema,
13    OperatorName,
14    OptionalType,
15    SchemaKind,
16)
17
18from .autograd import autograd_not_implemented
19
20
21def register_functional_op(
22    lib: Library,
23    new_op_name: str,
24    mutable_op: OpOverload,
25) -> None:
26    """Given a mutable operator, registers the functional variant.
27
28    This API also correctly links the functional variant with the mutable
29    operator for the purposes of functionalization.
30
31    All of the new registrations are performed on the ``lib`` passed in.
32
33    Arguments:
34        lib (Library): Should be a torch.library.Library object that has
35            the same namespace as ``mutable_op``'s namespace.
36            lib will be used to register the new functional op as well
37            as a functionalization kernel for the ``mutable_op``
38            If you don't have a library handy, use
39            ``torch.library.Library(ns, 'FRAGMENT')`` to construct one.
40        new_op_name (str): The name of the functional operator (without the
41            namespace). If no namespace, the new functional variant will be
42            accessible under ``torch.ops.{lib.ns}.new_op_name``.
43        mutable_op (OpOverload): The mutable custom operator. Note
44            that you may need to add a `.default` to it, like
45            `torch.ops.aten.abs_.default`.
46
47    """
48    validate(mutable_op)
49    schema = functional_schema(new_op_name, mutable_op)
50    lib.define(schema)
51
52    functional_impl = construct_functional_impl(mutable_op)
53    lib.impl(new_op_name, functional_impl, 'CompositeExplicitAutograd')
54
55    functional_op = getattr(getattr(torch.ops, lib.ns), new_op_name).default
56
57    # There's no easy way for us to generate the autograd kernel, so we
58    # use autograd_not_implemented. Also, this makes it so that the user
59    # is unable to register an autograd formula themselves. This shouldn't
60    # be a problem if the user doesn't use the functional op direclty
61    # in their program, but we may need to revist this in the future.
62    lib.impl(new_op_name, autograd_not_implemented(functional_op), 'Autograd')
63
64    f_kernel = construct_functionalization_kernel(weakref.proxy(mutable_op), functional_op)
65
66    lib.impl(mutable_op, f_kernel, 'Functionalize')
67
68
69def construct_functional_impl(mutable_op):
70    def functional_impl(*args):
71        # Strategy:
72        # - clone args that would have been mutated
73        # - run mutable_op
74        # - return the cloned args as additional outputs
75        new_args = []
76        extra_rets = []
77        for is_write, arg in zip(mutable_args(mutable_op), args):
78            if is_write:
79                cloned = arg.clone() if arg is not None else None
80                new_args.append(cloned)
81                extra_rets.append(cloned)
82            else:
83                new_args.append(arg)
84        result = mutable_op(*new_args)
85        if result is None:
86            return tuple(extra_rets)
87        if isinstance(result, tuple):
88            return (*result, *extra_rets)
89        return (result, *extra_rets)
90    return functional_impl
91
92
93def construct_functionalization_kernel(mutable_op, functional_op):
94    def kernel(*args):
95        # There's nothing to be functionalized!
96        # We can still end up here because DispatchKey::Functionalize is a mode key
97        if pytree.tree_all_only(torch.Tensor, lambda x: not torch._is_functional_tensor(x), args):
98            with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
99                return mutable_op(*args)
100
101        # NB: This differs from the codegen -- codegen handles cases where there
102        # are mixed FunctionalTensorWrapper and non-FunctionalTensorWrapper.
103        # This only really matters for XLA (mixed CPU-XLA tensors) and
104        # running functionalization without the PT2 stack (which guarantees to us that
105        # all tensors are FunctionalTensorWrapper).
106        if not pytree.tree_all_only(torch.Tensor, torch._is_functional_tensor, args):
107            raise RuntimeError("{mutable_op}: expected all args to be FunctionalTensorWrapper")
108
109        unwrapped_args = []
110        for arg in args:
111            if isinstance(arg, torch.Tensor) and torch._is_functional_tensor(arg):
112                torch._sync(arg)
113                unwrapped = torch._from_functional_tensor(arg)
114                unwrapped_args.append(unwrapped)
115            else:
116                unwrapped_args.append(arg)
117
118        with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
119            output = functional_op(*unwrapped_args)
120
121        num_actual_output = len(mutable_op._schema.returns)
122        actual_output = pytree.tree_map(
123            torch._to_functional_tensor, output[:num_actual_output])
124
125        new_values_to_propagate = output[num_actual_output:]
126        inputs_to_replace = [arg for is_write, arg in zip(mutable_args(mutable_op), args)
127                             if is_write]
128        assert len(new_values_to_propagate) == len(inputs_to_replace)
129        for new_value, arg in zip(new_values_to_propagate, inputs_to_replace):
130            if (arg is None and new_value is None) or (arg is not None and new_value is not None):
131                continue
132            torch._C._propagate_xla_data(arg, new_value)
133            torch._C._replace_(arg, new_value)
134            torch._C._commit_update(arg)
135            torch._sync(arg)
136
137        if len(actual_output) == 1:
138            return actual_output[0]
139        elif len(actual_output) == 0:
140            return None
141        return actual_output
142
143    return kernel
144
145
146def validate(mutable_op: OpOverload):
147    if not isinstance(mutable_op, OpOverload):
148        raise TypeError(
149            f"register_functional_op(mutable_op): expected mutable_op to be instance of "
150            f"OpOverload but got {type(mutable_op)}")
151
152    # There are generally three types of "in-place" or "mutable" ops.
153    # Each of them have their own conventions:
154    # - inplace (first input modified in-place and returned as only output)
155    # - out= (some args modified in-place and returned as outputs)
156    # - mutable (some args modified in-place but none of those returned as outputs)
157    # In theory we can support all three, but we'll just support the last
158    # option right now for simplicity.
159    schema = FunctionSchema.parse(str(mutable_op._schema))
160    if not schema.kind() == SchemaKind.mutable:
161        raise RuntimeError("Expected op to be mutable (as opposed to functional, inplace or out)")
162    for ret in schema.returns:
163        # construct_functionalization_kernel assumes this for simplicity
164        if ret.annotation is not None:
165            raise NotImplementedError(
166                "NYI: register_functional_op(op) where op returns a mutated or aliased value. "
167                "Please file an issue (and as a workaround, modify your operator to "
168                "not return the mutated value or aliases)")
169    for arg in schema.arguments.flat_all:
170        # construct_functionalization_kernel assumes this for simplicity
171        if arg.type.is_tensor_like() and (
172            arg.type != BaseType(BaseTy.Tensor)
173            and arg.type != OptionalType(BaseType(BaseTy.Tensor))
174        ):
175            raise NotImplementedError(
176                "NYI: register_functional_op(op) where op has a List[Tensor] input."
177                "Please file an issue.")
178
179
180def functional_schema(new_op_name, op: OpOverload):
181    schema = FunctionSchema.parse(str(op._schema))
182    schema = schema.signature().with_name(OperatorName.parse(new_op_name))
183    return str(schema)
184
185
186def mutable_args(op: OpOverload):
187    return tuple(False if arg.alias_info is None else arg.alias_info.is_write
188                 for arg in op._schema.arguments)
189