xref: /aosp_15_r20/external/pytorch/torch/_custom_ops.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 # mypy: allow-untyped-defs
2 import inspect
3 
4 from torch._custom_op.impl import (
5     _custom_op_with_schema,
6     _find_custom_op,
7     infer_schema,
8     parse_qualname,
9     validate_namespace,
10 )
11 from torch.library import get_ctx
12 
13 
14 __all__ = [
15     "custom_op",
16     "impl",
17     "impl_abstract",
18     "get_ctx",
19     "impl_save_for_backward",
20     "impl_backward",
21 ]
22 
23 
24 def custom_op(qualname, func_or_schema=None):
25     r"""Register a new custom operator
26 
27     In PyTorch, defining an op (short for "operator") is a two step-process:
28     - we need to define the op (by providing an operator name and schema)
29     - we need to implement behavior for how the operator interacts with
30       various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc.
31 
32     This entrypoint defines the custom operator (the first step)
33     you must then perform the second step by calling various
34     ``impl_*`` APIs.
35 
36     This API may be used as a decorator (see examples).
37 
38     For a detailed guide on custom ops, please see
39     https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
40 
41     Arguments:
42         qualname (str): Should be a string that looks like
43             "namespace::operator_name". Operators in PyTorch need a namespace to
44             avoid name collisions; a given operator may only be created once.
45             If you are writing a Python library, we recommend the namespace to
46             be the name of your top-level module.
47         func_or_schema (Union[Callable, str]): Each PyTorch operator needs a
48             schema that tells PyTorch the types of the inputs/outputs.
49             If this is a Callable, we will automatically infer the schema from
50             the type annotations on the function (see examples). Otherwise,
51             if you don't want to use type annotations, you may provide us the
52             schema string.
53 
54     Example::
55         >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
56         >>> import torch
57         >>> import numpy as np
58         >>> from torch import Tensor
59         >>>
60         >>> # Step 1: define the custom op.
61         >>> # We need to provide the API a "prototype function"
62         >>> # (a function that returns NotImplementedError), from which
63         >>> # we will infer the types of the inputs and outputs.
64         >>> @torch._custom_ops.custom_op("mylibrary::numpy_sin")
65         >>> def numpy_sin(x: Tensor) -> Tensor:
66         >>>     raise NotImplementedError
67         >>>
68         >>> # The custom op is now accessible via the torch.ops module:
69         >>> torch.ops.mylibrary.numpy_sin
70         >>>
71         >>> # Step 2: Register an implementation for various PyTorch subsystems
72         >>>
73         >>> # Register an implementation for CPU tensors
74         >>> @torch._custom_ops.impl("mylibrary::numpy_sin", device_types="cpu")
75         >>> def numpy_sin_impl_cpu(x):
76         >>>     return torch.from_numpy(np.sin(x.numpy()))
77         >>>
78         >>> # Register an implementation for CUDA tensors
79         >>> @torch._custom_ops.impl("mylibrary::numpy_sin", device_types="cuda")
80         >>> def numpy_sin_impl_cuda(x):
81         >>>     return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device)
82         >>>
83         >>> x = torch.randn(3)
84         >>> torch.ops.mylibrary.numpy_sin(x)  # calls numpy_sin_impl_cpu
85         >>>
86         >>> x_cuda = x.cuda()
87         >>> torch.ops.mylibrary.numpy_sin(x)  # calls numpy_sin_impl_cuda
88 
89     """
90     ns, name = parse_qualname(qualname)
91     validate_namespace(ns)
92 
93     def inner(func):
94         if not inspect.isfunction(func):
95             raise ValueError(
96                 f"custom_op(...)(func): Expected `func` to be a Python "
97                 f"function, got: {type(func)}"
98             )
99 
100         if func.__name__ != name:
101             raise ValueError(
102                 f"custom_op(qualname='{qualname}', ...)(func): expected `func` "
103                 f"to have name '{name}' but got '{func.__name__}'. "
104                 f"Please either change the name of `func` or the qualname that "
105                 f"is passed to `custom_op`"
106             )
107 
108         schema = infer_schema(func, mutates_args=())
109         _custom_op_with_schema(qualname, schema)
110         return func
111 
112     if func_or_schema is None:
113         return inner
114     if isinstance(func_or_schema, str):
115         _custom_op_with_schema(qualname, func_or_schema)
116     else:
117         return inner(func_or_schema)
118 
119 
120 def impl(qualname, *, device_types=("cpu", "cuda"), func=None):
121     r"""Register an implementation for a device type for this custom op.
122 
123     If the op is passed multiple Tensor inputs with different device
124     types, it will dispatch to the registered implementation for the highest
125     priority device type among those present.
126     The supported device types, in order of priority, are {'cuda', 'cpu'}.
127 
128     This API may be used as a decorator (see examples).
129 
130     For a detailed guide on custom ops, please see
131     https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
132 
133     Arguments:
134         device_types (str or Iterable[str]): the device type(s) to register the function for.
135 
136     Example::
137         >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
138         >>> import torch
139         >>> import numpy as np
140         >>> from torch import Tensor
141         >>>
142         >>> # Step 1: define the custom op.
143         >>> # We need to provide the API a "prototype function"
144         >>> # (a function that returns NotImplementedError), from which
145         >>> # we will infer the types of the inputs and outputs.
146         >>> @torch._custom_ops.custom_op("mylibrary::numpy_cos")
147         >>> def numpy_cos(x: Tensor) -> Tensor:
148         >>>     raise NotImplementedError
149         >>>
150         >>> # The custom op is now accessible via the torch.ops module:
151         >>> torch.ops.mylibrary.numpy_cos
152         >>>
153         >>> # Step 2: Register an implementation for various PyTorch subsystems
154         >>>
155         >>> # Register an implementation for CPU tensors
156         >>> @torch._custom_ops.impl("mylibrary::numpy_cos", device_types="cpu")
157         >>> def numpy_cos_impl_cpu(x):
158         >>>     return torch.from_numpy(np.cos(x.numpy()))
159         >>>
160         >>> # Register an implementation for CUDA tensors
161         >>> @torch._custom_ops.impl("mylibrary::numpy_cos", device_types="cuda")
162         >>> def numpy_cos_impl_cuda(x):
163         >>>     return torch.from_numpy(np.cos(x.cpu().numpy())).to(x.device)
164         >>>
165         >>> x = torch.randn(3)
166         >>> torch.ops.mylibrary.numpy_cos(x)  # calls numpy_cos_impl_cpu
167         >>>
168         >>> x_cuda = x.cuda()
169         >>> torch.ops.mylibrary.numpy_cos(x)  # calls numpy_cos_impl_cuda
170 
171     """
172 
173     def inner(func):
174         custom_op = _find_custom_op(qualname, also_check_torch_library=True)
175         custom_op.impl(device_types, _stacklevel=3)(func)
176         return func
177 
178     if func is None:
179         return inner
180     return inner(func)
181 
182 
183 def impl_abstract(qualname, *, func=None):
184     r"""Register an abstract implementation for this operator.
185 
186     An "abstract implementation" specifies the behavior of this operator on
187     Tensors that carry no data. Given some input Tensors with certain properties
188     (sizes/strides/storage_offset/device), it specifies what the properties of
189     the output Tensors are.
190 
191     The abstract implementation has the same signature as the operator.
192     It is run for both FakeTensors and meta tensors. To write an abstract
193     implementation, assume that all Tensor inputs to the operator are
194     regular CPU/CUDA/Meta tensors, but they do not have storage, and
195     you are trying to return regular CPU/CUDA/Meta tensor(s) as output.
196     The abstract implementation must consist of only PyTorch operations
197     (and may not directly access the storage or data of any input or
198     intermediate Tensors).
199 
200     This API may be used as a decorator (see examples).
201 
202     For a detailed guide on custom ops, please see
203     https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
204 
205     Examples::
206         >>> import numpy as np
207         >>> from torch import Tensor
208         >>>
209         >>> # Example 1: an operator without data-dependent output shape
210         >>> @torch._custom_ops.custom_op("mylibrary::custom_linear")
211         >>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor:
212         >>>     raise NotImplementedError
213         >>>
214         >>> @torch._custom_ops.impl_abstract("mylibrary::custom_linear")
215         >>> def custom_linear_abstract(x, weight):
216         >>>     assert x.dim() == 2
217         >>>     assert weight.dim() == 2
218         >>>     assert bias.dim() == 1
219         >>>     assert x.shape[1] == weight.shape[1]
220         >>>     assert weight.shape[0] == bias.shape[0]
221         >>>     assert x.device == weight.device
222         >>>
223         >>>     return (x @ weight.t()) + bias
224         >>>
225         >>> # Example 2: an operator with data-dependent output shape
226         >>> @torch._custom_ops.custom_op('mylibrary::custom_nonzero')
227         >>> def custom_nonzero(x: Tensor) -> Tensor:
228         >>>     ...
229         >>>
230         >>> @torch._custom_ops.impl_abstract("mylibrary::custom_nonzero")
231         >>> def custom_nonzero_abstract(x):
232         >>>     # Number of nonzero-elements is data-dependent.
233         >>>     # Since we cannot peek at the data in an abstract impl,
234         >>>     # we use the ctx object to construct a new symint that
235         >>>     # represents the data-dependent size.
236         >>>     ctx = torch._custom_ops.get_ctx()
237         >>>     nnz = ctx.create_unbacked_symint()
238         >>>     shape = [x.dim(), nnz]
239         >>>     result = x.new_empty(shape, dtype=torch.long)
240         >>>     return result
241         >>>
242         >>> @torch._custom_ops.impl("mylibrary::custom_nonzero")
243         >>> def custom_nonzero_impl(x):
244         >>>     x_np = to_numpy(x)
245         >>>     res = np.stack(np.nonzero(x_np), axis=1)
246         >>>     # unbacked symbolic ints in PyTorch must be >= 2, so we
247         >>>     # constrain the range to at least 2
248         >>>     if res.shape[0] <= 1:
249         >>>         raise RuntimeError("not supported")
250         >>>     return torch.tensor(res, device=x.device)
251 
252     """
253     import torch.library
254 
255     return torch.library.register_fake(qualname, func, _stacklevel=2)
256 
257 
258 def impl_save_for_backward(qualname, *, func=None):
259     r"""Register a function that tells us what to save for backward.
260 
261     Please see :func:`impl_backward` for more details.
262     """
263 
264     def inner(func):
265         custom_op = _find_custom_op(qualname, also_check_torch_library=True)
266         custom_op.impl_save_for_backward(_stacklevel=3)(func)
267         return func
268 
269     if func is None:
270         return inner
271     return inner(func)
272 
273 
274 def impl_backward(qualname, output_differentiability=None, *, func=None):
275     r"""Registers a backward formula for an operator.
276 
277     In order for an operator to work with autograd, you need to register
278     a backward formula. There are two pieces to this:
279     1. You must give us a function to specify what to save for backward.
280        Call this the "save for backward" function.
281     2. You must give us a function that computes gradients. Call this the
282        "backward" function.
283 
284     Use `impl_save_for_backward` to define a "save for backward" function
285     that specifies what gets saved for backward. The function should accept
286     two arguments ``(inputs, output)`` and return the quantities to be saved
287     for backward.
288 
289     During runtime, when you call the operator in a forwards pass, PyTorch
290     will invoke the "save for backward" function with the inputs and output
291     of the operator.
292 
293     Use `impl_backward` to define the "backward" function. The backward
294     function must accept ``(ctx, saved, *grads)``:
295     - ``ctx`` is a context object where we may provide information
296     - ``saved`` is exactly what gets returned from the "save for backward"
297       function
298     - ``grads`` is one or more gradients. The number of gradients matches
299       the number of outputs of the operator.
300 
301     The backward function must return a dict that maps the name of
302     an input to the operator to its corresponding gradient. All inputs that
303     were declared to be Tensors in the operator definition must be accounted
304     for in the dict. The gradient may be a Tensor or None.
305 
306     For a detailed guide on custom ops, please see
307     https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk
308 
309     """
310 
311     def inner(func):
312         custom_op = _find_custom_op(qualname, also_check_torch_library=True)
313         custom_op.impl_backward(output_differentiability, _stacklevel=3)(func)
314         return func
315 
316     if func is None:
317         return inner
318     return inner(func)
319 
320 
321 def _destroy(qualname):
322     """De-registers a custom op. For testing purposes only"""
323     custom_op = _find_custom_op(qualname)
324     custom_op._destroy()
325