1 # mypy: allow-untyped-defs 2 import inspect 3 4 from torch._custom_op.impl import ( 5 _custom_op_with_schema, 6 _find_custom_op, 7 infer_schema, 8 parse_qualname, 9 validate_namespace, 10 ) 11 from torch.library import get_ctx 12 13 14 __all__ = [ 15 "custom_op", 16 "impl", 17 "impl_abstract", 18 "get_ctx", 19 "impl_save_for_backward", 20 "impl_backward", 21 ] 22 23 24 def custom_op(qualname, func_or_schema=None): 25 r"""Register a new custom operator 26 27 In PyTorch, defining an op (short for "operator") is a two step-process: 28 - we need to define the op (by providing an operator name and schema) 29 - we need to implement behavior for how the operator interacts with 30 various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc. 31 32 This entrypoint defines the custom operator (the first step) 33 you must then perform the second step by calling various 34 ``impl_*`` APIs. 35 36 This API may be used as a decorator (see examples). 37 38 For a detailed guide on custom ops, please see 39 https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk 40 41 Arguments: 42 qualname (str): Should be a string that looks like 43 "namespace::operator_name". Operators in PyTorch need a namespace to 44 avoid name collisions; a given operator may only be created once. 45 If you are writing a Python library, we recommend the namespace to 46 be the name of your top-level module. 47 func_or_schema (Union[Callable, str]): Each PyTorch operator needs a 48 schema that tells PyTorch the types of the inputs/outputs. 49 If this is a Callable, we will automatically infer the schema from 50 the type annotations on the function (see examples). Otherwise, 51 if you don't want to use type annotations, you may provide us the 52 schema string. 53 54 Example:: 55 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) 56 >>> import torch 57 >>> import numpy as np 58 >>> from torch import Tensor 59 >>> 60 >>> # Step 1: define the custom op. 61 >>> # We need to provide the API a "prototype function" 62 >>> # (a function that returns NotImplementedError), from which 63 >>> # we will infer the types of the inputs and outputs. 64 >>> @torch._custom_ops.custom_op("mylibrary::numpy_sin") 65 >>> def numpy_sin(x: Tensor) -> Tensor: 66 >>> raise NotImplementedError 67 >>> 68 >>> # The custom op is now accessible via the torch.ops module: 69 >>> torch.ops.mylibrary.numpy_sin 70 >>> 71 >>> # Step 2: Register an implementation for various PyTorch subsystems 72 >>> 73 >>> # Register an implementation for CPU tensors 74 >>> @torch._custom_ops.impl("mylibrary::numpy_sin", device_types="cpu") 75 >>> def numpy_sin_impl_cpu(x): 76 >>> return torch.from_numpy(np.sin(x.numpy())) 77 >>> 78 >>> # Register an implementation for CUDA tensors 79 >>> @torch._custom_ops.impl("mylibrary::numpy_sin", device_types="cuda") 80 >>> def numpy_sin_impl_cuda(x): 81 >>> return torch.from_numpy(np.sin(x.cpu().numpy())).to(x.device) 82 >>> 83 >>> x = torch.randn(3) 84 >>> torch.ops.mylibrary.numpy_sin(x) # calls numpy_sin_impl_cpu 85 >>> 86 >>> x_cuda = x.cuda() 87 >>> torch.ops.mylibrary.numpy_sin(x) # calls numpy_sin_impl_cuda 88 89 """ 90 ns, name = parse_qualname(qualname) 91 validate_namespace(ns) 92 93 def inner(func): 94 if not inspect.isfunction(func): 95 raise ValueError( 96 f"custom_op(...)(func): Expected `func` to be a Python " 97 f"function, got: {type(func)}" 98 ) 99 100 if func.__name__ != name: 101 raise ValueError( 102 f"custom_op(qualname='{qualname}', ...)(func): expected `func` " 103 f"to have name '{name}' but got '{func.__name__}'. " 104 f"Please either change the name of `func` or the qualname that " 105 f"is passed to `custom_op`" 106 ) 107 108 schema = infer_schema(func, mutates_args=()) 109 _custom_op_with_schema(qualname, schema) 110 return func 111 112 if func_or_schema is None: 113 return inner 114 if isinstance(func_or_schema, str): 115 _custom_op_with_schema(qualname, func_or_schema) 116 else: 117 return inner(func_or_schema) 118 119 120 def impl(qualname, *, device_types=("cpu", "cuda"), func=None): 121 r"""Register an implementation for a device type for this custom op. 122 123 If the op is passed multiple Tensor inputs with different device 124 types, it will dispatch to the registered implementation for the highest 125 priority device type among those present. 126 The supported device types, in order of priority, are {'cuda', 'cpu'}. 127 128 This API may be used as a decorator (see examples). 129 130 For a detailed guide on custom ops, please see 131 https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk 132 133 Arguments: 134 device_types (str or Iterable[str]): the device type(s) to register the function for. 135 136 Example:: 137 >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) 138 >>> import torch 139 >>> import numpy as np 140 >>> from torch import Tensor 141 >>> 142 >>> # Step 1: define the custom op. 143 >>> # We need to provide the API a "prototype function" 144 >>> # (a function that returns NotImplementedError), from which 145 >>> # we will infer the types of the inputs and outputs. 146 >>> @torch._custom_ops.custom_op("mylibrary::numpy_cos") 147 >>> def numpy_cos(x: Tensor) -> Tensor: 148 >>> raise NotImplementedError 149 >>> 150 >>> # The custom op is now accessible via the torch.ops module: 151 >>> torch.ops.mylibrary.numpy_cos 152 >>> 153 >>> # Step 2: Register an implementation for various PyTorch subsystems 154 >>> 155 >>> # Register an implementation for CPU tensors 156 >>> @torch._custom_ops.impl("mylibrary::numpy_cos", device_types="cpu") 157 >>> def numpy_cos_impl_cpu(x): 158 >>> return torch.from_numpy(np.cos(x.numpy())) 159 >>> 160 >>> # Register an implementation for CUDA tensors 161 >>> @torch._custom_ops.impl("mylibrary::numpy_cos", device_types="cuda") 162 >>> def numpy_cos_impl_cuda(x): 163 >>> return torch.from_numpy(np.cos(x.cpu().numpy())).to(x.device) 164 >>> 165 >>> x = torch.randn(3) 166 >>> torch.ops.mylibrary.numpy_cos(x) # calls numpy_cos_impl_cpu 167 >>> 168 >>> x_cuda = x.cuda() 169 >>> torch.ops.mylibrary.numpy_cos(x) # calls numpy_cos_impl_cuda 170 171 """ 172 173 def inner(func): 174 custom_op = _find_custom_op(qualname, also_check_torch_library=True) 175 custom_op.impl(device_types, _stacklevel=3)(func) 176 return func 177 178 if func is None: 179 return inner 180 return inner(func) 181 182 183 def impl_abstract(qualname, *, func=None): 184 r"""Register an abstract implementation for this operator. 185 186 An "abstract implementation" specifies the behavior of this operator on 187 Tensors that carry no data. Given some input Tensors with certain properties 188 (sizes/strides/storage_offset/device), it specifies what the properties of 189 the output Tensors are. 190 191 The abstract implementation has the same signature as the operator. 192 It is run for both FakeTensors and meta tensors. To write an abstract 193 implementation, assume that all Tensor inputs to the operator are 194 regular CPU/CUDA/Meta tensors, but they do not have storage, and 195 you are trying to return regular CPU/CUDA/Meta tensor(s) as output. 196 The abstract implementation must consist of only PyTorch operations 197 (and may not directly access the storage or data of any input or 198 intermediate Tensors). 199 200 This API may be used as a decorator (see examples). 201 202 For a detailed guide on custom ops, please see 203 https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk 204 205 Examples:: 206 >>> import numpy as np 207 >>> from torch import Tensor 208 >>> 209 >>> # Example 1: an operator without data-dependent output shape 210 >>> @torch._custom_ops.custom_op("mylibrary::custom_linear") 211 >>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor: 212 >>> raise NotImplementedError 213 >>> 214 >>> @torch._custom_ops.impl_abstract("mylibrary::custom_linear") 215 >>> def custom_linear_abstract(x, weight): 216 >>> assert x.dim() == 2 217 >>> assert weight.dim() == 2 218 >>> assert bias.dim() == 1 219 >>> assert x.shape[1] == weight.shape[1] 220 >>> assert weight.shape[0] == bias.shape[0] 221 >>> assert x.device == weight.device 222 >>> 223 >>> return (x @ weight.t()) + bias 224 >>> 225 >>> # Example 2: an operator with data-dependent output shape 226 >>> @torch._custom_ops.custom_op('mylibrary::custom_nonzero') 227 >>> def custom_nonzero(x: Tensor) -> Tensor: 228 >>> ... 229 >>> 230 >>> @torch._custom_ops.impl_abstract("mylibrary::custom_nonzero") 231 >>> def custom_nonzero_abstract(x): 232 >>> # Number of nonzero-elements is data-dependent. 233 >>> # Since we cannot peek at the data in an abstract impl, 234 >>> # we use the ctx object to construct a new symint that 235 >>> # represents the data-dependent size. 236 >>> ctx = torch._custom_ops.get_ctx() 237 >>> nnz = ctx.create_unbacked_symint() 238 >>> shape = [x.dim(), nnz] 239 >>> result = x.new_empty(shape, dtype=torch.long) 240 >>> return result 241 >>> 242 >>> @torch._custom_ops.impl("mylibrary::custom_nonzero") 243 >>> def custom_nonzero_impl(x): 244 >>> x_np = to_numpy(x) 245 >>> res = np.stack(np.nonzero(x_np), axis=1) 246 >>> # unbacked symbolic ints in PyTorch must be >= 2, so we 247 >>> # constrain the range to at least 2 248 >>> if res.shape[0] <= 1: 249 >>> raise RuntimeError("not supported") 250 >>> return torch.tensor(res, device=x.device) 251 252 """ 253 import torch.library 254 255 return torch.library.register_fake(qualname, func, _stacklevel=2) 256 257 258 def impl_save_for_backward(qualname, *, func=None): 259 r"""Register a function that tells us what to save for backward. 260 261 Please see :func:`impl_backward` for more details. 262 """ 263 264 def inner(func): 265 custom_op = _find_custom_op(qualname, also_check_torch_library=True) 266 custom_op.impl_save_for_backward(_stacklevel=3)(func) 267 return func 268 269 if func is None: 270 return inner 271 return inner(func) 272 273 274 def impl_backward(qualname, output_differentiability=None, *, func=None): 275 r"""Registers a backward formula for an operator. 276 277 In order for an operator to work with autograd, you need to register 278 a backward formula. There are two pieces to this: 279 1. You must give us a function to specify what to save for backward. 280 Call this the "save for backward" function. 281 2. You must give us a function that computes gradients. Call this the 282 "backward" function. 283 284 Use `impl_save_for_backward` to define a "save for backward" function 285 that specifies what gets saved for backward. The function should accept 286 two arguments ``(inputs, output)`` and return the quantities to be saved 287 for backward. 288 289 During runtime, when you call the operator in a forwards pass, PyTorch 290 will invoke the "save for backward" function with the inputs and output 291 of the operator. 292 293 Use `impl_backward` to define the "backward" function. The backward 294 function must accept ``(ctx, saved, *grads)``: 295 - ``ctx`` is a context object where we may provide information 296 - ``saved`` is exactly what gets returned from the "save for backward" 297 function 298 - ``grads`` is one or more gradients. The number of gradients matches 299 the number of outputs of the operator. 300 301 The backward function must return a dict that maps the name of 302 an input to the operator to its corresponding gradient. All inputs that 303 were declared to be Tensors in the operator definition must be accounted 304 for in the dict. The gradient may be a Tensor or None. 305 306 For a detailed guide on custom ops, please see 307 https://docs.google.com/document/d/1aGWtgxV3HppuxQAdddyPrs74_aEntpkYt9MalnCKnhk 308 309 """ 310 311 def inner(func): 312 custom_op = _find_custom_op(qualname, also_check_torch_library=True) 313 custom_op.impl_backward(output_differentiability, _stacklevel=3)(func) 314 return func 315 316 if func is None: 317 return inner 318 return inner(func) 319 320 321 def _destroy(qualname): 322 """De-registers a custom op. For testing purposes only""" 323 custom_op = _find_custom_op(qualname) 324 custom_op._destroy() 325