xref: /aosp_15_r20/external/pytorch/torch/library.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerimport contextlib
3*da0073e9SAndroid Build Coastguard Workerimport functools
4*da0073e9SAndroid Build Coastguard Workerimport inspect
5*da0073e9SAndroid Build Coastguard Workerimport re
6*da0073e9SAndroid Build Coastguard Workerimport sys
7*da0073e9SAndroid Build Coastguard Workerimport traceback
8*da0073e9SAndroid Build Coastguard Workerimport weakref
9*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union
10*da0073e9SAndroid Build Coastguard Workerfrom typing_extensions import deprecated
11*da0073e9SAndroid Build Coastguard Worker
12*da0073e9SAndroid Build Coastguard Workerimport torch
13*da0073e9SAndroid Build Coastguard Workerimport torch._library as _library
14*da0073e9SAndroid Build Coastguard Workerfrom torch._library.custom_ops import (
15*da0073e9SAndroid Build Coastguard Worker    _maybe_get_opdef,
16*da0073e9SAndroid Build Coastguard Worker    custom_op,
17*da0073e9SAndroid Build Coastguard Worker    CustomOpDef,
18*da0073e9SAndroid Build Coastguard Worker    device_types_t,
19*da0073e9SAndroid Build Coastguard Worker)
20*da0073e9SAndroid Build Coastguard Workerfrom torch._library.infer_schema import infer_schema  # noqa: F401
21*da0073e9SAndroid Build Coastguard Workerfrom torch._ops import OpOverload
22*da0073e9SAndroid Build Coastguard Worker
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker__all__ = [
25*da0073e9SAndroid Build Coastguard Worker    "Library",
26*da0073e9SAndroid Build Coastguard Worker    "impl",
27*da0073e9SAndroid Build Coastguard Worker    "define",
28*da0073e9SAndroid Build Coastguard Worker    "fallthrough_kernel",
29*da0073e9SAndroid Build Coastguard Worker    "impl_abstract",
30*da0073e9SAndroid Build Coastguard Worker    "register_fake",
31*da0073e9SAndroid Build Coastguard Worker    "register_torch_dispatch",
32*da0073e9SAndroid Build Coastguard Worker    "register_vmap",
33*da0073e9SAndroid Build Coastguard Worker    "get_ctx",
34*da0073e9SAndroid Build Coastguard Worker    "custom_op",
35*da0073e9SAndroid Build Coastguard Worker    "infer_schema",
36*da0073e9SAndroid Build Coastguard Worker]
37*da0073e9SAndroid Build Coastguard Worker
38*da0073e9SAndroid Build Coastguard Worker# Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered
39*da0073e9SAndroid Build Coastguard Worker# The keys in the set are of the form `namespace + "/" + op_name + "/" + dispatch_key`.
40*da0073e9SAndroid Build Coastguard Worker# This set is maintained to ensure that two libraries don't try to override the exact same functionality to avoid
41*da0073e9SAndroid Build Coastguard Worker# libraries calling into kernels not intended to be called.
42*da0073e9SAndroid Build Coastguard Worker_impls: Set[str] = set()
43*da0073e9SAndroid Build Coastguard Worker_defs: Set[str] = set()
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker# prim is reserved by TorchScript interpreter
46*da0073e9SAndroid Build Coastguard Worker_reserved_namespaces = ["prim"]
47*da0073e9SAndroid Build Coastguard Worker
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Workerdef fallthrough_kernel():
50*da0073e9SAndroid Build Coastguard Worker    """
51*da0073e9SAndroid Build Coastguard Worker    A dummy function to pass to ``Library.impl`` in order to register a fallthrough.
52*da0073e9SAndroid Build Coastguard Worker    """
53*da0073e9SAndroid Build Coastguard Worker    raise NotImplementedError("fallthrough_kernel() should never be called.")
54*da0073e9SAndroid Build Coastguard Worker
55*da0073e9SAndroid Build Coastguard Worker
56*da0073e9SAndroid Build Coastguard Workerclass Library:
57*da0073e9SAndroid Build Coastguard Worker    """
58*da0073e9SAndroid Build Coastguard Worker    A class to create libraries that can be used to register new operators or
59*da0073e9SAndroid Build Coastguard Worker    override operators in existing libraries from Python.
60*da0073e9SAndroid Build Coastguard Worker    A user can optionally pass in a dispatch keyname if they only want to register
61*da0073e9SAndroid Build Coastguard Worker    kernels corresponding to only one specific dispatch key.
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker    To create a library to override operators in an existing library (with name ns), set the kind to "IMPL".
64*da0073e9SAndroid Build Coastguard Worker    To create a new library (with name ns) to register new operators, set the kind to "DEF".
65*da0073e9SAndroid Build Coastguard Worker    To create a fragment of a possibly existing library to register operators (and bypass
66*da0073e9SAndroid Build Coastguard Worker    the limitation that there is only one library for a given namespace), set the kind to
67*da0073e9SAndroid Build Coastguard Worker    "FRAGMENT".
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker    Args:
70*da0073e9SAndroid Build Coastguard Worker        ns: library name
71*da0073e9SAndroid Build Coastguard Worker        kind: "DEF", "IMPL" (default: "IMPL"), "FRAGMENT"
72*da0073e9SAndroid Build Coastguard Worker        dispatch_key: PyTorch dispatch key (default: "")
73*da0073e9SAndroid Build Coastguard Worker    """
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker    def __init__(self, ns, kind, dispatch_key=""):
76*da0073e9SAndroid Build Coastguard Worker        if kind not in ("IMPL", "DEF", "FRAGMENT"):
77*da0073e9SAndroid Build Coastguard Worker            raise ValueError("Unsupported kind: ", kind)
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker        if ns in _reserved_namespaces and (kind == "DEF" or kind == "FRAGMENT"):
80*da0073e9SAndroid Build Coastguard Worker            raise ValueError(
81*da0073e9SAndroid Build Coastguard Worker                ns,
82*da0073e9SAndroid Build Coastguard Worker                " is a reserved namespace. Please try creating a library with another name.",
83*da0073e9SAndroid Build Coastguard Worker            )
84*da0073e9SAndroid Build Coastguard Worker
85*da0073e9SAndroid Build Coastguard Worker        frame = traceback.extract_stack(limit=3)[0]
86*da0073e9SAndroid Build Coastguard Worker        filename, lineno = frame.filename, frame.lineno
87*da0073e9SAndroid Build Coastguard Worker        self.m: Optional[Any] = torch._C._dispatch_library(
88*da0073e9SAndroid Build Coastguard Worker            kind, ns, dispatch_key, filename, lineno
89*da0073e9SAndroid Build Coastguard Worker        )
90*da0073e9SAndroid Build Coastguard Worker        self.ns = ns
91*da0073e9SAndroid Build Coastguard Worker        self._op_defs: Set[str] = set()
92*da0073e9SAndroid Build Coastguard Worker        self._op_impls: Set[str] = set()
93*da0073e9SAndroid Build Coastguard Worker        self._registration_handles: List[torch._library.utils.RegistrationHandle] = []
94*da0073e9SAndroid Build Coastguard Worker        self.kind = kind
95*da0073e9SAndroid Build Coastguard Worker        self.dispatch_key = dispatch_key
96*da0073e9SAndroid Build Coastguard Worker        # Use a finalizer to setup the "destructor" instead of __del__.
97*da0073e9SAndroid Build Coastguard Worker        # Python __del__ can lead to weird things (globals and locals may already
98*da0073e9SAndroid Build Coastguard Worker        # be gone when __del__ actually gets called!). finalizers help the
99*da0073e9SAndroid Build Coastguard Worker        # situation because it lets us capture references and keeps them alive
100*da0073e9SAndroid Build Coastguard Worker        weakref.finalize(
101*da0073e9SAndroid Build Coastguard Worker            self,
102*da0073e9SAndroid Build Coastguard Worker            _del_library,
103*da0073e9SAndroid Build Coastguard Worker            _impls,
104*da0073e9SAndroid Build Coastguard Worker            self._op_impls,
105*da0073e9SAndroid Build Coastguard Worker            _defs,
106*da0073e9SAndroid Build Coastguard Worker            self._op_defs,
107*da0073e9SAndroid Build Coastguard Worker            self._registration_handles,
108*da0073e9SAndroid Build Coastguard Worker        )
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker    def __repr__(self):
111*da0073e9SAndroid Build Coastguard Worker        return f"Library(kind={self.kind}, ns={self.ns}, dispatch_key={self.dispatch_key})>"
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker    def define(self, schema, alias_analysis="", *, tags=()):
114*da0073e9SAndroid Build Coastguard Worker        r"""Defines a new operator and its semantics in the ns namespace.
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker        Args:
117*da0073e9SAndroid Build Coastguard Worker            schema: function schema to define a new operator.
118*da0073e9SAndroid Build Coastguard Worker            alias_analysis (optional): Indicates if the aliasing properties of the operator arguments can be
119*da0073e9SAndroid Build Coastguard Worker                                       inferred from the schema (default behavior) or not ("CONSERVATIVE").
120*da0073e9SAndroid Build Coastguard Worker            tags (Tag | Sequence[Tag]): one or more torch.Tag to apply to this
121*da0073e9SAndroid Build Coastguard Worker                                       operator. Tagging an operator changes the operator's behavior
122*da0073e9SAndroid Build Coastguard Worker                                       under various PyTorch subsystems; please read the docs for the
123*da0073e9SAndroid Build Coastguard Worker                                       torch.Tag carefully before applying it.
124*da0073e9SAndroid Build Coastguard Worker
125*da0073e9SAndroid Build Coastguard Worker        Returns:
126*da0073e9SAndroid Build Coastguard Worker            name of the operator as inferred from the schema.
127*da0073e9SAndroid Build Coastguard Worker
128*da0073e9SAndroid Build Coastguard Worker        Example::
129*da0073e9SAndroid Build Coastguard Worker            >>> my_lib = Library("mylib", "DEF")
130*da0073e9SAndroid Build Coastguard Worker            >>> my_lib.define("sum(Tensor self) -> Tensor")
131*da0073e9SAndroid Build Coastguard Worker        """
132*da0073e9SAndroid Build Coastguard Worker        # This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid
133*da0073e9SAndroid Build Coastguard Worker        # AliasAnalysis type in C++
134*da0073e9SAndroid Build Coastguard Worker        if alias_analysis not in ["", "FROM_SCHEMA", "CONSERVATIVE"]:
135*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(f"Invalid alias_analysis type {alias_analysis}")
136*da0073e9SAndroid Build Coastguard Worker        assert self.m is not None
137*da0073e9SAndroid Build Coastguard Worker        if isinstance(tags, torch.Tag):
138*da0073e9SAndroid Build Coastguard Worker            tags = (tags,)
139*da0073e9SAndroid Build Coastguard Worker
140*da0073e9SAndroid Build Coastguard Worker        name = schema.split("(")[0]
141*da0073e9SAndroid Build Coastguard Worker        packet_name = name.split(".")[0] if "." in name else name
142*da0073e9SAndroid Build Coastguard Worker        has_preexisting_packet = hasattr(torch.ops, self.ns) and hasattr(
143*da0073e9SAndroid Build Coastguard Worker            getattr(torch.ops, self.ns), packet_name
144*da0073e9SAndroid Build Coastguard Worker        )
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker        result = self.m.define(schema, alias_analysis, tuple(tags))
147*da0073e9SAndroid Build Coastguard Worker        name = schema.split("(")[0]
148*da0073e9SAndroid Build Coastguard Worker        qualname = self.ns + "::" + name
149*da0073e9SAndroid Build Coastguard Worker
150*da0073e9SAndroid Build Coastguard Worker        # If the OpOverloadPacket exists already, then this means we're adding a
151*da0073e9SAndroid Build Coastguard Worker        # new OpOverload for it. Refresh the packet to include the new OpOverload.
152*da0073e9SAndroid Build Coastguard Worker        if has_preexisting_packet:
153*da0073e9SAndroid Build Coastguard Worker            ns = getattr(torch.ops, self.ns)
154*da0073e9SAndroid Build Coastguard Worker            packet = getattr(ns, packet_name)
155*da0073e9SAndroid Build Coastguard Worker            torch._ops._refresh_packet(packet)
156*da0073e9SAndroid Build Coastguard Worker
157*da0073e9SAndroid Build Coastguard Worker        self._op_defs.add(qualname)
158*da0073e9SAndroid Build Coastguard Worker        _defs.add(qualname)
159*da0073e9SAndroid Build Coastguard Worker        return result
160*da0073e9SAndroid Build Coastguard Worker
161*da0073e9SAndroid Build Coastguard Worker    def _register_fake(self, op_name, fn, _stacklevel=1):
162*da0073e9SAndroid Build Coastguard Worker        r"""Registers the fake impl for an operator defined in the library."""
163*da0073e9SAndroid Build Coastguard Worker        source = torch._library.utils.get_source(_stacklevel + 1)
164*da0073e9SAndroid Build Coastguard Worker        frame = sys._getframe(_stacklevel)
165*da0073e9SAndroid Build Coastguard Worker        caller_module = inspect.getmodule(frame)
166*da0073e9SAndroid Build Coastguard Worker        # Can be none if you call register_fake from somewhere there isn't a module
167*da0073e9SAndroid Build Coastguard Worker        # (e.g. __main__)
168*da0073e9SAndroid Build Coastguard Worker        caller_module_name = None if caller_module is None else caller_module.__name__
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker        # TODO(rzou): We're gonna need to stage this change with torchvision,
171*da0073e9SAndroid Build Coastguard Worker        # since torchvision is github first.
172*da0073e9SAndroid Build Coastguard Worker        if caller_module_name is not None and caller_module_name.startswith(
173*da0073e9SAndroid Build Coastguard Worker            "torchvision."
174*da0073e9SAndroid Build Coastguard Worker        ):
175*da0073e9SAndroid Build Coastguard Worker            caller_module_name = None
176*da0073e9SAndroid Build Coastguard Worker
177*da0073e9SAndroid Build Coastguard Worker        qualname = f"{self.ns}::{op_name}"
178*da0073e9SAndroid Build Coastguard Worker        entry = torch._library.simple_registry.singleton.find(qualname)
179*da0073e9SAndroid Build Coastguard Worker        if caller_module_name is not None:
180*da0073e9SAndroid Build Coastguard Worker            func_to_register = _check_pystubs_once(fn, qualname, caller_module_name)
181*da0073e9SAndroid Build Coastguard Worker        else:
182*da0073e9SAndroid Build Coastguard Worker            func_to_register = fn
183*da0073e9SAndroid Build Coastguard Worker
184*da0073e9SAndroid Build Coastguard Worker        handle = entry.fake_impl.register(func_to_register, source)
185*da0073e9SAndroid Build Coastguard Worker        self._registration_handles.append(handle)
186*da0073e9SAndroid Build Coastguard Worker
187*da0073e9SAndroid Build Coastguard Worker    def _register_torch_dispatch_rule(self, op_name, torch_dispatch_class, fn):
188*da0073e9SAndroid Build Coastguard Worker        r"""Registers a torch_dispatch rule for the given operator and torch_dispatch_class.
189*da0073e9SAndroid Build Coastguard Worker
190*da0073e9SAndroid Build Coastguard Worker        This allows for open registration to specify the behavior between the operator
191*da0073e9SAndroid Build Coastguard Worker        and the torch_dispatch_class without needing to modify the torch_dispatch_class
192*da0073e9SAndroid Build Coastguard Worker        or the operator directly.
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker        The torch_dispatch_class is either a Tensor subclass with `__torch_dispatch__` or a
195*da0073e9SAndroid Build Coastguard Worker        TorchDispatchMode.
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker        If it is a Tensor subclass, we expect fn to have the following signature:
198*da0073e9SAndroid Build Coastguard Worker        (cls, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any
199*da0073e9SAndroid Build Coastguard Worker
200*da0073e9SAndroid Build Coastguard Worker        If it is a TorchDispatchMode, we expect fn to have the following signature:
201*da0073e9SAndroid Build Coastguard Worker        (mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any
202*da0073e9SAndroid Build Coastguard Worker        """
203*da0073e9SAndroid Build Coastguard Worker        qualname = f"{self.ns}::{op_name}"
204*da0073e9SAndroid Build Coastguard Worker        entry = torch._library.simple_registry.singleton.find(qualname)
205*da0073e9SAndroid Build Coastguard Worker        handle = entry.torch_dispatch_rules.register(torch_dispatch_class, fn)
206*da0073e9SAndroid Build Coastguard Worker        self._registration_handles.append(handle)
207*da0073e9SAndroid Build Coastguard Worker
208*da0073e9SAndroid Build Coastguard Worker    def _impl_with_aoti_compile(self, op_name, dispatch_key=""):
209*da0073e9SAndroid Build Coastguard Worker        r"""Register the operator to use the AOTI-compiled implementation.
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker        Args:
212*da0073e9SAndroid Build Coastguard Worker            op_name: operator name (along with the overload) or OpOverload object.
213*da0073e9SAndroid Build Coastguard Worker            dispatch_key: dispatch key that the input function should be registered for. By default, it uses
214*da0073e9SAndroid Build Coastguard Worker                          the dispatch key that the library was created with.
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker        Example::
217*da0073e9SAndroid Build Coastguard Worker            >>> my_lib = Library("aten", "IMPL")
218*da0073e9SAndroid Build Coastguard Worker            >>> my_lib._impl_with_aoti_compile("div.Tensor", "CPU")
219*da0073e9SAndroid Build Coastguard Worker        """
220*da0073e9SAndroid Build Coastguard Worker        if dispatch_key == "":
221*da0073e9SAndroid Build Coastguard Worker            dispatch_key = self.dispatch_key
222*da0073e9SAndroid Build Coastguard Worker        assert torch.DispatchKeySet(dispatch_key).has(torch._C.DispatchKey.Dense)
223*da0073e9SAndroid Build Coastguard Worker
224*da0073e9SAndroid Build Coastguard Worker        if isinstance(op_name, str):
225*da0073e9SAndroid Build Coastguard Worker            name = op_name
226*da0073e9SAndroid Build Coastguard Worker        elif isinstance(op_name, OpOverload):
227*da0073e9SAndroid Build Coastguard Worker            name = op_name._schema.name
228*da0073e9SAndroid Build Coastguard Worker            overload_name = op_name._schema.overload_name
229*da0073e9SAndroid Build Coastguard Worker            if overload_name != "":
230*da0073e9SAndroid Build Coastguard Worker                name = name + "." + overload_name
231*da0073e9SAndroid Build Coastguard Worker        else:
232*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
233*da0073e9SAndroid Build Coastguard Worker                "_impl_with_aoti_compile should be passed either a name or an OpOverload object "
234*da0073e9SAndroid Build Coastguard Worker                "as the first argument"
235*da0073e9SAndroid Build Coastguard Worker            )
236*da0073e9SAndroid Build Coastguard Worker
237*da0073e9SAndroid Build Coastguard Worker        key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key
238*da0073e9SAndroid Build Coastguard Worker        if key in _impls:
239*da0073e9SAndroid Build Coastguard Worker            # TODO: in future, add more info about where the existing function is registered (this info is
240*da0073e9SAndroid Build Coastguard Worker            # today already returned by the C++ warning when _impl_with_aoti_compile is called but we error out before that)
241*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
242*da0073e9SAndroid Build Coastguard Worker                "This is not allowed since there's already a kernel registered from python overriding {}"
243*da0073e9SAndroid Build Coastguard Worker                "'s behavior for {} dispatch key and {} namespace.".format(
244*da0073e9SAndroid Build Coastguard Worker                    name.split("::")[-1], dispatch_key, self.ns
245*da0073e9SAndroid Build Coastguard Worker                )
246*da0073e9SAndroid Build Coastguard Worker            )
247*da0073e9SAndroid Build Coastguard Worker
248*da0073e9SAndroid Build Coastguard Worker        assert self.m is not None
249*da0073e9SAndroid Build Coastguard Worker        impl_fn: Callable = self.m.impl_with_aoti_compile
250*da0073e9SAndroid Build Coastguard Worker        impl_fn(self.ns, name.split("::")[-1], dispatch_key)
251*da0073e9SAndroid Build Coastguard Worker
252*da0073e9SAndroid Build Coastguard Worker        _impls.add(key)
253*da0073e9SAndroid Build Coastguard Worker        self._op_impls.add(key)
254*da0073e9SAndroid Build Coastguard Worker
255*da0073e9SAndroid Build Coastguard Worker    def impl(self, op_name, fn, dispatch_key="", *, with_keyset=False):
256*da0073e9SAndroid Build Coastguard Worker        r"""Registers the function implementation for an operator defined in the library.
257*da0073e9SAndroid Build Coastguard Worker
258*da0073e9SAndroid Build Coastguard Worker        Args:
259*da0073e9SAndroid Build Coastguard Worker            op_name: operator name (along with the overload) or OpOverload object.
260*da0073e9SAndroid Build Coastguard Worker            fn: function that's the operator implementation for the input dispatch key or :func:`~fallthrough_kernel`
261*da0073e9SAndroid Build Coastguard Worker                to register a fallthrough.
262*da0073e9SAndroid Build Coastguard Worker            dispatch_key: dispatch key that the input function should be registered for. By default, it uses
263*da0073e9SAndroid Build Coastguard Worker                          the dispatch key that the library was created with.
264*da0073e9SAndroid Build Coastguard Worker            with_keyset: flag controlling if the current dispatcher call keyset should be passed as the first argument
265*da0073e9SAndroid Build Coastguard Worker                         to :attr:`fn` when calling. This should be used to create the appropriate keyset for redispatch calls.
266*da0073e9SAndroid Build Coastguard Worker
267*da0073e9SAndroid Build Coastguard Worker        Example::
268*da0073e9SAndroid Build Coastguard Worker            >>> my_lib = Library("aten", "IMPL")
269*da0073e9SAndroid Build Coastguard Worker            >>> def div_cpu(self, other):
270*da0073e9SAndroid Build Coastguard Worker            >>>     return self * (1 / other)
271*da0073e9SAndroid Build Coastguard Worker            >>> my_lib.impl("div.Tensor", div_cpu, "CPU")
272*da0073e9SAndroid Build Coastguard Worker        """
273*da0073e9SAndroid Build Coastguard Worker        if not callable(fn):
274*da0073e9SAndroid Build Coastguard Worker            raise TypeError(
275*da0073e9SAndroid Build Coastguard Worker                f"Input function is required to be a callable but found type {type(fn)}"
276*da0073e9SAndroid Build Coastguard Worker            )
277*da0073e9SAndroid Build Coastguard Worker        if dispatch_key == "":
278*da0073e9SAndroid Build Coastguard Worker            dispatch_key = self.dispatch_key
279*da0073e9SAndroid Build Coastguard Worker
280*da0073e9SAndroid Build Coastguard Worker        if isinstance(op_name, str):
281*da0073e9SAndroid Build Coastguard Worker            name = op_name
282*da0073e9SAndroid Build Coastguard Worker        elif isinstance(op_name, OpOverload):
283*da0073e9SAndroid Build Coastguard Worker            name = op_name._schema.name
284*da0073e9SAndroid Build Coastguard Worker            overload_name = op_name._schema.overload_name
285*da0073e9SAndroid Build Coastguard Worker            if overload_name != "":
286*da0073e9SAndroid Build Coastguard Worker                name = name + "." + overload_name
287*da0073e9SAndroid Build Coastguard Worker        else:
288*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
289*da0073e9SAndroid Build Coastguard Worker                "impl should be passed either a name or an OpOverload object as the first argument"
290*da0073e9SAndroid Build Coastguard Worker            )
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Worker        key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key
293*da0073e9SAndroid Build Coastguard Worker        if key in _impls:
294*da0073e9SAndroid Build Coastguard Worker            # TODO: in future, add more info about where the existing function is registered (this info is
295*da0073e9SAndroid Build Coastguard Worker            # today already returned by the C++ warning when impl is called but we error out before that)
296*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
297*da0073e9SAndroid Build Coastguard Worker                "This is not allowed since there's already a kernel registered from python overriding {}"
298*da0073e9SAndroid Build Coastguard Worker                "'s behavior for {} dispatch key and {} namespace.".format(
299*da0073e9SAndroid Build Coastguard Worker                    name.split("::")[-1], dispatch_key, self.ns
300*da0073e9SAndroid Build Coastguard Worker                )
301*da0073e9SAndroid Build Coastguard Worker            )
302*da0073e9SAndroid Build Coastguard Worker
303*da0073e9SAndroid Build Coastguard Worker        if dispatch_key == "Meta":
304*da0073e9SAndroid Build Coastguard Worker            dispatcher_op_name = name
305*da0073e9SAndroid Build Coastguard Worker            if "::" not in dispatcher_op_name:
306*da0073e9SAndroid Build Coastguard Worker                dispatcher_op_name = f"{self.ns}::{dispatcher_op_name}"
307*da0073e9SAndroid Build Coastguard Worker
308*da0073e9SAndroid Build Coastguard Worker            # Internally, we shouldn't be registering meta kernels for any operators that
309*da0073e9SAndroid Build Coastguard Worker            # have CompositeImplicitAutograd kernels.
310*da0073e9SAndroid Build Coastguard Worker            # Instead, we should be letting those decompositions run, and writing meta kernels
311*da0073e9SAndroid Build Coastguard Worker            # only for the base operators.
312*da0073e9SAndroid Build Coastguard Worker            if torch._C._dispatch_has_kernel_for_dispatch_key(
313*da0073e9SAndroid Build Coastguard Worker                dispatcher_op_name, "CompositeImplicitAutograd"
314*da0073e9SAndroid Build Coastguard Worker            ):
315*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
316*da0073e9SAndroid Build Coastguard Worker                    f"We should not register a meta kernel directly to the operator '{name}',"
317*da0073e9SAndroid Build Coastguard Worker                    " because it has a CompositeImplicitAutograd kernel in core."
318*da0073e9SAndroid Build Coastguard Worker                    " Instead we should let the operator decompose, and ensure that we have meta kernels"
319*da0073e9SAndroid Build Coastguard Worker                    " for the base ops that it decomposes into."
320*da0073e9SAndroid Build Coastguard Worker                )
321*da0073e9SAndroid Build Coastguard Worker
322*da0073e9SAndroid Build Coastguard Worker        assert self.m is not None
323*da0073e9SAndroid Build Coastguard Worker        self.m.impl(
324*da0073e9SAndroid Build Coastguard Worker            name,
325*da0073e9SAndroid Build Coastguard Worker            dispatch_key if dispatch_key != "" else "CompositeImplicitAutograd",
326*da0073e9SAndroid Build Coastguard Worker            fn,
327*da0073e9SAndroid Build Coastguard Worker            with_keyset,
328*da0073e9SAndroid Build Coastguard Worker        )
329*da0073e9SAndroid Build Coastguard Worker
330*da0073e9SAndroid Build Coastguard Worker        _impls.add(key)
331*da0073e9SAndroid Build Coastguard Worker        self._op_impls.add(key)
332*da0073e9SAndroid Build Coastguard Worker
333*da0073e9SAndroid Build Coastguard Worker    def fallback(self, fn, dispatch_key="", *, with_keyset=False):
334*da0073e9SAndroid Build Coastguard Worker        r"""Registers the function implementation as the fallback for the given key.
335*da0073e9SAndroid Build Coastguard Worker
336*da0073e9SAndroid Build Coastguard Worker        This function only works for a library with global namespace ("_").
337*da0073e9SAndroid Build Coastguard Worker
338*da0073e9SAndroid Build Coastguard Worker        Args:
339*da0073e9SAndroid Build Coastguard Worker            fn: function used as fallback for the given dispatch key or :func:`~fallthrough_kernel`
340*da0073e9SAndroid Build Coastguard Worker                to register a fallthrough.
341*da0073e9SAndroid Build Coastguard Worker            dispatch_key: dispatch key that the input function should be registered for. By default, it uses
342*da0073e9SAndroid Build Coastguard Worker                          the dispatch key that the library was created with.
343*da0073e9SAndroid Build Coastguard Worker            with_keyset: flag controlling if the current dispatcher call keyset should be passed as the first argument
344*da0073e9SAndroid Build Coastguard Worker                         to :attr:`fn` when calling. This should be used to create the appropriate keyset for redispatch calls.
345*da0073e9SAndroid Build Coastguard Worker
346*da0073e9SAndroid Build Coastguard Worker        Example::
347*da0073e9SAndroid Build Coastguard Worker            >>> my_lib = Library("_", "IMPL")
348*da0073e9SAndroid Build Coastguard Worker            >>> def fallback_kernel(op, *args, **kwargs):
349*da0073e9SAndroid Build Coastguard Worker            >>>     # Handle all autocast ops generically
350*da0073e9SAndroid Build Coastguard Worker            >>>     # ...
351*da0073e9SAndroid Build Coastguard Worker            >>> my_lib.fallback(fallback_kernel, "Autocast")
352*da0073e9SAndroid Build Coastguard Worker        """
353*da0073e9SAndroid Build Coastguard Worker        if dispatch_key == "":
354*da0073e9SAndroid Build Coastguard Worker            dispatch_key = self.dispatch_key
355*da0073e9SAndroid Build Coastguard Worker
356*da0073e9SAndroid Build Coastguard Worker        if self.ns != "_":
357*da0073e9SAndroid Build Coastguard Worker            raise RuntimeError(
358*da0073e9SAndroid Build Coastguard Worker                f"""Fallback can only be registered using libary fragment on the global namespace "_" but it is {self.ns}"""
359*da0073e9SAndroid Build Coastguard Worker            )
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker        assert dispatch_key != ""
362*da0073e9SAndroid Build Coastguard Worker        assert self.m is not None
363*da0073e9SAndroid Build Coastguard Worker
364*da0073e9SAndroid Build Coastguard Worker        self.m.fallback(dispatch_key, fn, with_keyset)
365*da0073e9SAndroid Build Coastguard Worker
366*da0073e9SAndroid Build Coastguard Worker    def _destroy(self):
367*da0073e9SAndroid Build Coastguard Worker        if self.m is not None:
368*da0073e9SAndroid Build Coastguard Worker            self.m.reset()
369*da0073e9SAndroid Build Coastguard Worker        self.m = None
370*da0073e9SAndroid Build Coastguard Worker        for handle in self._registration_handles:
371*da0073e9SAndroid Build Coastguard Worker            handle.destroy()
372*da0073e9SAndroid Build Coastguard Worker        self._registration_handles.clear()
373*da0073e9SAndroid Build Coastguard Worker        global _impls
374*da0073e9SAndroid Build Coastguard Worker        _impls -= self._op_impls
375*da0073e9SAndroid Build Coastguard Worker        for name in self._op_defs:
376*da0073e9SAndroid Build Coastguard Worker            # Delete the cached torch.ops.ns.foo if it was registered.
377*da0073e9SAndroid Build Coastguard Worker            # Otherwise, accessing it leads to a segfault.
378*da0073e9SAndroid Build Coastguard Worker            # It's possible that we only registered an overload in this Library
379*da0073e9SAndroid Build Coastguard Worker            # and another library owns an alive overload.
380*da0073e9SAndroid Build Coastguard Worker            # That's OK - the next time torch.ops.ns.foo gets called, it'll be
381*da0073e9SAndroid Build Coastguard Worker            # recomputed to point at the right collection of overloads.
382*da0073e9SAndroid Build Coastguard Worker            ns, name_with_overload = name.split("::")
383*da0073e9SAndroid Build Coastguard Worker            name = name_with_overload.split(".")[0]
384*da0073e9SAndroid Build Coastguard Worker            if not hasattr(torch.ops, ns):
385*da0073e9SAndroid Build Coastguard Worker                continue
386*da0073e9SAndroid Build Coastguard Worker            namespace = getattr(torch.ops, ns)
387*da0073e9SAndroid Build Coastguard Worker            if not hasattr(namespace, name):
388*da0073e9SAndroid Build Coastguard Worker                continue
389*da0073e9SAndroid Build Coastguard Worker            delattr(namespace, name)
390*da0073e9SAndroid Build Coastguard Worker
391*da0073e9SAndroid Build Coastguard Worker
392*da0073e9SAndroid Build Coastguard Workerdef _del_library(
393*da0073e9SAndroid Build Coastguard Worker    captured_impls,
394*da0073e9SAndroid Build Coastguard Worker    op_impls,
395*da0073e9SAndroid Build Coastguard Worker    captured_defs,
396*da0073e9SAndroid Build Coastguard Worker    op_defs,
397*da0073e9SAndroid Build Coastguard Worker    registration_handles,
398*da0073e9SAndroid Build Coastguard Worker):
399*da0073e9SAndroid Build Coastguard Worker    captured_impls -= op_impls
400*da0073e9SAndroid Build Coastguard Worker    captured_defs -= op_defs
401*da0073e9SAndroid Build Coastguard Worker    for handle in registration_handles:
402*da0073e9SAndroid Build Coastguard Worker        handle.destroy()
403*da0073e9SAndroid Build Coastguard Worker
404*da0073e9SAndroid Build Coastguard Worker
405*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager
406*da0073e9SAndroid Build Coastguard Workerdef _scoped_library(*args, **kwargs):
407*da0073e9SAndroid Build Coastguard Worker    try:
408*da0073e9SAndroid Build Coastguard Worker        lib = Library(*args, **kwargs)
409*da0073e9SAndroid Build Coastguard Worker        yield lib
410*da0073e9SAndroid Build Coastguard Worker    finally:
411*da0073e9SAndroid Build Coastguard Worker        lib._destroy()
412*da0073e9SAndroid Build Coastguard Worker
413*da0073e9SAndroid Build Coastguard Worker
414*da0073e9SAndroid Build Coastguard Worker_keep_alive: List[Library] = []
415*da0073e9SAndroid Build Coastguard Worker
416*da0073e9SAndroid Build Coastguard Worker
417*da0073e9SAndroid Build Coastguard WorkerNAMELESS_SCHEMA = re.compile(r"\(.*\) -> .*")
418*da0073e9SAndroid Build Coastguard Worker
419*da0073e9SAndroid Build Coastguard Worker
420*da0073e9SAndroid Build Coastguard Worker@functools.singledispatch
421*da0073e9SAndroid Build Coastguard Workerdef define(qualname, schema, *, lib=None, tags=()):
422*da0073e9SAndroid Build Coastguard Worker    r"""Defines a new operator.
423*da0073e9SAndroid Build Coastguard Worker
424*da0073e9SAndroid Build Coastguard Worker    In PyTorch, defining an op (short for "operator") is a two step-process:
425*da0073e9SAndroid Build Coastguard Worker    - we need to define the op (by providing an operator name and schema)
426*da0073e9SAndroid Build Coastguard Worker    - we need to implement behavior for how the operator interacts with
427*da0073e9SAndroid Build Coastguard Worker    various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc.
428*da0073e9SAndroid Build Coastguard Worker
429*da0073e9SAndroid Build Coastguard Worker    This entrypoint defines the custom operator (the first step)
430*da0073e9SAndroid Build Coastguard Worker    you must then perform the second step by calling various
431*da0073e9SAndroid Build Coastguard Worker    ``impl_*`` APIs, like :func:`torch.library.impl` or
432*da0073e9SAndroid Build Coastguard Worker    :func:`torch.library.register_fake`.
433*da0073e9SAndroid Build Coastguard Worker
434*da0073e9SAndroid Build Coastguard Worker    Args:
435*da0073e9SAndroid Build Coastguard Worker        qualname (str): The qualified name for the operator. Should be
436*da0073e9SAndroid Build Coastguard Worker            a string that looks like "namespace::name", e.g. "aten::sin".
437*da0073e9SAndroid Build Coastguard Worker            Operators in PyTorch need a namespace to
438*da0073e9SAndroid Build Coastguard Worker            avoid name collisions; a given operator may only be created once.
439*da0073e9SAndroid Build Coastguard Worker            If you are writing a Python library, we recommend the namespace to
440*da0073e9SAndroid Build Coastguard Worker            be the name of your top-level module.
441*da0073e9SAndroid Build Coastguard Worker        schema (str): The schema of the operator. E.g. "(Tensor x) -> Tensor"
442*da0073e9SAndroid Build Coastguard Worker            for an op that accepts one Tensor and returns one Tensor. It does
443*da0073e9SAndroid Build Coastguard Worker            not contain the operator name (that is passed in ``qualname``).
444*da0073e9SAndroid Build Coastguard Worker        lib (Optional[Library]): If provided, the lifetime of this operator
445*da0073e9SAndroid Build Coastguard Worker            will be tied to the lifetime of the Library object.
446*da0073e9SAndroid Build Coastguard Worker        tags (Tag | Sequence[Tag]): one or more torch.Tag to apply to this
447*da0073e9SAndroid Build Coastguard Worker            operator. Tagging an operator changes the operator's behavior
448*da0073e9SAndroid Build Coastguard Worker            under various PyTorch subsystems; please read the docs for the
449*da0073e9SAndroid Build Coastguard Worker            torch.Tag carefully before applying it.
450*da0073e9SAndroid Build Coastguard Worker
451*da0073e9SAndroid Build Coastguard Worker    Example::
452*da0073e9SAndroid Build Coastguard Worker        >>> import torch
453*da0073e9SAndroid Build Coastguard Worker        >>> import numpy as np
454*da0073e9SAndroid Build Coastguard Worker        >>>
455*da0073e9SAndroid Build Coastguard Worker        >>> # Define the operator
456*da0073e9SAndroid Build Coastguard Worker        >>> torch.library.define("mylib::sin", "(Tensor x) -> Tensor")
457*da0073e9SAndroid Build Coastguard Worker        >>>
458*da0073e9SAndroid Build Coastguard Worker        >>> # Add implementations for the operator
459*da0073e9SAndroid Build Coastguard Worker        >>> @torch.library.impl("mylib::sin", "cpu")
460*da0073e9SAndroid Build Coastguard Worker        >>> def f(x):
461*da0073e9SAndroid Build Coastguard Worker        >>>     return torch.from_numpy(np.sin(x.numpy()))
462*da0073e9SAndroid Build Coastguard Worker        >>>
463*da0073e9SAndroid Build Coastguard Worker        >>> # Call the new operator from torch.ops.
464*da0073e9SAndroid Build Coastguard Worker        >>> x = torch.randn(3)
465*da0073e9SAndroid Build Coastguard Worker        >>> y = torch.ops.mylib.sin(x)
466*da0073e9SAndroid Build Coastguard Worker        >>> assert torch.allclose(y, x.sin())
467*da0073e9SAndroid Build Coastguard Worker
468*da0073e9SAndroid Build Coastguard Worker    """
469*da0073e9SAndroid Build Coastguard Worker    if not isinstance(qualname, str):
470*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
471*da0073e9SAndroid Build Coastguard Worker            f"define(qualname, schema): expected qualname "
472*da0073e9SAndroid Build Coastguard Worker            f"to be instance of str, got {type(qualname)}"
473*da0073e9SAndroid Build Coastguard Worker        )
474*da0073e9SAndroid Build Coastguard Worker    namespace, name = torch._library.utils.parse_namespace(qualname)
475*da0073e9SAndroid Build Coastguard Worker    if lib is None:
476*da0073e9SAndroid Build Coastguard Worker        lib = Library(namespace, "FRAGMENT")
477*da0073e9SAndroid Build Coastguard Worker        _keep_alive.append(lib)
478*da0073e9SAndroid Build Coastguard Worker    if not NAMELESS_SCHEMA.fullmatch(schema):
479*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
480*da0073e9SAndroid Build Coastguard Worker            f"define(qualname, schema, ...): expected schema "
481*da0073e9SAndroid Build Coastguard Worker            f'to look like e.g. "(Tensor x) -> Tensor" but '
482*da0073e9SAndroid Build Coastguard Worker            f'got "{schema}"'
483*da0073e9SAndroid Build Coastguard Worker        )
484*da0073e9SAndroid Build Coastguard Worker    lib.define(name + schema, alias_analysis="", tags=tags)
485*da0073e9SAndroid Build Coastguard Worker
486*da0073e9SAndroid Build Coastguard Worker
487*da0073e9SAndroid Build Coastguard Worker@define.register
488*da0073e9SAndroid Build Coastguard Workerdef _(lib: Library, schema, alias_analysis=""):
489*da0073e9SAndroid Build Coastguard Worker    """The old torch.library.define.
490*da0073e9SAndroid Build Coastguard Worker    We're keeping this around for BC reasons
491*da0073e9SAndroid Build Coastguard Worker    """
492*da0073e9SAndroid Build Coastguard Worker
493*da0073e9SAndroid Build Coastguard Worker    def wrap(f):
494*da0073e9SAndroid Build Coastguard Worker        name = lib.define(schema, alias_analysis)
495*da0073e9SAndroid Build Coastguard Worker        lib.impl(name, f)
496*da0073e9SAndroid Build Coastguard Worker        return f
497*da0073e9SAndroid Build Coastguard Worker
498*da0073e9SAndroid Build Coastguard Worker    return wrap
499*da0073e9SAndroid Build Coastguard Worker
500*da0073e9SAndroid Build Coastguard Worker
501*da0073e9SAndroid Build Coastguard Worker@functools.singledispatch
502*da0073e9SAndroid Build Coastguard Workerdef impl(qualname, types, func=None, *, lib=None):
503*da0073e9SAndroid Build Coastguard Worker    """Register an implementation for a device type for this operator.
504*da0073e9SAndroid Build Coastguard Worker
505*da0073e9SAndroid Build Coastguard Worker    You may pass "default" for ``types`` to register this implementation as the
506*da0073e9SAndroid Build Coastguard Worker    default implementation for ALL device types.
507*da0073e9SAndroid Build Coastguard Worker    Please only use this if the implementation truly supports all device types;
508*da0073e9SAndroid Build Coastguard Worker    for example, this is true if it is a composition of built-in PyTorch operators.
509*da0073e9SAndroid Build Coastguard Worker
510*da0073e9SAndroid Build Coastguard Worker    Some valid types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu".
511*da0073e9SAndroid Build Coastguard Worker
512*da0073e9SAndroid Build Coastguard Worker    Args:
513*da0073e9SAndroid Build Coastguard Worker        qualname (str): Should be a string that looks like "namespace::operator_name".
514*da0073e9SAndroid Build Coastguard Worker        types (str | Sequence[str]): The device types to register an impl to.
515*da0073e9SAndroid Build Coastguard Worker        lib (Optional[Library]): If provided, the lifetime of this registration
516*da0073e9SAndroid Build Coastguard Worker            will be tied to the lifetime of the Library object.
517*da0073e9SAndroid Build Coastguard Worker
518*da0073e9SAndroid Build Coastguard Worker    Examples:
519*da0073e9SAndroid Build Coastguard Worker        >>> import torch
520*da0073e9SAndroid Build Coastguard Worker        >>> import numpy as np
521*da0073e9SAndroid Build Coastguard Worker        >>>
522*da0073e9SAndroid Build Coastguard Worker        >>> # Define the operator
523*da0073e9SAndroid Build Coastguard Worker        >>> torch.library.define("mylib::mysin", "(Tensor x) -> Tensor")
524*da0073e9SAndroid Build Coastguard Worker        >>>
525*da0073e9SAndroid Build Coastguard Worker        >>> # Add implementations for the cpu device
526*da0073e9SAndroid Build Coastguard Worker        >>> @torch.library.impl("mylib::mysin", "cpu")
527*da0073e9SAndroid Build Coastguard Worker        >>> def f(x):
528*da0073e9SAndroid Build Coastguard Worker        >>>     return torch.from_numpy(np.sin(x.numpy()))
529*da0073e9SAndroid Build Coastguard Worker        >>>
530*da0073e9SAndroid Build Coastguard Worker        >>> x = torch.randn(3)
531*da0073e9SAndroid Build Coastguard Worker        >>> y = torch.ops.mylib.mysin(x)
532*da0073e9SAndroid Build Coastguard Worker        >>> assert torch.allclose(y, x.sin())
533*da0073e9SAndroid Build Coastguard Worker    """
534*da0073e9SAndroid Build Coastguard Worker    return _impl(qualname, types, func, lib=lib, disable_dynamo=False)
535*da0073e9SAndroid Build Coastguard Worker
536*da0073e9SAndroid Build Coastguard Worker
537*da0073e9SAndroid Build Coastguard Workerdef _impl(qualname, types, func=None, *, lib=None, disable_dynamo=False):
538*da0073e9SAndroid Build Coastguard Worker    if isinstance(types, str):
539*da0073e9SAndroid Build Coastguard Worker        types = (types,)
540*da0073e9SAndroid Build Coastguard Worker    keys = set({})
541*da0073e9SAndroid Build Coastguard Worker    for typ in types:
542*da0073e9SAndroid Build Coastguard Worker        is_dispatch_key = torch._C._parse_dispatch_key(typ)
543*da0073e9SAndroid Build Coastguard Worker        if is_dispatch_key:
544*da0073e9SAndroid Build Coastguard Worker            # We also support passing a DispatchKey to impl. Please prefer using
545*da0073e9SAndroid Build Coastguard Worker            # the higher-level torch.library APIs and only pass DispatchKey to
546*da0073e9SAndroid Build Coastguard Worker            # torch.library.impl with caution (or even better, don't use this
547*da0073e9SAndroid Build Coastguard Worker            # option and file an issue on GitHub for what you need).
548*da0073e9SAndroid Build Coastguard Worker            # We don't advertise this to users because
549*da0073e9SAndroid Build Coastguard Worker            # it is very easy to shoot yourself in the foot.
550*da0073e9SAndroid Build Coastguard Worker            keys.add(typ)
551*da0073e9SAndroid Build Coastguard Worker        else:
552*da0073e9SAndroid Build Coastguard Worker            keys.add(_device_type_to_key(typ))
553*da0073e9SAndroid Build Coastguard Worker
554*da0073e9SAndroid Build Coastguard Worker    def register(func):
555*da0073e9SAndroid Build Coastguard Worker        namespace, _ = torch._library.utils.parse_namespace(qualname)
556*da0073e9SAndroid Build Coastguard Worker
557*da0073e9SAndroid Build Coastguard Worker        if lib is None:
558*da0073e9SAndroid Build Coastguard Worker            use_lib = Library(namespace, "FRAGMENT")
559*da0073e9SAndroid Build Coastguard Worker            _keep_alive.append(use_lib)
560*da0073e9SAndroid Build Coastguard Worker        else:
561*da0073e9SAndroid Build Coastguard Worker            use_lib = lib
562*da0073e9SAndroid Build Coastguard Worker        if disable_dynamo:
563*da0073e9SAndroid Build Coastguard Worker
564*da0073e9SAndroid Build Coastguard Worker            @torch._disable_dynamo
565*da0073e9SAndroid Build Coastguard Worker            def func_no_dynamo(*args, **kwargs):
566*da0073e9SAndroid Build Coastguard Worker                return func(*args, **kwargs)
567*da0073e9SAndroid Build Coastguard Worker
568*da0073e9SAndroid Build Coastguard Worker            for key in keys:
569*da0073e9SAndroid Build Coastguard Worker                use_lib.impl(qualname, func_no_dynamo, key)
570*da0073e9SAndroid Build Coastguard Worker        else:
571*da0073e9SAndroid Build Coastguard Worker            for key in keys:
572*da0073e9SAndroid Build Coastguard Worker                use_lib.impl(qualname, func, key)
573*da0073e9SAndroid Build Coastguard Worker
574*da0073e9SAndroid Build Coastguard Worker    if func is None:
575*da0073e9SAndroid Build Coastguard Worker        return register
576*da0073e9SAndroid Build Coastguard Worker    else:
577*da0073e9SAndroid Build Coastguard Worker        register(func)
578*da0073e9SAndroid Build Coastguard Worker
579*da0073e9SAndroid Build Coastguard Worker
580*da0073e9SAndroid Build Coastguard Workerdef _device_type_to_key(device_type: str) -> str:
581*da0073e9SAndroid Build Coastguard Worker    if device_type == "default":
582*da0073e9SAndroid Build Coastguard Worker        # This is technically not correct, because although all device_type
583*da0073e9SAndroid Build Coastguard Worker        # DispatchKeys are included in CompositeExplicitAutograd,
584*da0073e9SAndroid Build Coastguard Worker        # not everything in CompositeExplicitAutograd is associated with a
585*da0073e9SAndroid Build Coastguard Worker        # device_type. I don't really care that much about the difference.
586*da0073e9SAndroid Build Coastguard Worker        return "CompositeExplicitAutograd"
587*da0073e9SAndroid Build Coastguard Worker    return torch._C._dispatch_key_for_device(device_type)
588*da0073e9SAndroid Build Coastguard Worker
589*da0073e9SAndroid Build Coastguard Worker
590*da0073e9SAndroid Build Coastguard Worker@impl.register
591*da0073e9SAndroid Build Coastguard Workerdef _(lib: Library, name, dispatch_key=""):
592*da0073e9SAndroid Build Coastguard Worker    """Legacy torch.library.impl API. Kept around for BC"""
593*da0073e9SAndroid Build Coastguard Worker
594*da0073e9SAndroid Build Coastguard Worker    def wrap(f):
595*da0073e9SAndroid Build Coastguard Worker        lib.impl(name, f, dispatch_key)
596*da0073e9SAndroid Build Coastguard Worker        return f
597*da0073e9SAndroid Build Coastguard Worker
598*da0073e9SAndroid Build Coastguard Worker    return wrap
599*da0073e9SAndroid Build Coastguard Worker
600*da0073e9SAndroid Build Coastguard Worker
601*da0073e9SAndroid Build Coastguard Worker@deprecated(
602*da0073e9SAndroid Build Coastguard Worker    "`torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that "
603*da0073e9SAndroid Build Coastguard Worker    "instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.",
604*da0073e9SAndroid Build Coastguard Worker    category=FutureWarning,
605*da0073e9SAndroid Build Coastguard Worker)
606*da0073e9SAndroid Build Coastguard Workerdef impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1):
607*da0073e9SAndroid Build Coastguard Worker    r"""This API was renamed to :func:`torch.library.register_fake` in PyTorch 2.4.
608*da0073e9SAndroid Build Coastguard Worker    Please use that instead.
609*da0073e9SAndroid Build Coastguard Worker    """
610*da0073e9SAndroid Build Coastguard Worker    if func is not None:
611*da0073e9SAndroid Build Coastguard Worker        _stacklevel = _stacklevel + 1
612*da0073e9SAndroid Build Coastguard Worker    return register_fake(qualname, func, lib=lib, _stacklevel=_stacklevel)
613*da0073e9SAndroid Build Coastguard Worker
614*da0073e9SAndroid Build Coastguard Worker
615*da0073e9SAndroid Build Coastguard Worker_op_identifier = Union[
616*da0073e9SAndroid Build Coastguard Worker    str, "torch._ops.OpOverload", "torch._library.custom_ops.CustomOpDef"
617*da0073e9SAndroid Build Coastguard Worker]
618*da0073e9SAndroid Build Coastguard Worker
619*da0073e9SAndroid Build Coastguard Worker
620*da0073e9SAndroid Build Coastguard Workerdef register_kernel(
621*da0073e9SAndroid Build Coastguard Worker    op: _op_identifier,
622*da0073e9SAndroid Build Coastguard Worker    device_types: device_types_t,
623*da0073e9SAndroid Build Coastguard Worker    func: Optional[Callable] = None,
624*da0073e9SAndroid Build Coastguard Worker    /,
625*da0073e9SAndroid Build Coastguard Worker    *,
626*da0073e9SAndroid Build Coastguard Worker    lib: Optional[Library] = None,
627*da0073e9SAndroid Build Coastguard Worker):
628*da0073e9SAndroid Build Coastguard Worker    """Register an implementation for a device type for this operator.
629*da0073e9SAndroid Build Coastguard Worker
630*da0073e9SAndroid Build Coastguard Worker    Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu".
631*da0073e9SAndroid Build Coastguard Worker    This API may be used as a decorator.
632*da0073e9SAndroid Build Coastguard Worker
633*da0073e9SAndroid Build Coastguard Worker    Args:
634*da0073e9SAndroid Build Coastguard Worker        fn (Callable): The function to register as the implementation for
635*da0073e9SAndroid Build Coastguard Worker            the given device types.
636*da0073e9SAndroid Build Coastguard Worker        device_types (None | str | Sequence[str]): The device_types to register an impl to.
637*da0073e9SAndroid Build Coastguard Worker            If None, we will register to all device types -- please only use
638*da0073e9SAndroid Build Coastguard Worker            this option if your implementation is truly device-type-agnostic.
639*da0073e9SAndroid Build Coastguard Worker
640*da0073e9SAndroid Build Coastguard Worker    Examples::
641*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
642*da0073e9SAndroid Build Coastguard Worker        >>> import torch
643*da0073e9SAndroid Build Coastguard Worker        >>> from torch import Tensor
644*da0073e9SAndroid Build Coastguard Worker        >>> from torch.library import custom_op
645*da0073e9SAndroid Build Coastguard Worker        >>> import numpy as np
646*da0073e9SAndroid Build Coastguard Worker        >>>
647*da0073e9SAndroid Build Coastguard Worker        >>> # Create a custom op that works on cpu
648*da0073e9SAndroid Build Coastguard Worker        >>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu")
649*da0073e9SAndroid Build Coastguard Worker        >>> def numpy_sin(x: Tensor) -> Tensor:
650*da0073e9SAndroid Build Coastguard Worker        >>>     x_np = x.numpy()
651*da0073e9SAndroid Build Coastguard Worker        >>>     y_np = np.sin(x_np)
652*da0073e9SAndroid Build Coastguard Worker        >>>     return torch.from_numpy(y_np)
653*da0073e9SAndroid Build Coastguard Worker        >>>
654*da0073e9SAndroid Build Coastguard Worker        >>> # Add implementations for the cuda device
655*da0073e9SAndroid Build Coastguard Worker        >>> @torch.library.register_kernel("mylib::numpy_sin", "cuda")
656*da0073e9SAndroid Build Coastguard Worker        >>> def _(x):
657*da0073e9SAndroid Build Coastguard Worker        >>>     x_np = x.cpu().numpy()
658*da0073e9SAndroid Build Coastguard Worker        >>>     y_np = np.sin(x_np)
659*da0073e9SAndroid Build Coastguard Worker        >>>     return torch.from_numpy(y_np).to(device=x.device)
660*da0073e9SAndroid Build Coastguard Worker        >>>
661*da0073e9SAndroid Build Coastguard Worker        >>> x_cpu = torch.randn(3)
662*da0073e9SAndroid Build Coastguard Worker        >>> x_cuda = x_cpu.cuda()
663*da0073e9SAndroid Build Coastguard Worker        >>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin())
664*da0073e9SAndroid Build Coastguard Worker        >>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin())
665*da0073e9SAndroid Build Coastguard Worker
666*da0073e9SAndroid Build Coastguard Worker    """
667*da0073e9SAndroid Build Coastguard Worker
668*da0073e9SAndroid Build Coastguard Worker    if not isinstance(
669*da0073e9SAndroid Build Coastguard Worker        op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
670*da0073e9SAndroid Build Coastguard Worker    ):
671*da0073e9SAndroid Build Coastguard Worker        raise ValueError("register_kernel(op): got unexpected type for op: {type(op)}")
672*da0073e9SAndroid Build Coastguard Worker    if isinstance(op, torch._ops.OpOverload):
673*da0073e9SAndroid Build Coastguard Worker        op = op._name
674*da0073e9SAndroid Build Coastguard Worker    opdef = _maybe_get_opdef(op)
675*da0073e9SAndroid Build Coastguard Worker    if opdef is not None:
676*da0073e9SAndroid Build Coastguard Worker        return opdef.register_kernel(device_types, func)
677*da0073e9SAndroid Build Coastguard Worker    assert isinstance(op, str)
678*da0073e9SAndroid Build Coastguard Worker    if device_types is None:
679*da0073e9SAndroid Build Coastguard Worker        device_types = "CompositeExplicitAutograd"
680*da0073e9SAndroid Build Coastguard Worker
681*da0073e9SAndroid Build Coastguard Worker    return _impl(op, device_types, func, lib=lib, disable_dynamo=True)
682*da0073e9SAndroid Build Coastguard Worker
683*da0073e9SAndroid Build Coastguard Worker
684*da0073e9SAndroid Build Coastguard Workerdef register_fake(
685*da0073e9SAndroid Build Coastguard Worker    op: _op_identifier,
686*da0073e9SAndroid Build Coastguard Worker    func: Optional[Callable] = None,
687*da0073e9SAndroid Build Coastguard Worker    /,
688*da0073e9SAndroid Build Coastguard Worker    *,
689*da0073e9SAndroid Build Coastguard Worker    lib: Optional[Library] = None,
690*da0073e9SAndroid Build Coastguard Worker    _stacklevel: int = 1,
691*da0073e9SAndroid Build Coastguard Worker):
692*da0073e9SAndroid Build Coastguard Worker    r"""Register a FakeTensor implementation ("fake impl") for this operator.
693*da0073e9SAndroid Build Coastguard Worker
694*da0073e9SAndroid Build Coastguard Worker    Also sometimes known as a "meta kernel", "abstract impl".
695*da0073e9SAndroid Build Coastguard Worker
696*da0073e9SAndroid Build Coastguard Worker    An "FakeTensor implementation" specifies the behavior of this operator on
697*da0073e9SAndroid Build Coastguard Worker    Tensors that carry no data ("FakeTensor"). Given some input Tensors with
698*da0073e9SAndroid Build Coastguard Worker    certain properties (sizes/strides/storage_offset/device), it specifies
699*da0073e9SAndroid Build Coastguard Worker    what the properties of the output Tensors are.
700*da0073e9SAndroid Build Coastguard Worker
701*da0073e9SAndroid Build Coastguard Worker    The FakeTensor implementation has the same signature as the operator.
702*da0073e9SAndroid Build Coastguard Worker    It is run for both FakeTensors and meta tensors. To write a FakeTensor
703*da0073e9SAndroid Build Coastguard Worker    implementation, assume that all Tensor inputs to the operator are
704*da0073e9SAndroid Build Coastguard Worker    regular CPU/CUDA/Meta tensors, but they do not have storage, and
705*da0073e9SAndroid Build Coastguard Worker    you are trying to return regular CPU/CUDA/Meta tensor(s) as output.
706*da0073e9SAndroid Build Coastguard Worker    The FakeTensor implementation must consist of only PyTorch operations
707*da0073e9SAndroid Build Coastguard Worker    (and may not directly access the storage or data of any input or
708*da0073e9SAndroid Build Coastguard Worker    intermediate Tensors).
709*da0073e9SAndroid Build Coastguard Worker
710*da0073e9SAndroid Build Coastguard Worker    This API may be used as a decorator (see examples).
711*da0073e9SAndroid Build Coastguard Worker
712*da0073e9SAndroid Build Coastguard Worker    For a detailed guide on custom ops, please see
713*da0073e9SAndroid Build Coastguard Worker    https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html
714*da0073e9SAndroid Build Coastguard Worker
715*da0073e9SAndroid Build Coastguard Worker    Examples:
716*da0073e9SAndroid Build Coastguard Worker        >>> import torch
717*da0073e9SAndroid Build Coastguard Worker        >>> import numpy as np
718*da0073e9SAndroid Build Coastguard Worker        >>> from torch import Tensor
719*da0073e9SAndroid Build Coastguard Worker        >>>
720*da0073e9SAndroid Build Coastguard Worker        >>> # Example 1: an operator without data-dependent output shape
721*da0073e9SAndroid Build Coastguard Worker        >>> @torch.library.custom_op("mylib::custom_linear", mutates_args=())
722*da0073e9SAndroid Build Coastguard Worker        >>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
723*da0073e9SAndroid Build Coastguard Worker        >>>     raise NotImplementedError("Implementation goes here")
724*da0073e9SAndroid Build Coastguard Worker        >>>
725*da0073e9SAndroid Build Coastguard Worker        >>> @torch.library.register_fake("mylib::custom_linear")
726*da0073e9SAndroid Build Coastguard Worker        >>> def _(x, weight, bias):
727*da0073e9SAndroid Build Coastguard Worker        >>>     assert x.dim() == 2
728*da0073e9SAndroid Build Coastguard Worker        >>>     assert weight.dim() == 2
729*da0073e9SAndroid Build Coastguard Worker        >>>     assert bias.dim() == 1
730*da0073e9SAndroid Build Coastguard Worker        >>>     assert x.shape[1] == weight.shape[1]
731*da0073e9SAndroid Build Coastguard Worker        >>>     assert weight.shape[0] == bias.shape[0]
732*da0073e9SAndroid Build Coastguard Worker        >>>     assert x.device == weight.device
733*da0073e9SAndroid Build Coastguard Worker        >>>
734*da0073e9SAndroid Build Coastguard Worker        >>>     return (x @ weight.t()) + bias
735*da0073e9SAndroid Build Coastguard Worker        >>>
736*da0073e9SAndroid Build Coastguard Worker        >>> with torch._subclasses.fake_tensor.FakeTensorMode():
737*da0073e9SAndroid Build Coastguard Worker        >>>     x = torch.randn(2, 3)
738*da0073e9SAndroid Build Coastguard Worker        >>>     w = torch.randn(3, 3)
739*da0073e9SAndroid Build Coastguard Worker        >>>     b = torch.randn(3)
740*da0073e9SAndroid Build Coastguard Worker        >>>     y = torch.ops.mylib.custom_linear(x, w, b)
741*da0073e9SAndroid Build Coastguard Worker        >>>
742*da0073e9SAndroid Build Coastguard Worker        >>> assert y.shape == (2, 3)
743*da0073e9SAndroid Build Coastguard Worker        >>>
744*da0073e9SAndroid Build Coastguard Worker        >>> # Example 2: an operator with data-dependent output shape
745*da0073e9SAndroid Build Coastguard Worker        >>> @torch.library.custom_op("mylib::custom_nonzero", mutates_args=())
746*da0073e9SAndroid Build Coastguard Worker        >>> def custom_nonzero(x: Tensor) -> Tensor:
747*da0073e9SAndroid Build Coastguard Worker        >>>     x_np = x.numpy(force=True)
748*da0073e9SAndroid Build Coastguard Worker        >>>     res = np.stack(np.nonzero(x_np), axis=1)
749*da0073e9SAndroid Build Coastguard Worker        >>>     return torch.tensor(res, device=x.device)
750*da0073e9SAndroid Build Coastguard Worker        >>>
751*da0073e9SAndroid Build Coastguard Worker        >>> @torch.library.register_fake("mylib::custom_nonzero")
752*da0073e9SAndroid Build Coastguard Worker        >>> def _(x):
753*da0073e9SAndroid Build Coastguard Worker        >>> # Number of nonzero-elements is data-dependent.
754*da0073e9SAndroid Build Coastguard Worker        >>> # Since we cannot peek at the data in an fake impl,
755*da0073e9SAndroid Build Coastguard Worker        >>> # we use the ctx object to construct a new symint that
756*da0073e9SAndroid Build Coastguard Worker        >>> # represents the data-dependent size.
757*da0073e9SAndroid Build Coastguard Worker        >>>     ctx = torch.library.get_ctx()
758*da0073e9SAndroid Build Coastguard Worker        >>>     nnz = ctx.new_dynamic_size()
759*da0073e9SAndroid Build Coastguard Worker        >>>     shape = [nnz, x.dim()]
760*da0073e9SAndroid Build Coastguard Worker        >>>     result = x.new_empty(shape, dtype=torch.int64)
761*da0073e9SAndroid Build Coastguard Worker        >>>     return result
762*da0073e9SAndroid Build Coastguard Worker        >>>
763*da0073e9SAndroid Build Coastguard Worker        >>> from torch.fx.experimental.proxy_tensor import make_fx
764*da0073e9SAndroid Build Coastguard Worker        >>>
765*da0073e9SAndroid Build Coastguard Worker        >>> x = torch.tensor([0, 1, 2, 3, 4, 0])
766*da0073e9SAndroid Build Coastguard Worker        >>> trace = make_fx(torch.ops.mylib.custom_nonzero, tracing_mode="symbolic")(x)
767*da0073e9SAndroid Build Coastguard Worker        >>> trace.print_readable()
768*da0073e9SAndroid Build Coastguard Worker        >>>
769*da0073e9SAndroid Build Coastguard Worker        >>> assert torch.allclose(trace(x), torch.ops.mylib.custom_nonzero(x))
770*da0073e9SAndroid Build Coastguard Worker
771*da0073e9SAndroid Build Coastguard Worker    """
772*da0073e9SAndroid Build Coastguard Worker    if not isinstance(
773*da0073e9SAndroid Build Coastguard Worker        op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
774*da0073e9SAndroid Build Coastguard Worker    ):
775*da0073e9SAndroid Build Coastguard Worker        raise ValueError("register_fake(op): got unexpected type for op: {type(op)}")
776*da0073e9SAndroid Build Coastguard Worker    if isinstance(op, torch._ops.OpOverload):
777*da0073e9SAndroid Build Coastguard Worker        op = op._name
778*da0073e9SAndroid Build Coastguard Worker    opdef = _maybe_get_opdef(op)
779*da0073e9SAndroid Build Coastguard Worker    if opdef is not None:
780*da0073e9SAndroid Build Coastguard Worker        if func is None:
781*da0073e9SAndroid Build Coastguard Worker            return opdef.register_fake
782*da0073e9SAndroid Build Coastguard Worker        else:
783*da0073e9SAndroid Build Coastguard Worker            return opdef.register_fake(func)
784*da0073e9SAndroid Build Coastguard Worker    assert isinstance(op, str)
785*da0073e9SAndroid Build Coastguard Worker
786*da0073e9SAndroid Build Coastguard Worker    stacklevel = _stacklevel
787*da0073e9SAndroid Build Coastguard Worker
788*da0073e9SAndroid Build Coastguard Worker    def register(func):
789*da0073e9SAndroid Build Coastguard Worker        namespace, op_name = torch._library.utils.parse_namespace(op)
790*da0073e9SAndroid Build Coastguard Worker        if lib is None:
791*da0073e9SAndroid Build Coastguard Worker            use_lib = Library(namespace, "FRAGMENT")
792*da0073e9SAndroid Build Coastguard Worker            _keep_alive.append(use_lib)
793*da0073e9SAndroid Build Coastguard Worker        else:
794*da0073e9SAndroid Build Coastguard Worker            use_lib = lib
795*da0073e9SAndroid Build Coastguard Worker        use_lib._register_fake(op_name, func, _stacklevel=stacklevel + 1)
796*da0073e9SAndroid Build Coastguard Worker        return func
797*da0073e9SAndroid Build Coastguard Worker
798*da0073e9SAndroid Build Coastguard Worker    if func is None:
799*da0073e9SAndroid Build Coastguard Worker        return register
800*da0073e9SAndroid Build Coastguard Worker    else:
801*da0073e9SAndroid Build Coastguard Worker        stacklevel += 1
802*da0073e9SAndroid Build Coastguard Worker        return register(func)
803*da0073e9SAndroid Build Coastguard Worker
804*da0073e9SAndroid Build Coastguard Worker
805*da0073e9SAndroid Build Coastguard Workerdef register_autograd(
806*da0073e9SAndroid Build Coastguard Worker    op: _op_identifier,
807*da0073e9SAndroid Build Coastguard Worker    backward: Callable,
808*da0073e9SAndroid Build Coastguard Worker    /,
809*da0073e9SAndroid Build Coastguard Worker    *,
810*da0073e9SAndroid Build Coastguard Worker    setup_context: Optional[Callable] = None,
811*da0073e9SAndroid Build Coastguard Worker    lib=None,
812*da0073e9SAndroid Build Coastguard Worker) -> None:
813*da0073e9SAndroid Build Coastguard Worker    r"""Register a backward formula for this custom op.
814*da0073e9SAndroid Build Coastguard Worker
815*da0073e9SAndroid Build Coastguard Worker    In order for an operator to work with autograd, you need to register
816*da0073e9SAndroid Build Coastguard Worker    a backward formula:
817*da0073e9SAndroid Build Coastguard Worker    1. You must tell us how to compute gradients during the backward pass
818*da0073e9SAndroid Build Coastguard Worker    by providing us a "backward" function.
819*da0073e9SAndroid Build Coastguard Worker    2. If you need any values from the forward to compute gradients, you can
820*da0073e9SAndroid Build Coastguard Worker    use `setup_context` to save values for backward.
821*da0073e9SAndroid Build Coastguard Worker
822*da0073e9SAndroid Build Coastguard Worker    ``backward`` runs during the backward pass. It accepts ``(ctx, *grads)``:
823*da0073e9SAndroid Build Coastguard Worker    - ``grads`` is one or more gradients. The number of gradients matches
824*da0073e9SAndroid Build Coastguard Worker    the number of outputs of the operator.
825*da0073e9SAndroid Build Coastguard Worker    The ``ctx`` object is `the same ctx object <context_method_mixins>`_ used by
826*da0073e9SAndroid Build Coastguard Worker    :class:`torch.autograd.Function`. The semantics of ``backward_fn`` are the
827*da0073e9SAndroid Build Coastguard Worker    same as :meth:`torch.autograd.Function.backward`.
828*da0073e9SAndroid Build Coastguard Worker
829*da0073e9SAndroid Build Coastguard Worker    ``setup_context(ctx, inputs, output)`` runs during the forward pass.
830*da0073e9SAndroid Build Coastguard Worker    Please save quantities needed for backward onto the ``ctx`` object via
831*da0073e9SAndroid Build Coastguard Worker    either :meth:`torch.autograd.function.FunctionCtx.save_for_backward`
832*da0073e9SAndroid Build Coastguard Worker    or assigning them as attributes of ``ctx``. If your custom op has
833*da0073e9SAndroid Build Coastguard Worker    kwarg-only arguments, we expect the signature of ``setup_context``
834*da0073e9SAndroid Build Coastguard Worker    to be ``setup_context(ctx, inputs, keyword_only_inputs, output)``.
835*da0073e9SAndroid Build Coastguard Worker
836*da0073e9SAndroid Build Coastguard Worker    Both ``setup_context_fn`` and ``backward_fn`` must be traceable. That is,
837*da0073e9SAndroid Build Coastguard Worker    they may not directly access :meth:`torch.Tensor.data_ptr` and they must
838*da0073e9SAndroid Build Coastguard Worker    not depend on or mutate global state. If you need a non-traceable backward,
839*da0073e9SAndroid Build Coastguard Worker    you can make it a separate custom_op that you call inside ``backward_fn``.
840*da0073e9SAndroid Build Coastguard Worker
841*da0073e9SAndroid Build Coastguard Worker    Examples:
842*da0073e9SAndroid Build Coastguard Worker        >>> import torch
843*da0073e9SAndroid Build Coastguard Worker        >>> import numpy as np
844*da0073e9SAndroid Build Coastguard Worker        >>> from torch import Tensor
845*da0073e9SAndroid Build Coastguard Worker        >>>
846*da0073e9SAndroid Build Coastguard Worker        >>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=())
847*da0073e9SAndroid Build Coastguard Worker        >>> def numpy_sin(x: Tensor) -> Tensor:
848*da0073e9SAndroid Build Coastguard Worker        >>>     x_np = x.cpu().numpy()
849*da0073e9SAndroid Build Coastguard Worker        >>>     y_np = np.sin(x_np)
850*da0073e9SAndroid Build Coastguard Worker        >>>     return torch.from_numpy(y_np).to(device=x.device)
851*da0073e9SAndroid Build Coastguard Worker        >>>
852*da0073e9SAndroid Build Coastguard Worker        >>> def setup_context(ctx, inputs, output) -> Tensor:
853*da0073e9SAndroid Build Coastguard Worker        >>>     x, = inputs
854*da0073e9SAndroid Build Coastguard Worker        >>>     ctx.save_for_backward(x)
855*da0073e9SAndroid Build Coastguard Worker        >>>
856*da0073e9SAndroid Build Coastguard Worker        >>> def backward(ctx, grad):
857*da0073e9SAndroid Build Coastguard Worker        >>>     x, = ctx.saved_tensors
858*da0073e9SAndroid Build Coastguard Worker        >>>     return grad * x.cos()
859*da0073e9SAndroid Build Coastguard Worker        >>>
860*da0073e9SAndroid Build Coastguard Worker        >>> torch.library.register_autograd(
861*da0073e9SAndroid Build Coastguard Worker        ...     "mylib::numpy_sin", backward, setup_context=setup_context
862*da0073e9SAndroid Build Coastguard Worker        ... )
863*da0073e9SAndroid Build Coastguard Worker        >>>
864*da0073e9SAndroid Build Coastguard Worker        >>> x = torch.randn(3, requires_grad=True)
865*da0073e9SAndroid Build Coastguard Worker        >>> y = numpy_sin(x)
866*da0073e9SAndroid Build Coastguard Worker        >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
867*da0073e9SAndroid Build Coastguard Worker        >>> assert torch.allclose(grad_x, x.cos())
868*da0073e9SAndroid Build Coastguard Worker        >>>
869*da0073e9SAndroid Build Coastguard Worker        >>> # Example with a keyword-only arg
870*da0073e9SAndroid Build Coastguard Worker        >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
871*da0073e9SAndroid Build Coastguard Worker        >>> def numpy_mul(x: Tensor, *, val: float) -> Tensor:
872*da0073e9SAndroid Build Coastguard Worker        >>>     x_np = x.cpu().numpy()
873*da0073e9SAndroid Build Coastguard Worker        >>>     y_np = x_np * val
874*da0073e9SAndroid Build Coastguard Worker        >>>     return torch.from_numpy(y_np).to(device=x.device)
875*da0073e9SAndroid Build Coastguard Worker        >>>
876*da0073e9SAndroid Build Coastguard Worker        >>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor:
877*da0073e9SAndroid Build Coastguard Worker        >>>     ctx.val = keyword_only_inputs["val"]
878*da0073e9SAndroid Build Coastguard Worker        >>>
879*da0073e9SAndroid Build Coastguard Worker        >>> def backward(ctx, grad):
880*da0073e9SAndroid Build Coastguard Worker        >>>     return grad * ctx.val
881*da0073e9SAndroid Build Coastguard Worker        >>>
882*da0073e9SAndroid Build Coastguard Worker        >>> torch.library.register_autograd(
883*da0073e9SAndroid Build Coastguard Worker        ...     "mylib::numpy_mul", backward, setup_context=setup_context
884*da0073e9SAndroid Build Coastguard Worker        ... )
885*da0073e9SAndroid Build Coastguard Worker        >>>
886*da0073e9SAndroid Build Coastguard Worker        >>> x = torch.randn(3, requires_grad=True)
887*da0073e9SAndroid Build Coastguard Worker        >>> y = numpy_mul(x, val=3.14)
888*da0073e9SAndroid Build Coastguard Worker        >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y))
889*da0073e9SAndroid Build Coastguard Worker        >>> assert torch.allclose(grad_x, torch.full_like(x, 3.14))
890*da0073e9SAndroid Build Coastguard Worker
891*da0073e9SAndroid Build Coastguard Worker    """
892*da0073e9SAndroid Build Coastguard Worker    if not isinstance(
893*da0073e9SAndroid Build Coastguard Worker        op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
894*da0073e9SAndroid Build Coastguard Worker    ):
895*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
896*da0073e9SAndroid Build Coastguard Worker            f"register_autograd(op): got unexpected type for op: {type(op)}"
897*da0073e9SAndroid Build Coastguard Worker        )
898*da0073e9SAndroid Build Coastguard Worker    if isinstance(op, torch._ops.OpOverload):
899*da0073e9SAndroid Build Coastguard Worker        op = op._name
900*da0073e9SAndroid Build Coastguard Worker    opdef = _maybe_get_opdef(op)
901*da0073e9SAndroid Build Coastguard Worker    if opdef is not None:
902*da0073e9SAndroid Build Coastguard Worker        opdef.register_autograd(backward, setup_context=setup_context)
903*da0073e9SAndroid Build Coastguard Worker        return
904*da0073e9SAndroid Build Coastguard Worker
905*da0073e9SAndroid Build Coastguard Worker    assert isinstance(op, str)
906*da0073e9SAndroid Build Coastguard Worker    qualname = op
907*da0073e9SAndroid Build Coastguard Worker    op = torch._library.utils.lookup_op(qualname)
908*da0073e9SAndroid Build Coastguard Worker    schema = op._schema
909*da0073e9SAndroid Build Coastguard Worker    if not _library.utils.is_functional_schema(schema):
910*da0073e9SAndroid Build Coastguard Worker        raise RuntimeError(
911*da0073e9SAndroid Build Coastguard Worker            f"Cannot register autograd formula for non-functional operator "
912*da0073e9SAndroid Build Coastguard Worker            f"{op} with schema {schema}. Please create "
913*da0073e9SAndroid Build Coastguard Worker            f"a functional operator and register an autograd formula for that."
914*da0073e9SAndroid Build Coastguard Worker        )
915*da0073e9SAndroid Build Coastguard Worker    if _library.utils.has_kwarg_only_tensors(schema):
916*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError(
917*da0073e9SAndroid Build Coastguard Worker            f"register_autograd with kwarg-only Tensor args. In the original "
918*da0073e9SAndroid Build Coastguard Worker            f"definition of the op, please make your tensors not kwarg-only. "
919*da0073e9SAndroid Build Coastguard Worker            f"Got: {schema}"
920*da0073e9SAndroid Build Coastguard Worker        )
921*da0073e9SAndroid Build Coastguard Worker
922*da0073e9SAndroid Build Coastguard Worker    info = _library.autograd.Info(backward, setup_context)
923*da0073e9SAndroid Build Coastguard Worker    autograd_kernel = _library.autograd.make_autograd_impl(op, info)
924*da0073e9SAndroid Build Coastguard Worker    namespace, opname = torch._library.utils.parse_namespace(qualname)
925*da0073e9SAndroid Build Coastguard Worker    if lib is None:
926*da0073e9SAndroid Build Coastguard Worker        lib = Library(namespace, "FRAGMENT")
927*da0073e9SAndroid Build Coastguard Worker        _keep_alive.append(lib)
928*da0073e9SAndroid Build Coastguard Worker    lib.impl(opname, autograd_kernel, "Autograd", with_keyset=True)
929*da0073e9SAndroid Build Coastguard Worker
930*da0073e9SAndroid Build Coastguard Worker
931*da0073e9SAndroid Build Coastguard Workerdef register_torch_dispatch(
932*da0073e9SAndroid Build Coastguard Worker    op: _op_identifier,
933*da0073e9SAndroid Build Coastguard Worker    torch_dispatch_class: Any,
934*da0073e9SAndroid Build Coastguard Worker    func: Optional[Callable] = None,
935*da0073e9SAndroid Build Coastguard Worker    /,
936*da0073e9SAndroid Build Coastguard Worker    *,
937*da0073e9SAndroid Build Coastguard Worker    lib: Optional[Library] = None,
938*da0073e9SAndroid Build Coastguard Worker):
939*da0073e9SAndroid Build Coastguard Worker    r"""Registers a torch_dispatch rule for the given operator and ``torch_dispatch_class``.
940*da0073e9SAndroid Build Coastguard Worker
941*da0073e9SAndroid Build Coastguard Worker    This allows for open registration to specify the behavior between the operator
942*da0073e9SAndroid Build Coastguard Worker    and the ``torch_dispatch_class`` without needing to modify the ``torch_dispatch_class``
943*da0073e9SAndroid Build Coastguard Worker    or the operator directly.
944*da0073e9SAndroid Build Coastguard Worker
945*da0073e9SAndroid Build Coastguard Worker    The ``torch_dispatch_class`` is either a Tensor subclass with ``__torch_dispatch__`` or a
946*da0073e9SAndroid Build Coastguard Worker    TorchDispatchMode.
947*da0073e9SAndroid Build Coastguard Worker
948*da0073e9SAndroid Build Coastguard Worker    If it is a Tensor subclass, we expect ``func`` to have the following signature:
949*da0073e9SAndroid Build Coastguard Worker    ``(cls, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any``
950*da0073e9SAndroid Build Coastguard Worker
951*da0073e9SAndroid Build Coastguard Worker    If it is a TorchDispatchMode, we expect ``func`` to have the following signature:
952*da0073e9SAndroid Build Coastguard Worker    ``(mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any``
953*da0073e9SAndroid Build Coastguard Worker
954*da0073e9SAndroid Build Coastguard Worker    ``args`` and ``kwargs`` will have been normalized the same way they are
955*da0073e9SAndroid Build Coastguard Worker    in ``__torch_dispatch__`` (see :ref:`torch-dispatch-calling-convention`).
956*da0073e9SAndroid Build Coastguard Worker
957*da0073e9SAndroid Build Coastguard Worker    Examples:
958*da0073e9SAndroid Build Coastguard Worker
959*da0073e9SAndroid Build Coastguard Worker        >>> import torch
960*da0073e9SAndroid Build Coastguard Worker        >>>
961*da0073e9SAndroid Build Coastguard Worker        >>> @torch.library.custom_op("mylib::foo", mutates_args={})
962*da0073e9SAndroid Build Coastguard Worker        >>> def foo(x: torch.Tensor) -> torch.Tensor:
963*da0073e9SAndroid Build Coastguard Worker        >>>     return x.clone()
964*da0073e9SAndroid Build Coastguard Worker        >>>
965*da0073e9SAndroid Build Coastguard Worker        >>> class MyMode(torch.utils._python_dispatch.TorchDispatchMode):
966*da0073e9SAndroid Build Coastguard Worker        >>>     def __torch_dispatch__(self, func, types, args=(), kwargs=None):
967*da0073e9SAndroid Build Coastguard Worker        >>>         return func(*args, **kwargs)
968*da0073e9SAndroid Build Coastguard Worker        >>>
969*da0073e9SAndroid Build Coastguard Worker        >>> @torch.library.register_torch_dispatch("mylib::foo", MyMode)
970*da0073e9SAndroid Build Coastguard Worker        >>> def _(mode, func, types, args, kwargs):
971*da0073e9SAndroid Build Coastguard Worker        >>>     x, = args
972*da0073e9SAndroid Build Coastguard Worker        >>>     return x + 1
973*da0073e9SAndroid Build Coastguard Worker        >>>
974*da0073e9SAndroid Build Coastguard Worker        >>> x = torch.randn(3)
975*da0073e9SAndroid Build Coastguard Worker        >>> y = foo(x)
976*da0073e9SAndroid Build Coastguard Worker        >>> assert torch.allclose(y, x)
977*da0073e9SAndroid Build Coastguard Worker        >>>
978*da0073e9SAndroid Build Coastguard Worker        >>> with MyMode():
979*da0073e9SAndroid Build Coastguard Worker        >>>     y = foo(x)
980*da0073e9SAndroid Build Coastguard Worker        >>> assert torch.allclose(y, x + 1)
981*da0073e9SAndroid Build Coastguard Worker
982*da0073e9SAndroid Build Coastguard Worker    """
983*da0073e9SAndroid Build Coastguard Worker    if not isinstance(
984*da0073e9SAndroid Build Coastguard Worker        op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
985*da0073e9SAndroid Build Coastguard Worker    ):
986*da0073e9SAndroid Build Coastguard Worker        raise ValueError(
987*da0073e9SAndroid Build Coastguard Worker            "register_torch_dispatch(op): got unexpected type for op: {type(op)}"
988*da0073e9SAndroid Build Coastguard Worker        )
989*da0073e9SAndroid Build Coastguard Worker    if isinstance(op, torch._ops.OpOverload):
990*da0073e9SAndroid Build Coastguard Worker        op = op._name
991*da0073e9SAndroid Build Coastguard Worker    opdef = _maybe_get_opdef(op)
992*da0073e9SAndroid Build Coastguard Worker    if opdef is not None:
993*da0073e9SAndroid Build Coastguard Worker        return opdef.register_torch_dispatch(torch_dispatch_class, func)
994*da0073e9SAndroid Build Coastguard Worker    assert isinstance(op, str)
995*da0073e9SAndroid Build Coastguard Worker
996*da0073e9SAndroid Build Coastguard Worker    def register(func):
997*da0073e9SAndroid Build Coastguard Worker        namespace, op_name = torch._library.utils.parse_namespace(op)
998*da0073e9SAndroid Build Coastguard Worker        if lib is None:
999*da0073e9SAndroid Build Coastguard Worker            use_lib = Library(namespace, "FRAGMENT")
1000*da0073e9SAndroid Build Coastguard Worker            _keep_alive.append(use_lib)
1001*da0073e9SAndroid Build Coastguard Worker        else:
1002*da0073e9SAndroid Build Coastguard Worker            use_lib = lib
1003*da0073e9SAndroid Build Coastguard Worker        use_lib._register_torch_dispatch_rule(op_name, torch_dispatch_class, func)
1004*da0073e9SAndroid Build Coastguard Worker        return func
1005*da0073e9SAndroid Build Coastguard Worker
1006*da0073e9SAndroid Build Coastguard Worker    if func is None:
1007*da0073e9SAndroid Build Coastguard Worker        return register
1008*da0073e9SAndroid Build Coastguard Worker    else:
1009*da0073e9SAndroid Build Coastguard Worker        return register(func)
1010*da0073e9SAndroid Build Coastguard Worker
1011*da0073e9SAndroid Build Coastguard Worker
1012*da0073e9SAndroid Build Coastguard Workerdef register_vmap(
1013*da0073e9SAndroid Build Coastguard Worker    op: _op_identifier,
1014*da0073e9SAndroid Build Coastguard Worker    func: Optional[Callable] = None,
1015*da0073e9SAndroid Build Coastguard Worker    /,
1016*da0073e9SAndroid Build Coastguard Worker    *,
1017*da0073e9SAndroid Build Coastguard Worker    lib=None,
1018*da0073e9SAndroid Build Coastguard Worker):
1019*da0073e9SAndroid Build Coastguard Worker    r"""Register a vmap implementation to support :func:`torch.vmap` for this custom op.
1020*da0073e9SAndroid Build Coastguard Worker
1021*da0073e9SAndroid Build Coastguard Worker    This API may be used as a decorator (see examples).
1022*da0073e9SAndroid Build Coastguard Worker
1023*da0073e9SAndroid Build Coastguard Worker    In order for an operator to work with :func:`torch.vmap`, you may need to register a
1024*da0073e9SAndroid Build Coastguard Worker    vmap implementation in the following signature:
1025*da0073e9SAndroid Build Coastguard Worker
1026*da0073e9SAndroid Build Coastguard Worker        ``vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs)``,
1027*da0073e9SAndroid Build Coastguard Worker
1028*da0073e9SAndroid Build Coastguard Worker    where ``*args`` and ``**kwargs`` are the arguments and kwargs for ``op``.
1029*da0073e9SAndroid Build Coastguard Worker    We do not support kwarg-only Tensor args.
1030*da0073e9SAndroid Build Coastguard Worker
1031*da0073e9SAndroid Build Coastguard Worker    It specifies how do we compute the batched version of ``op`` given inputs with an additional
1032*da0073e9SAndroid Build Coastguard Worker    dimension (specified by ``in_dims``).
1033*da0073e9SAndroid Build Coastguard Worker
1034*da0073e9SAndroid Build Coastguard Worker    For each arg in ``args``, ``in_dims`` has a corresponding ``Optional[int]``. It is ``None``
1035*da0073e9SAndroid Build Coastguard Worker    if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer
1036*da0073e9SAndroid Build Coastguard Worker    specifying what dimension of the Tensor is being vmapped over.
1037*da0073e9SAndroid Build Coastguard Worker
1038*da0073e9SAndroid Build Coastguard Worker    ``info`` is a collection of additional metadata that may be helpful:
1039*da0073e9SAndroid Build Coastguard Worker    ``info.batch_size`` specifies the size of the dimension being vmapped over, while
1040*da0073e9SAndroid Build Coastguard Worker    ``info.randomness`` is the ``randomness`` option that was passed to :func:`torch.vmap`.
1041*da0073e9SAndroid Build Coastguard Worker
1042*da0073e9SAndroid Build Coastguard Worker    The return of the function ``func`` is a tuple of ``(output, out_dims)``. Similar to ``in_dims``,
1043*da0073e9SAndroid Build Coastguard Worker    ``out_dims`` should be of the same structure as ``output`` and contain one ``out_dim``
1044*da0073e9SAndroid Build Coastguard Worker    per output that specifies if the output has the vmapped dimension and what index it is in.
1045*da0073e9SAndroid Build Coastguard Worker
1046*da0073e9SAndroid Build Coastguard Worker    Examples:
1047*da0073e9SAndroid Build Coastguard Worker        >>> import torch
1048*da0073e9SAndroid Build Coastguard Worker        >>> import numpy as np
1049*da0073e9SAndroid Build Coastguard Worker        >>> from torch import Tensor
1050*da0073e9SAndroid Build Coastguard Worker        >>> from typing import Tuple
1051*da0073e9SAndroid Build Coastguard Worker        >>>
1052*da0073e9SAndroid Build Coastguard Worker        >>> def to_numpy(tensor):
1053*da0073e9SAndroid Build Coastguard Worker        >>>     return tensor.cpu().numpy()
1054*da0073e9SAndroid Build Coastguard Worker        >>>
1055*da0073e9SAndroid Build Coastguard Worker        >>> lib = torch.library.Library("mylib", "FRAGMENT")
1056*da0073e9SAndroid Build Coastguard Worker        >>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=())
1057*da0073e9SAndroid Build Coastguard Worker        >>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]:
1058*da0073e9SAndroid Build Coastguard Worker        >>>     x_np = to_numpy(x)
1059*da0073e9SAndroid Build Coastguard Worker        >>>     dx = torch.tensor(3 * x_np ** 2, device=x.device)
1060*da0073e9SAndroid Build Coastguard Worker        >>>     return torch.tensor(x_np ** 3, device=x.device), dx
1061*da0073e9SAndroid Build Coastguard Worker        >>>
1062*da0073e9SAndroid Build Coastguard Worker        >>> def numpy_cube_vmap(info, in_dims, x):
1063*da0073e9SAndroid Build Coastguard Worker        >>>     result = numpy_cube(x)
1064*da0073e9SAndroid Build Coastguard Worker        >>>     return result, (in_dims[0], in_dims[0])
1065*da0073e9SAndroid Build Coastguard Worker        >>>
1066*da0073e9SAndroid Build Coastguard Worker        >>> torch.library.register_vmap(numpy_cube, numpy_cube_vmap)
1067*da0073e9SAndroid Build Coastguard Worker        >>>
1068*da0073e9SAndroid Build Coastguard Worker        >>> x = torch.randn(3)
1069*da0073e9SAndroid Build Coastguard Worker        >>> torch.vmap(numpy_cube)(x)
1070*da0073e9SAndroid Build Coastguard Worker        >>>
1071*da0073e9SAndroid Build Coastguard Worker        >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
1072*da0073e9SAndroid Build Coastguard Worker        >>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor:
1073*da0073e9SAndroid Build Coastguard Worker        >>>     return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device)
1074*da0073e9SAndroid Build Coastguard Worker        >>>
1075*da0073e9SAndroid Build Coastguard Worker        >>> @torch.library.register_vmap("mylib::numpy_mul")
1076*da0073e9SAndroid Build Coastguard Worker        >>> def numpy_mul_vmap(info, in_dims, x, y):
1077*da0073e9SAndroid Build Coastguard Worker        >>>     x_bdim, y_bdim = in_dims
1078*da0073e9SAndroid Build Coastguard Worker        >>>     x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1)
1079*da0073e9SAndroid Build Coastguard Worker        >>>     y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1)
1080*da0073e9SAndroid Build Coastguard Worker        >>>     result = x * y
1081*da0073e9SAndroid Build Coastguard Worker        >>>     result = result.movedim(-1, 0)
1082*da0073e9SAndroid Build Coastguard Worker        >>>     return result, 0
1083*da0073e9SAndroid Build Coastguard Worker        >>>
1084*da0073e9SAndroid Build Coastguard Worker        >>>
1085*da0073e9SAndroid Build Coastguard Worker        >>> x = torch.randn(3)
1086*da0073e9SAndroid Build Coastguard Worker        >>> y = torch.randn(3)
1087*da0073e9SAndroid Build Coastguard Worker        >>> torch.vmap(numpy_mul)(x, y)
1088*da0073e9SAndroid Build Coastguard Worker
1089*da0073e9SAndroid Build Coastguard Worker    .. note::
1090*da0073e9SAndroid Build Coastguard Worker        The vmap function should aim to preserve the semantics of the entire custom operator.
1091*da0073e9SAndroid Build Coastguard Worker        That is, ``grad(vmap(op))`` should be replaceable with a ``grad(map(op))``.
1092*da0073e9SAndroid Build Coastguard Worker
1093*da0073e9SAndroid Build Coastguard Worker        If your custom operator has any custom behavior in the backward pass, please
1094*da0073e9SAndroid Build Coastguard Worker        keep this in mind.
1095*da0073e9SAndroid Build Coastguard Worker
1096*da0073e9SAndroid Build Coastguard Worker    """
1097*da0073e9SAndroid Build Coastguard Worker    if not isinstance(
1098*da0073e9SAndroid Build Coastguard Worker        op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef)
1099*da0073e9SAndroid Build Coastguard Worker    ):
1100*da0073e9SAndroid Build Coastguard Worker        raise ValueError(f"register_vmap(op): got unexpected type for op: {type(op)}")
1101*da0073e9SAndroid Build Coastguard Worker    if isinstance(op, torch._ops.OpOverload):
1102*da0073e9SAndroid Build Coastguard Worker        op = op._name
1103*da0073e9SAndroid Build Coastguard Worker    opdef = _maybe_get_opdef(op)
1104*da0073e9SAndroid Build Coastguard Worker    if opdef is not None:
1105*da0073e9SAndroid Build Coastguard Worker        return opdef.register_vmap(func)
1106*da0073e9SAndroid Build Coastguard Worker    assert isinstance(op, str)
1107*da0073e9SAndroid Build Coastguard Worker    qualname = op
1108*da0073e9SAndroid Build Coastguard Worker    op = torch._library.utils.lookup_op(qualname)
1109*da0073e9SAndroid Build Coastguard Worker    schema = op._schema
1110*da0073e9SAndroid Build Coastguard Worker    if _library.utils.has_kwarg_only_tensors(schema):
1111*da0073e9SAndroid Build Coastguard Worker        raise NotImplementedError(
1112*da0073e9SAndroid Build Coastguard Worker            f"register_vmap with kwarg-only Tensor args. In the original "
1113*da0073e9SAndroid Build Coastguard Worker            f"definition of the op, please make your tensors not kwarg-only. "
1114*da0073e9SAndroid Build Coastguard Worker            f"Got: {schema}"
1115*da0073e9SAndroid Build Coastguard Worker        )
1116*da0073e9SAndroid Build Coastguard Worker
1117*da0073e9SAndroid Build Coastguard Worker    def register(func):
1118*da0073e9SAndroid Build Coastguard Worker        nonlocal op, lib
1119*da0073e9SAndroid Build Coastguard Worker
1120*da0073e9SAndroid Build Coastguard Worker        namespace, opname = torch._library.utils.parse_namespace(qualname)
1121*da0073e9SAndroid Build Coastguard Worker        if lib is None:
1122*da0073e9SAndroid Build Coastguard Worker            lib = Library(namespace, "FRAGMENT")
1123*da0073e9SAndroid Build Coastguard Worker            _keep_alive.append(lib)
1124*da0073e9SAndroid Build Coastguard Worker
1125*da0073e9SAndroid Build Coastguard Worker        from torch._functorch.autograd_function import custom_function_call_vmap_helper
1126*da0073e9SAndroid Build Coastguard Worker        from torch._functorch.pyfunctorch import retrieve_current_functorch_interpreter
1127*da0073e9SAndroid Build Coastguard Worker
1128*da0073e9SAndroid Build Coastguard Worker        def wrapped_func(keyset, *args, **kwargs):
1129*da0073e9SAndroid Build Coastguard Worker            interpreter = retrieve_current_functorch_interpreter()
1130*da0073e9SAndroid Build Coastguard Worker            return custom_function_call_vmap_helper(
1131*da0073e9SAndroid Build Coastguard Worker                interpreter, func, op, *args, **kwargs
1132*da0073e9SAndroid Build Coastguard Worker            )
1133*da0073e9SAndroid Build Coastguard Worker
1134*da0073e9SAndroid Build Coastguard Worker        lib.impl(opname, wrapped_func, "FuncTorchBatched", with_keyset=True)
1135*da0073e9SAndroid Build Coastguard Worker
1136*da0073e9SAndroid Build Coastguard Worker    if func is None:
1137*da0073e9SAndroid Build Coastguard Worker        return register
1138*da0073e9SAndroid Build Coastguard Worker    else:
1139*da0073e9SAndroid Build Coastguard Worker        return register(func)
1140*da0073e9SAndroid Build Coastguard Worker
1141*da0073e9SAndroid Build Coastguard Worker
1142*da0073e9SAndroid Build Coastguard Worker# If the op was defined in C++, then we want to make sure there was an
1143*da0073e9SAndroid Build Coastguard Worker# m.set_python_module(module, ...) call and that the module is the
1144*da0073e9SAndroid Build Coastguard Worker# same as the module that called torch.library.register_fake.
1145*da0073e9SAndroid Build Coastguard Workerdef _check_pystubs_once(func, qualname, actual_module_name):
1146*da0073e9SAndroid Build Coastguard Worker    checked = False
1147*da0073e9SAndroid Build Coastguard Worker
1148*da0073e9SAndroid Build Coastguard Worker    def inner(*args, **kwargs):
1149*da0073e9SAndroid Build Coastguard Worker        nonlocal checked
1150*da0073e9SAndroid Build Coastguard Worker        if checked:
1151*da0073e9SAndroid Build Coastguard Worker            return func(*args, **kwargs)
1152*da0073e9SAndroid Build Coastguard Worker
1153*da0073e9SAndroid Build Coastguard Worker        op = torch._library.utils.lookup_op(qualname)
1154*da0073e9SAndroid Build Coastguard Worker        if op._defined_in_python:
1155*da0073e9SAndroid Build Coastguard Worker            checked = True
1156*da0073e9SAndroid Build Coastguard Worker            return func(*args, **kwargs)
1157*da0073e9SAndroid Build Coastguard Worker
1158*da0073e9SAndroid Build Coastguard Worker        maybe_pystub = torch._C._dispatch_pystub(
1159*da0073e9SAndroid Build Coastguard Worker            op._schema.name, op._schema.overload_name
1160*da0073e9SAndroid Build Coastguard Worker        )
1161*da0073e9SAndroid Build Coastguard Worker        if maybe_pystub is None:
1162*da0073e9SAndroid Build Coastguard Worker            if torch._library.utils.requires_set_python_module():
1163*da0073e9SAndroid Build Coastguard Worker                namespace = op.namespace
1164*da0073e9SAndroid Build Coastguard Worker                cpp_filename = op._handle.debug()
1165*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
1166*da0073e9SAndroid Build Coastguard Worker                    f"Operator '{qualname}' was defined in C++ and has a Python "
1167*da0073e9SAndroid Build Coastguard Worker                    f"fake impl. In this situation, we require there to also be a "
1168*da0073e9SAndroid Build Coastguard Worker                    f'companion C++ `m.set_python_module("{actual_module_name}")` '
1169*da0073e9SAndroid Build Coastguard Worker                    f"call, but we could not find one. Please add that to "
1170*da0073e9SAndroid Build Coastguard Worker                    f"to the top of the C++ TORCH_LIBRARY({namespace}, ...) block the "
1171*da0073e9SAndroid Build Coastguard Worker                    f"operator was registered in ({cpp_filename})"
1172*da0073e9SAndroid Build Coastguard Worker                )
1173*da0073e9SAndroid Build Coastguard Worker        else:
1174*da0073e9SAndroid Build Coastguard Worker            pystub_module = maybe_pystub[0]
1175*da0073e9SAndroid Build Coastguard Worker            if actual_module_name != pystub_module:
1176*da0073e9SAndroid Build Coastguard Worker                cpp_filename = op._handle.debug()
1177*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError(
1178*da0073e9SAndroid Build Coastguard Worker                    f"Operator '{qualname}' specified that its python fake impl "
1179*da0073e9SAndroid Build Coastguard Worker                    f"is in the Python module '{pystub_module}' but it was actually found "
1180*da0073e9SAndroid Build Coastguard Worker                    f"in '{actual_module_name}'. Please either move the fake impl "
1181*da0073e9SAndroid Build Coastguard Worker                    f"or correct the m.set_python_module call ({cpp_filename})"
1182*da0073e9SAndroid Build Coastguard Worker                )
1183*da0073e9SAndroid Build Coastguard Worker        checked = True
1184*da0073e9SAndroid Build Coastguard Worker        return func(*args, **kwargs)
1185*da0073e9SAndroid Build Coastguard Worker
1186*da0073e9SAndroid Build Coastguard Worker    return inner
1187*da0073e9SAndroid Build Coastguard Worker
1188*da0073e9SAndroid Build Coastguard Worker
1189*da0073e9SAndroid Build Coastguard Worker# NOTE [ctx inside the fake implementation]
1190*da0073e9SAndroid Build Coastguard Worker# If a user has an operator with data-dependent output shape, then when writing
1191*da0073e9SAndroid Build Coastguard Worker# a fake implementation they must query the current ctx and use methods on the
1192*da0073e9SAndroid Build Coastguard Worker# ctx to construct a new unbacked symint.
1193*da0073e9SAndroid Build Coastguard Worker#
1194*da0073e9SAndroid Build Coastguard Worker# This is done via us setting the global_ctx_getter function every time a fake
1195*da0073e9SAndroid Build Coastguard Worker# implementation is invoked.
1196*da0073e9SAndroid Build Coastguard Workerdef get_ctx() -> "torch._library.fake_impl.FakeImplCtx":
1197*da0073e9SAndroid Build Coastguard Worker    """get_ctx() returns the current AbstractImplCtx object.
1198*da0073e9SAndroid Build Coastguard Worker
1199*da0073e9SAndroid Build Coastguard Worker    Calling ``get_ctx()`` is only valid inside of an fake impl
1200*da0073e9SAndroid Build Coastguard Worker    (see :func:`torch.library.register_fake` for more usage details.
1201*da0073e9SAndroid Build Coastguard Worker    """
1202*da0073e9SAndroid Build Coastguard Worker    return torch._library.fake_impl.global_ctx_getter()
1203*da0073e9SAndroid Build Coastguard Worker
1204*da0073e9SAndroid Build Coastguard Worker
1205*da0073e9SAndroid Build Coastguard Worker_OPCHECK_DEFAULT_UTILS = (
1206*da0073e9SAndroid Build Coastguard Worker    "test_schema",
1207*da0073e9SAndroid Build Coastguard Worker    "test_autograd_registration",
1208*da0073e9SAndroid Build Coastguard Worker    "test_faketensor",
1209*da0073e9SAndroid Build Coastguard Worker    "test_aot_dispatch_dynamic",
1210*da0073e9SAndroid Build Coastguard Worker)
1211*da0073e9SAndroid Build Coastguard Worker
1212*da0073e9SAndroid Build Coastguard Worker
1213*da0073e9SAndroid Build Coastguard Workerdef opcheck(
1214*da0073e9SAndroid Build Coastguard Worker    op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, CustomOpDef],
1215*da0073e9SAndroid Build Coastguard Worker    args: Tuple[Any, ...],
1216*da0073e9SAndroid Build Coastguard Worker    kwargs: Optional[Dict[str, Any]] = None,
1217*da0073e9SAndroid Build Coastguard Worker    *,
1218*da0073e9SAndroid Build Coastguard Worker    test_utils: Union[str, Sequence[str]] = _OPCHECK_DEFAULT_UTILS,
1219*da0073e9SAndroid Build Coastguard Worker    raise_exception: bool = True,
1220*da0073e9SAndroid Build Coastguard Worker) -> Dict[str, str]:
1221*da0073e9SAndroid Build Coastguard Worker    """Given an operator and some sample arguments, tests if the operator is
1222*da0073e9SAndroid Build Coastguard Worker    registered correctly.
1223*da0073e9SAndroid Build Coastguard Worker
1224*da0073e9SAndroid Build Coastguard Worker    That is, when you use the torch.library/TORCH_LIBRARY APIs to create a
1225*da0073e9SAndroid Build Coastguard Worker    custom op, you specified metadata (e.g. mutability info) about the custom op
1226*da0073e9SAndroid Build Coastguard Worker    and these APIs require that the functions you pass them satisfy certain
1227*da0073e9SAndroid Build Coastguard Worker    properties (e.g. no data pointer access in the fake/meta/abstract kernel)
1228*da0073e9SAndroid Build Coastguard Worker    ``opcheck`` tests these metadata and properties.
1229*da0073e9SAndroid Build Coastguard Worker
1230*da0073e9SAndroid Build Coastguard Worker    Concretely, we test the following:
1231*da0073e9SAndroid Build Coastguard Worker
1232*da0073e9SAndroid Build Coastguard Worker    - test_schema: If the schema matches the implementation of
1233*da0073e9SAndroid Build Coastguard Worker      the operator. For example: if the schema specifies a Tensor is mutated,
1234*da0073e9SAndroid Build Coastguard Worker      then we check the implementation mutates the Tensor. If the schema
1235*da0073e9SAndroid Build Coastguard Worker      specifies that we return a new Tensor, then we check that the
1236*da0073e9SAndroid Build Coastguard Worker      implementation returns a new Tensor (instead of an existing one or
1237*da0073e9SAndroid Build Coastguard Worker      a view of an existing one).
1238*da0073e9SAndroid Build Coastguard Worker    - test_autograd_registration: If the operator supports training
1239*da0073e9SAndroid Build Coastguard Worker      (autograd): we check that its autograd formula is registered via
1240*da0073e9SAndroid Build Coastguard Worker      torch.library.register_autograd or a manual registration to one
1241*da0073e9SAndroid Build Coastguard Worker      or more DispatchKey::Autograd keys. Any other DispatchKey-based
1242*da0073e9SAndroid Build Coastguard Worker      registrations may lead to undefined behavior.
1243*da0073e9SAndroid Build Coastguard Worker    - test_faketensor: If the operator has a FakeTensor kernel
1244*da0073e9SAndroid Build Coastguard Worker      (and if it is correct). The FakeTensor kernel is necessary (
1245*da0073e9SAndroid Build Coastguard Worker      but not sufficient) for the operator to work with PyTorch compilation
1246*da0073e9SAndroid Build Coastguard Worker      APIs (torch.compile/export/FX). We check that a FakeTensor kernel
1247*da0073e9SAndroid Build Coastguard Worker      (also sometimes known as a meta kernel) was registered for the
1248*da0073e9SAndroid Build Coastguard Worker      operator and that it is correct. This test takes the result of
1249*da0073e9SAndroid Build Coastguard Worker      running the operator on real tensors and the result of running
1250*da0073e9SAndroid Build Coastguard Worker      the operator on FakeTensors and checks that they have the same
1251*da0073e9SAndroid Build Coastguard Worker      Tensor metadata (sizes/strides/dtype/device/etc).
1252*da0073e9SAndroid Build Coastguard Worker    - test_aot_dispatch_dynamic: If the operator has correct behavior
1253*da0073e9SAndroid Build Coastguard Worker      with PyTorch compilation APIs (torch.compile/export/FX).
1254*da0073e9SAndroid Build Coastguard Worker      This checks that the outputs (and gradients, if applicable) are the
1255*da0073e9SAndroid Build Coastguard Worker      same under eager-mode PyTorch and torch.compile.
1256*da0073e9SAndroid Build Coastguard Worker      This test is a superset of ``test_faketensor`` and is an e2e test;
1257*da0073e9SAndroid Build Coastguard Worker      other things it tests are that the operator supports
1258*da0073e9SAndroid Build Coastguard Worker      functionalization and that the backward pass (if it exists) also
1259*da0073e9SAndroid Build Coastguard Worker      supports FakeTensor and functionalization.
1260*da0073e9SAndroid Build Coastguard Worker
1261*da0073e9SAndroid Build Coastguard Worker    For best results, please call ``opcheck`` multiple times with a
1262*da0073e9SAndroid Build Coastguard Worker    representative set of inputs. If your operator supports
1263*da0073e9SAndroid Build Coastguard Worker    autograd, please use ``opcheck`` with inputs with ``requires_grad = True``;
1264*da0073e9SAndroid Build Coastguard Worker    if your operator supports multiple devices (e.g. CPU and CUDA), please
1265*da0073e9SAndroid Build Coastguard Worker    use ``opcheck`` with inputs on all supported devices.
1266*da0073e9SAndroid Build Coastguard Worker
1267*da0073e9SAndroid Build Coastguard Worker    Args:
1268*da0073e9SAndroid Build Coastguard Worker        op: The operator. Must either be a function decorated with
1269*da0073e9SAndroid Build Coastguard Worker            :func:`torch.library.custom_op` or an OpOverload/OpOverloadPacket
1270*da0073e9SAndroid Build Coastguard Worker            found in torch.ops.* (e.g. torch.ops.aten.sin, torch.ops.mylib.foo)
1271*da0073e9SAndroid Build Coastguard Worker        args: The args to the operator
1272*da0073e9SAndroid Build Coastguard Worker        kwargs: The kwargs to the operator
1273*da0073e9SAndroid Build Coastguard Worker        test_utils: Tests that we should run. Default: all of them.
1274*da0073e9SAndroid Build Coastguard Worker            Example: ("test_schema", "test_faketensor")
1275*da0073e9SAndroid Build Coastguard Worker        raise_exception: If we should raise an exception on the first
1276*da0073e9SAndroid Build Coastguard Worker            error. If False, we will return a dict with information
1277*da0073e9SAndroid Build Coastguard Worker            on if each test passed or not.
1278*da0073e9SAndroid Build Coastguard Worker
1279*da0073e9SAndroid Build Coastguard Worker    .. warning::
1280*da0073e9SAndroid Build Coastguard Worker
1281*da0073e9SAndroid Build Coastguard Worker        opcheck and :func:`torch.autograd.gradcheck` test different things;
1282*da0073e9SAndroid Build Coastguard Worker        opcheck tests if your usage of torch.library APIs is correct while
1283*da0073e9SAndroid Build Coastguard Worker        :func:`torch.autograd.gradcheck` tests if your autograd formula is
1284*da0073e9SAndroid Build Coastguard Worker        mathematically correct. Use both to test custom ops that support
1285*da0073e9SAndroid Build Coastguard Worker        gradient computation.
1286*da0073e9SAndroid Build Coastguard Worker
1287*da0073e9SAndroid Build Coastguard Worker    Example:
1288*da0073e9SAndroid Build Coastguard Worker
1289*da0073e9SAndroid Build Coastguard Worker        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
1290*da0073e9SAndroid Build Coastguard Worker        >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=())
1291*da0073e9SAndroid Build Coastguard Worker        >>> def numpy_add(x: Tensor, y: float) -> Tensor:
1292*da0073e9SAndroid Build Coastguard Worker        >>>     x_np = x.numpy(force=True)
1293*da0073e9SAndroid Build Coastguard Worker        >>>     z_np = x_np + y
1294*da0073e9SAndroid Build Coastguard Worker        >>>     return torch.from_numpy(z_np).to(x.device)
1295*da0073e9SAndroid Build Coastguard Worker        >>>
1296*da0073e9SAndroid Build Coastguard Worker        >>> @numpy_sin.register_fake
1297*da0073e9SAndroid Build Coastguard Worker        >>> def _(x, y):
1298*da0073e9SAndroid Build Coastguard Worker        >>>     return torch.empty_like(x)
1299*da0073e9SAndroid Build Coastguard Worker        >>>
1300*da0073e9SAndroid Build Coastguard Worker        >>> def setup_context(ctx, inputs, output):
1301*da0073e9SAndroid Build Coastguard Worker        >>>     y, = inputs
1302*da0073e9SAndroid Build Coastguard Worker        >>>     ctx.y = y
1303*da0073e9SAndroid Build Coastguard Worker        >>>
1304*da0073e9SAndroid Build Coastguard Worker        >>> def backward(ctx, grad):
1305*da0073e9SAndroid Build Coastguard Worker        >>>     return grad * ctx.y, None
1306*da0073e9SAndroid Build Coastguard Worker        >>>
1307*da0073e9SAndroid Build Coastguard Worker        >>> numpy_sin.register_autograd(backward, setup_context=setup_context)
1308*da0073e9SAndroid Build Coastguard Worker        >>>
1309*da0073e9SAndroid Build Coastguard Worker        >>> sample_inputs = [
1310*da0073e9SAndroid Build Coastguard Worker        >>>     (torch.randn(3), 3.14),
1311*da0073e9SAndroid Build Coastguard Worker        >>>     (torch.randn(2, 3, device='cuda'), 2.718),
1312*da0073e9SAndroid Build Coastguard Worker        >>>     (torch.randn(1, 10, requires_grad=True), 1.234),
1313*da0073e9SAndroid Build Coastguard Worker        >>>     (torch.randn(64, 64, device='cuda', requires_grad=True), 90.18),
1314*da0073e9SAndroid Build Coastguard Worker        >>> ]
1315*da0073e9SAndroid Build Coastguard Worker        >>>
1316*da0073e9SAndroid Build Coastguard Worker        >>> for args in sample_inputs:
1317*da0073e9SAndroid Build Coastguard Worker        >>>     torch.library.opcheck(foo, args)
1318*da0073e9SAndroid Build Coastguard Worker
1319*da0073e9SAndroid Build Coastguard Worker    """
1320*da0073e9SAndroid Build Coastguard Worker    import torch.testing._internal.optests as optests
1321*da0073e9SAndroid Build Coastguard Worker
1322*da0073e9SAndroid Build Coastguard Worker    return optests.opcheck(
1323*da0073e9SAndroid Build Coastguard Worker        op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception
1324*da0073e9SAndroid Build Coastguard Worker    )
1325