1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport contextlib 3*da0073e9SAndroid Build Coastguard Workerimport functools 4*da0073e9SAndroid Build Coastguard Workerimport inspect 5*da0073e9SAndroid Build Coastguard Workerimport re 6*da0073e9SAndroid Build Coastguard Workerimport sys 7*da0073e9SAndroid Build Coastguard Workerimport traceback 8*da0073e9SAndroid Build Coastguard Workerimport weakref 9*da0073e9SAndroid Build Coastguard Workerfrom typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union 10*da0073e9SAndroid Build Coastguard Workerfrom typing_extensions import deprecated 11*da0073e9SAndroid Build Coastguard Worker 12*da0073e9SAndroid Build Coastguard Workerimport torch 13*da0073e9SAndroid Build Coastguard Workerimport torch._library as _library 14*da0073e9SAndroid Build Coastguard Workerfrom torch._library.custom_ops import ( 15*da0073e9SAndroid Build Coastguard Worker _maybe_get_opdef, 16*da0073e9SAndroid Build Coastguard Worker custom_op, 17*da0073e9SAndroid Build Coastguard Worker CustomOpDef, 18*da0073e9SAndroid Build Coastguard Worker device_types_t, 19*da0073e9SAndroid Build Coastguard Worker) 20*da0073e9SAndroid Build Coastguard Workerfrom torch._library.infer_schema import infer_schema # noqa: F401 21*da0073e9SAndroid Build Coastguard Workerfrom torch._ops import OpOverload 22*da0073e9SAndroid Build Coastguard Worker 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker__all__ = [ 25*da0073e9SAndroid Build Coastguard Worker "Library", 26*da0073e9SAndroid Build Coastguard Worker "impl", 27*da0073e9SAndroid Build Coastguard Worker "define", 28*da0073e9SAndroid Build Coastguard Worker "fallthrough_kernel", 29*da0073e9SAndroid Build Coastguard Worker "impl_abstract", 30*da0073e9SAndroid Build Coastguard Worker "register_fake", 31*da0073e9SAndroid Build Coastguard Worker "register_torch_dispatch", 32*da0073e9SAndroid Build Coastguard Worker "register_vmap", 33*da0073e9SAndroid Build Coastguard Worker "get_ctx", 34*da0073e9SAndroid Build Coastguard Worker "custom_op", 35*da0073e9SAndroid Build Coastguard Worker "infer_schema", 36*da0073e9SAndroid Build Coastguard Worker] 37*da0073e9SAndroid Build Coastguard Worker 38*da0073e9SAndroid Build Coastguard Worker# Set containing the combination of (namespace, operator, DispatchKey) for which a new kernel has been registered 39*da0073e9SAndroid Build Coastguard Worker# The keys in the set are of the form `namespace + "/" + op_name + "/" + dispatch_key`. 40*da0073e9SAndroid Build Coastguard Worker# This set is maintained to ensure that two libraries don't try to override the exact same functionality to avoid 41*da0073e9SAndroid Build Coastguard Worker# libraries calling into kernels not intended to be called. 42*da0073e9SAndroid Build Coastguard Worker_impls: Set[str] = set() 43*da0073e9SAndroid Build Coastguard Worker_defs: Set[str] = set() 44*da0073e9SAndroid Build Coastguard Worker 45*da0073e9SAndroid Build Coastguard Worker# prim is reserved by TorchScript interpreter 46*da0073e9SAndroid Build Coastguard Worker_reserved_namespaces = ["prim"] 47*da0073e9SAndroid Build Coastguard Worker 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Workerdef fallthrough_kernel(): 50*da0073e9SAndroid Build Coastguard Worker """ 51*da0073e9SAndroid Build Coastguard Worker A dummy function to pass to ``Library.impl`` in order to register a fallthrough. 52*da0073e9SAndroid Build Coastguard Worker """ 53*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError("fallthrough_kernel() should never be called.") 54*da0073e9SAndroid Build Coastguard Worker 55*da0073e9SAndroid Build Coastguard Worker 56*da0073e9SAndroid Build Coastguard Workerclass Library: 57*da0073e9SAndroid Build Coastguard Worker """ 58*da0073e9SAndroid Build Coastguard Worker A class to create libraries that can be used to register new operators or 59*da0073e9SAndroid Build Coastguard Worker override operators in existing libraries from Python. 60*da0073e9SAndroid Build Coastguard Worker A user can optionally pass in a dispatch keyname if they only want to register 61*da0073e9SAndroid Build Coastguard Worker kernels corresponding to only one specific dispatch key. 62*da0073e9SAndroid Build Coastguard Worker 63*da0073e9SAndroid Build Coastguard Worker To create a library to override operators in an existing library (with name ns), set the kind to "IMPL". 64*da0073e9SAndroid Build Coastguard Worker To create a new library (with name ns) to register new operators, set the kind to "DEF". 65*da0073e9SAndroid Build Coastguard Worker To create a fragment of a possibly existing library to register operators (and bypass 66*da0073e9SAndroid Build Coastguard Worker the limitation that there is only one library for a given namespace), set the kind to 67*da0073e9SAndroid Build Coastguard Worker "FRAGMENT". 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker Args: 70*da0073e9SAndroid Build Coastguard Worker ns: library name 71*da0073e9SAndroid Build Coastguard Worker kind: "DEF", "IMPL" (default: "IMPL"), "FRAGMENT" 72*da0073e9SAndroid Build Coastguard Worker dispatch_key: PyTorch dispatch key (default: "") 73*da0073e9SAndroid Build Coastguard Worker """ 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker def __init__(self, ns, kind, dispatch_key=""): 76*da0073e9SAndroid Build Coastguard Worker if kind not in ("IMPL", "DEF", "FRAGMENT"): 77*da0073e9SAndroid Build Coastguard Worker raise ValueError("Unsupported kind: ", kind) 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker if ns in _reserved_namespaces and (kind == "DEF" or kind == "FRAGMENT"): 80*da0073e9SAndroid Build Coastguard Worker raise ValueError( 81*da0073e9SAndroid Build Coastguard Worker ns, 82*da0073e9SAndroid Build Coastguard Worker " is a reserved namespace. Please try creating a library with another name.", 83*da0073e9SAndroid Build Coastguard Worker ) 84*da0073e9SAndroid Build Coastguard Worker 85*da0073e9SAndroid Build Coastguard Worker frame = traceback.extract_stack(limit=3)[0] 86*da0073e9SAndroid Build Coastguard Worker filename, lineno = frame.filename, frame.lineno 87*da0073e9SAndroid Build Coastguard Worker self.m: Optional[Any] = torch._C._dispatch_library( 88*da0073e9SAndroid Build Coastguard Worker kind, ns, dispatch_key, filename, lineno 89*da0073e9SAndroid Build Coastguard Worker ) 90*da0073e9SAndroid Build Coastguard Worker self.ns = ns 91*da0073e9SAndroid Build Coastguard Worker self._op_defs: Set[str] = set() 92*da0073e9SAndroid Build Coastguard Worker self._op_impls: Set[str] = set() 93*da0073e9SAndroid Build Coastguard Worker self._registration_handles: List[torch._library.utils.RegistrationHandle] = [] 94*da0073e9SAndroid Build Coastguard Worker self.kind = kind 95*da0073e9SAndroid Build Coastguard Worker self.dispatch_key = dispatch_key 96*da0073e9SAndroid Build Coastguard Worker # Use a finalizer to setup the "destructor" instead of __del__. 97*da0073e9SAndroid Build Coastguard Worker # Python __del__ can lead to weird things (globals and locals may already 98*da0073e9SAndroid Build Coastguard Worker # be gone when __del__ actually gets called!). finalizers help the 99*da0073e9SAndroid Build Coastguard Worker # situation because it lets us capture references and keeps them alive 100*da0073e9SAndroid Build Coastguard Worker weakref.finalize( 101*da0073e9SAndroid Build Coastguard Worker self, 102*da0073e9SAndroid Build Coastguard Worker _del_library, 103*da0073e9SAndroid Build Coastguard Worker _impls, 104*da0073e9SAndroid Build Coastguard Worker self._op_impls, 105*da0073e9SAndroid Build Coastguard Worker _defs, 106*da0073e9SAndroid Build Coastguard Worker self._op_defs, 107*da0073e9SAndroid Build Coastguard Worker self._registration_handles, 108*da0073e9SAndroid Build Coastguard Worker ) 109*da0073e9SAndroid Build Coastguard Worker 110*da0073e9SAndroid Build Coastguard Worker def __repr__(self): 111*da0073e9SAndroid Build Coastguard Worker return f"Library(kind={self.kind}, ns={self.ns}, dispatch_key={self.dispatch_key})>" 112*da0073e9SAndroid Build Coastguard Worker 113*da0073e9SAndroid Build Coastguard Worker def define(self, schema, alias_analysis="", *, tags=()): 114*da0073e9SAndroid Build Coastguard Worker r"""Defines a new operator and its semantics in the ns namespace. 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker Args: 117*da0073e9SAndroid Build Coastguard Worker schema: function schema to define a new operator. 118*da0073e9SAndroid Build Coastguard Worker alias_analysis (optional): Indicates if the aliasing properties of the operator arguments can be 119*da0073e9SAndroid Build Coastguard Worker inferred from the schema (default behavior) or not ("CONSERVATIVE"). 120*da0073e9SAndroid Build Coastguard Worker tags (Tag | Sequence[Tag]): one or more torch.Tag to apply to this 121*da0073e9SAndroid Build Coastguard Worker operator. Tagging an operator changes the operator's behavior 122*da0073e9SAndroid Build Coastguard Worker under various PyTorch subsystems; please read the docs for the 123*da0073e9SAndroid Build Coastguard Worker torch.Tag carefully before applying it. 124*da0073e9SAndroid Build Coastguard Worker 125*da0073e9SAndroid Build Coastguard Worker Returns: 126*da0073e9SAndroid Build Coastguard Worker name of the operator as inferred from the schema. 127*da0073e9SAndroid Build Coastguard Worker 128*da0073e9SAndroid Build Coastguard Worker Example:: 129*da0073e9SAndroid Build Coastguard Worker >>> my_lib = Library("mylib", "DEF") 130*da0073e9SAndroid Build Coastguard Worker >>> my_lib.define("sum(Tensor self) -> Tensor") 131*da0073e9SAndroid Build Coastguard Worker """ 132*da0073e9SAndroid Build Coastguard Worker # This is added because we also want to disallow PURE_FUNCTION alias analysis which is a valid 133*da0073e9SAndroid Build Coastguard Worker # AliasAnalysis type in C++ 134*da0073e9SAndroid Build Coastguard Worker if alias_analysis not in ["", "FROM_SCHEMA", "CONSERVATIVE"]: 135*da0073e9SAndroid Build Coastguard Worker raise RuntimeError(f"Invalid alias_analysis type {alias_analysis}") 136*da0073e9SAndroid Build Coastguard Worker assert self.m is not None 137*da0073e9SAndroid Build Coastguard Worker if isinstance(tags, torch.Tag): 138*da0073e9SAndroid Build Coastguard Worker tags = (tags,) 139*da0073e9SAndroid Build Coastguard Worker 140*da0073e9SAndroid Build Coastguard Worker name = schema.split("(")[0] 141*da0073e9SAndroid Build Coastguard Worker packet_name = name.split(".")[0] if "." in name else name 142*da0073e9SAndroid Build Coastguard Worker has_preexisting_packet = hasattr(torch.ops, self.ns) and hasattr( 143*da0073e9SAndroid Build Coastguard Worker getattr(torch.ops, self.ns), packet_name 144*da0073e9SAndroid Build Coastguard Worker ) 145*da0073e9SAndroid Build Coastguard Worker 146*da0073e9SAndroid Build Coastguard Worker result = self.m.define(schema, alias_analysis, tuple(tags)) 147*da0073e9SAndroid Build Coastguard Worker name = schema.split("(")[0] 148*da0073e9SAndroid Build Coastguard Worker qualname = self.ns + "::" + name 149*da0073e9SAndroid Build Coastguard Worker 150*da0073e9SAndroid Build Coastguard Worker # If the OpOverloadPacket exists already, then this means we're adding a 151*da0073e9SAndroid Build Coastguard Worker # new OpOverload for it. Refresh the packet to include the new OpOverload. 152*da0073e9SAndroid Build Coastguard Worker if has_preexisting_packet: 153*da0073e9SAndroid Build Coastguard Worker ns = getattr(torch.ops, self.ns) 154*da0073e9SAndroid Build Coastguard Worker packet = getattr(ns, packet_name) 155*da0073e9SAndroid Build Coastguard Worker torch._ops._refresh_packet(packet) 156*da0073e9SAndroid Build Coastguard Worker 157*da0073e9SAndroid Build Coastguard Worker self._op_defs.add(qualname) 158*da0073e9SAndroid Build Coastguard Worker _defs.add(qualname) 159*da0073e9SAndroid Build Coastguard Worker return result 160*da0073e9SAndroid Build Coastguard Worker 161*da0073e9SAndroid Build Coastguard Worker def _register_fake(self, op_name, fn, _stacklevel=1): 162*da0073e9SAndroid Build Coastguard Worker r"""Registers the fake impl for an operator defined in the library.""" 163*da0073e9SAndroid Build Coastguard Worker source = torch._library.utils.get_source(_stacklevel + 1) 164*da0073e9SAndroid Build Coastguard Worker frame = sys._getframe(_stacklevel) 165*da0073e9SAndroid Build Coastguard Worker caller_module = inspect.getmodule(frame) 166*da0073e9SAndroid Build Coastguard Worker # Can be none if you call register_fake from somewhere there isn't a module 167*da0073e9SAndroid Build Coastguard Worker # (e.g. __main__) 168*da0073e9SAndroid Build Coastguard Worker caller_module_name = None if caller_module is None else caller_module.__name__ 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker # TODO(rzou): We're gonna need to stage this change with torchvision, 171*da0073e9SAndroid Build Coastguard Worker # since torchvision is github first. 172*da0073e9SAndroid Build Coastguard Worker if caller_module_name is not None and caller_module_name.startswith( 173*da0073e9SAndroid Build Coastguard Worker "torchvision." 174*da0073e9SAndroid Build Coastguard Worker ): 175*da0073e9SAndroid Build Coastguard Worker caller_module_name = None 176*da0073e9SAndroid Build Coastguard Worker 177*da0073e9SAndroid Build Coastguard Worker qualname = f"{self.ns}::{op_name}" 178*da0073e9SAndroid Build Coastguard Worker entry = torch._library.simple_registry.singleton.find(qualname) 179*da0073e9SAndroid Build Coastguard Worker if caller_module_name is not None: 180*da0073e9SAndroid Build Coastguard Worker func_to_register = _check_pystubs_once(fn, qualname, caller_module_name) 181*da0073e9SAndroid Build Coastguard Worker else: 182*da0073e9SAndroid Build Coastguard Worker func_to_register = fn 183*da0073e9SAndroid Build Coastguard Worker 184*da0073e9SAndroid Build Coastguard Worker handle = entry.fake_impl.register(func_to_register, source) 185*da0073e9SAndroid Build Coastguard Worker self._registration_handles.append(handle) 186*da0073e9SAndroid Build Coastguard Worker 187*da0073e9SAndroid Build Coastguard Worker def _register_torch_dispatch_rule(self, op_name, torch_dispatch_class, fn): 188*da0073e9SAndroid Build Coastguard Worker r"""Registers a torch_dispatch rule for the given operator and torch_dispatch_class. 189*da0073e9SAndroid Build Coastguard Worker 190*da0073e9SAndroid Build Coastguard Worker This allows for open registration to specify the behavior between the operator 191*da0073e9SAndroid Build Coastguard Worker and the torch_dispatch_class without needing to modify the torch_dispatch_class 192*da0073e9SAndroid Build Coastguard Worker or the operator directly. 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Worker The torch_dispatch_class is either a Tensor subclass with `__torch_dispatch__` or a 195*da0073e9SAndroid Build Coastguard Worker TorchDispatchMode. 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker If it is a Tensor subclass, we expect fn to have the following signature: 198*da0073e9SAndroid Build Coastguard Worker (cls, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any 199*da0073e9SAndroid Build Coastguard Worker 200*da0073e9SAndroid Build Coastguard Worker If it is a TorchDispatchMode, we expect fn to have the following signature: 201*da0073e9SAndroid Build Coastguard Worker (mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any 202*da0073e9SAndroid Build Coastguard Worker """ 203*da0073e9SAndroid Build Coastguard Worker qualname = f"{self.ns}::{op_name}" 204*da0073e9SAndroid Build Coastguard Worker entry = torch._library.simple_registry.singleton.find(qualname) 205*da0073e9SAndroid Build Coastguard Worker handle = entry.torch_dispatch_rules.register(torch_dispatch_class, fn) 206*da0073e9SAndroid Build Coastguard Worker self._registration_handles.append(handle) 207*da0073e9SAndroid Build Coastguard Worker 208*da0073e9SAndroid Build Coastguard Worker def _impl_with_aoti_compile(self, op_name, dispatch_key=""): 209*da0073e9SAndroid Build Coastguard Worker r"""Register the operator to use the AOTI-compiled implementation. 210*da0073e9SAndroid Build Coastguard Worker 211*da0073e9SAndroid Build Coastguard Worker Args: 212*da0073e9SAndroid Build Coastguard Worker op_name: operator name (along with the overload) or OpOverload object. 213*da0073e9SAndroid Build Coastguard Worker dispatch_key: dispatch key that the input function should be registered for. By default, it uses 214*da0073e9SAndroid Build Coastguard Worker the dispatch key that the library was created with. 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker Example:: 217*da0073e9SAndroid Build Coastguard Worker >>> my_lib = Library("aten", "IMPL") 218*da0073e9SAndroid Build Coastguard Worker >>> my_lib._impl_with_aoti_compile("div.Tensor", "CPU") 219*da0073e9SAndroid Build Coastguard Worker """ 220*da0073e9SAndroid Build Coastguard Worker if dispatch_key == "": 221*da0073e9SAndroid Build Coastguard Worker dispatch_key = self.dispatch_key 222*da0073e9SAndroid Build Coastguard Worker assert torch.DispatchKeySet(dispatch_key).has(torch._C.DispatchKey.Dense) 223*da0073e9SAndroid Build Coastguard Worker 224*da0073e9SAndroid Build Coastguard Worker if isinstance(op_name, str): 225*da0073e9SAndroid Build Coastguard Worker name = op_name 226*da0073e9SAndroid Build Coastguard Worker elif isinstance(op_name, OpOverload): 227*da0073e9SAndroid Build Coastguard Worker name = op_name._schema.name 228*da0073e9SAndroid Build Coastguard Worker overload_name = op_name._schema.overload_name 229*da0073e9SAndroid Build Coastguard Worker if overload_name != "": 230*da0073e9SAndroid Build Coastguard Worker name = name + "." + overload_name 231*da0073e9SAndroid Build Coastguard Worker else: 232*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 233*da0073e9SAndroid Build Coastguard Worker "_impl_with_aoti_compile should be passed either a name or an OpOverload object " 234*da0073e9SAndroid Build Coastguard Worker "as the first argument" 235*da0073e9SAndroid Build Coastguard Worker ) 236*da0073e9SAndroid Build Coastguard Worker 237*da0073e9SAndroid Build Coastguard Worker key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key 238*da0073e9SAndroid Build Coastguard Worker if key in _impls: 239*da0073e9SAndroid Build Coastguard Worker # TODO: in future, add more info about where the existing function is registered (this info is 240*da0073e9SAndroid Build Coastguard Worker # today already returned by the C++ warning when _impl_with_aoti_compile is called but we error out before that) 241*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 242*da0073e9SAndroid Build Coastguard Worker "This is not allowed since there's already a kernel registered from python overriding {}" 243*da0073e9SAndroid Build Coastguard Worker "'s behavior for {} dispatch key and {} namespace.".format( 244*da0073e9SAndroid Build Coastguard Worker name.split("::")[-1], dispatch_key, self.ns 245*da0073e9SAndroid Build Coastguard Worker ) 246*da0073e9SAndroid Build Coastguard Worker ) 247*da0073e9SAndroid Build Coastguard Worker 248*da0073e9SAndroid Build Coastguard Worker assert self.m is not None 249*da0073e9SAndroid Build Coastguard Worker impl_fn: Callable = self.m.impl_with_aoti_compile 250*da0073e9SAndroid Build Coastguard Worker impl_fn(self.ns, name.split("::")[-1], dispatch_key) 251*da0073e9SAndroid Build Coastguard Worker 252*da0073e9SAndroid Build Coastguard Worker _impls.add(key) 253*da0073e9SAndroid Build Coastguard Worker self._op_impls.add(key) 254*da0073e9SAndroid Build Coastguard Worker 255*da0073e9SAndroid Build Coastguard Worker def impl(self, op_name, fn, dispatch_key="", *, with_keyset=False): 256*da0073e9SAndroid Build Coastguard Worker r"""Registers the function implementation for an operator defined in the library. 257*da0073e9SAndroid Build Coastguard Worker 258*da0073e9SAndroid Build Coastguard Worker Args: 259*da0073e9SAndroid Build Coastguard Worker op_name: operator name (along with the overload) or OpOverload object. 260*da0073e9SAndroid Build Coastguard Worker fn: function that's the operator implementation for the input dispatch key or :func:`~fallthrough_kernel` 261*da0073e9SAndroid Build Coastguard Worker to register a fallthrough. 262*da0073e9SAndroid Build Coastguard Worker dispatch_key: dispatch key that the input function should be registered for. By default, it uses 263*da0073e9SAndroid Build Coastguard Worker the dispatch key that the library was created with. 264*da0073e9SAndroid Build Coastguard Worker with_keyset: flag controlling if the current dispatcher call keyset should be passed as the first argument 265*da0073e9SAndroid Build Coastguard Worker to :attr:`fn` when calling. This should be used to create the appropriate keyset for redispatch calls. 266*da0073e9SAndroid Build Coastguard Worker 267*da0073e9SAndroid Build Coastguard Worker Example:: 268*da0073e9SAndroid Build Coastguard Worker >>> my_lib = Library("aten", "IMPL") 269*da0073e9SAndroid Build Coastguard Worker >>> def div_cpu(self, other): 270*da0073e9SAndroid Build Coastguard Worker >>> return self * (1 / other) 271*da0073e9SAndroid Build Coastguard Worker >>> my_lib.impl("div.Tensor", div_cpu, "CPU") 272*da0073e9SAndroid Build Coastguard Worker """ 273*da0073e9SAndroid Build Coastguard Worker if not callable(fn): 274*da0073e9SAndroid Build Coastguard Worker raise TypeError( 275*da0073e9SAndroid Build Coastguard Worker f"Input function is required to be a callable but found type {type(fn)}" 276*da0073e9SAndroid Build Coastguard Worker ) 277*da0073e9SAndroid Build Coastguard Worker if dispatch_key == "": 278*da0073e9SAndroid Build Coastguard Worker dispatch_key = self.dispatch_key 279*da0073e9SAndroid Build Coastguard Worker 280*da0073e9SAndroid Build Coastguard Worker if isinstance(op_name, str): 281*da0073e9SAndroid Build Coastguard Worker name = op_name 282*da0073e9SAndroid Build Coastguard Worker elif isinstance(op_name, OpOverload): 283*da0073e9SAndroid Build Coastguard Worker name = op_name._schema.name 284*da0073e9SAndroid Build Coastguard Worker overload_name = op_name._schema.overload_name 285*da0073e9SAndroid Build Coastguard Worker if overload_name != "": 286*da0073e9SAndroid Build Coastguard Worker name = name + "." + overload_name 287*da0073e9SAndroid Build Coastguard Worker else: 288*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 289*da0073e9SAndroid Build Coastguard Worker "impl should be passed either a name or an OpOverload object as the first argument" 290*da0073e9SAndroid Build Coastguard Worker ) 291*da0073e9SAndroid Build Coastguard Worker 292*da0073e9SAndroid Build Coastguard Worker key = self.ns + "/" + name.split("::")[-1] + "/" + dispatch_key 293*da0073e9SAndroid Build Coastguard Worker if key in _impls: 294*da0073e9SAndroid Build Coastguard Worker # TODO: in future, add more info about where the existing function is registered (this info is 295*da0073e9SAndroid Build Coastguard Worker # today already returned by the C++ warning when impl is called but we error out before that) 296*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 297*da0073e9SAndroid Build Coastguard Worker "This is not allowed since there's already a kernel registered from python overriding {}" 298*da0073e9SAndroid Build Coastguard Worker "'s behavior for {} dispatch key and {} namespace.".format( 299*da0073e9SAndroid Build Coastguard Worker name.split("::")[-1], dispatch_key, self.ns 300*da0073e9SAndroid Build Coastguard Worker ) 301*da0073e9SAndroid Build Coastguard Worker ) 302*da0073e9SAndroid Build Coastguard Worker 303*da0073e9SAndroid Build Coastguard Worker if dispatch_key == "Meta": 304*da0073e9SAndroid Build Coastguard Worker dispatcher_op_name = name 305*da0073e9SAndroid Build Coastguard Worker if "::" not in dispatcher_op_name: 306*da0073e9SAndroid Build Coastguard Worker dispatcher_op_name = f"{self.ns}::{dispatcher_op_name}" 307*da0073e9SAndroid Build Coastguard Worker 308*da0073e9SAndroid Build Coastguard Worker # Internally, we shouldn't be registering meta kernels for any operators that 309*da0073e9SAndroid Build Coastguard Worker # have CompositeImplicitAutograd kernels. 310*da0073e9SAndroid Build Coastguard Worker # Instead, we should be letting those decompositions run, and writing meta kernels 311*da0073e9SAndroid Build Coastguard Worker # only for the base operators. 312*da0073e9SAndroid Build Coastguard Worker if torch._C._dispatch_has_kernel_for_dispatch_key( 313*da0073e9SAndroid Build Coastguard Worker dispatcher_op_name, "CompositeImplicitAutograd" 314*da0073e9SAndroid Build Coastguard Worker ): 315*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 316*da0073e9SAndroid Build Coastguard Worker f"We should not register a meta kernel directly to the operator '{name}'," 317*da0073e9SAndroid Build Coastguard Worker " because it has a CompositeImplicitAutograd kernel in core." 318*da0073e9SAndroid Build Coastguard Worker " Instead we should let the operator decompose, and ensure that we have meta kernels" 319*da0073e9SAndroid Build Coastguard Worker " for the base ops that it decomposes into." 320*da0073e9SAndroid Build Coastguard Worker ) 321*da0073e9SAndroid Build Coastguard Worker 322*da0073e9SAndroid Build Coastguard Worker assert self.m is not None 323*da0073e9SAndroid Build Coastguard Worker self.m.impl( 324*da0073e9SAndroid Build Coastguard Worker name, 325*da0073e9SAndroid Build Coastguard Worker dispatch_key if dispatch_key != "" else "CompositeImplicitAutograd", 326*da0073e9SAndroid Build Coastguard Worker fn, 327*da0073e9SAndroid Build Coastguard Worker with_keyset, 328*da0073e9SAndroid Build Coastguard Worker ) 329*da0073e9SAndroid Build Coastguard Worker 330*da0073e9SAndroid Build Coastguard Worker _impls.add(key) 331*da0073e9SAndroid Build Coastguard Worker self._op_impls.add(key) 332*da0073e9SAndroid Build Coastguard Worker 333*da0073e9SAndroid Build Coastguard Worker def fallback(self, fn, dispatch_key="", *, with_keyset=False): 334*da0073e9SAndroid Build Coastguard Worker r"""Registers the function implementation as the fallback for the given key. 335*da0073e9SAndroid Build Coastguard Worker 336*da0073e9SAndroid Build Coastguard Worker This function only works for a library with global namespace ("_"). 337*da0073e9SAndroid Build Coastguard Worker 338*da0073e9SAndroid Build Coastguard Worker Args: 339*da0073e9SAndroid Build Coastguard Worker fn: function used as fallback for the given dispatch key or :func:`~fallthrough_kernel` 340*da0073e9SAndroid Build Coastguard Worker to register a fallthrough. 341*da0073e9SAndroid Build Coastguard Worker dispatch_key: dispatch key that the input function should be registered for. By default, it uses 342*da0073e9SAndroid Build Coastguard Worker the dispatch key that the library was created with. 343*da0073e9SAndroid Build Coastguard Worker with_keyset: flag controlling if the current dispatcher call keyset should be passed as the first argument 344*da0073e9SAndroid Build Coastguard Worker to :attr:`fn` when calling. This should be used to create the appropriate keyset for redispatch calls. 345*da0073e9SAndroid Build Coastguard Worker 346*da0073e9SAndroid Build Coastguard Worker Example:: 347*da0073e9SAndroid Build Coastguard Worker >>> my_lib = Library("_", "IMPL") 348*da0073e9SAndroid Build Coastguard Worker >>> def fallback_kernel(op, *args, **kwargs): 349*da0073e9SAndroid Build Coastguard Worker >>> # Handle all autocast ops generically 350*da0073e9SAndroid Build Coastguard Worker >>> # ... 351*da0073e9SAndroid Build Coastguard Worker >>> my_lib.fallback(fallback_kernel, "Autocast") 352*da0073e9SAndroid Build Coastguard Worker """ 353*da0073e9SAndroid Build Coastguard Worker if dispatch_key == "": 354*da0073e9SAndroid Build Coastguard Worker dispatch_key = self.dispatch_key 355*da0073e9SAndroid Build Coastguard Worker 356*da0073e9SAndroid Build Coastguard Worker if self.ns != "_": 357*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 358*da0073e9SAndroid Build Coastguard Worker f"""Fallback can only be registered using libary fragment on the global namespace "_" but it is {self.ns}""" 359*da0073e9SAndroid Build Coastguard Worker ) 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Worker assert dispatch_key != "" 362*da0073e9SAndroid Build Coastguard Worker assert self.m is not None 363*da0073e9SAndroid Build Coastguard Worker 364*da0073e9SAndroid Build Coastguard Worker self.m.fallback(dispatch_key, fn, with_keyset) 365*da0073e9SAndroid Build Coastguard Worker 366*da0073e9SAndroid Build Coastguard Worker def _destroy(self): 367*da0073e9SAndroid Build Coastguard Worker if self.m is not None: 368*da0073e9SAndroid Build Coastguard Worker self.m.reset() 369*da0073e9SAndroid Build Coastguard Worker self.m = None 370*da0073e9SAndroid Build Coastguard Worker for handle in self._registration_handles: 371*da0073e9SAndroid Build Coastguard Worker handle.destroy() 372*da0073e9SAndroid Build Coastguard Worker self._registration_handles.clear() 373*da0073e9SAndroid Build Coastguard Worker global _impls 374*da0073e9SAndroid Build Coastguard Worker _impls -= self._op_impls 375*da0073e9SAndroid Build Coastguard Worker for name in self._op_defs: 376*da0073e9SAndroid Build Coastguard Worker # Delete the cached torch.ops.ns.foo if it was registered. 377*da0073e9SAndroid Build Coastguard Worker # Otherwise, accessing it leads to a segfault. 378*da0073e9SAndroid Build Coastguard Worker # It's possible that we only registered an overload in this Library 379*da0073e9SAndroid Build Coastguard Worker # and another library owns an alive overload. 380*da0073e9SAndroid Build Coastguard Worker # That's OK - the next time torch.ops.ns.foo gets called, it'll be 381*da0073e9SAndroid Build Coastguard Worker # recomputed to point at the right collection of overloads. 382*da0073e9SAndroid Build Coastguard Worker ns, name_with_overload = name.split("::") 383*da0073e9SAndroid Build Coastguard Worker name = name_with_overload.split(".")[0] 384*da0073e9SAndroid Build Coastguard Worker if not hasattr(torch.ops, ns): 385*da0073e9SAndroid Build Coastguard Worker continue 386*da0073e9SAndroid Build Coastguard Worker namespace = getattr(torch.ops, ns) 387*da0073e9SAndroid Build Coastguard Worker if not hasattr(namespace, name): 388*da0073e9SAndroid Build Coastguard Worker continue 389*da0073e9SAndroid Build Coastguard Worker delattr(namespace, name) 390*da0073e9SAndroid Build Coastguard Worker 391*da0073e9SAndroid Build Coastguard Worker 392*da0073e9SAndroid Build Coastguard Workerdef _del_library( 393*da0073e9SAndroid Build Coastguard Worker captured_impls, 394*da0073e9SAndroid Build Coastguard Worker op_impls, 395*da0073e9SAndroid Build Coastguard Worker captured_defs, 396*da0073e9SAndroid Build Coastguard Worker op_defs, 397*da0073e9SAndroid Build Coastguard Worker registration_handles, 398*da0073e9SAndroid Build Coastguard Worker): 399*da0073e9SAndroid Build Coastguard Worker captured_impls -= op_impls 400*da0073e9SAndroid Build Coastguard Worker captured_defs -= op_defs 401*da0073e9SAndroid Build Coastguard Worker for handle in registration_handles: 402*da0073e9SAndroid Build Coastguard Worker handle.destroy() 403*da0073e9SAndroid Build Coastguard Worker 404*da0073e9SAndroid Build Coastguard Worker 405*da0073e9SAndroid Build Coastguard Worker@contextlib.contextmanager 406*da0073e9SAndroid Build Coastguard Workerdef _scoped_library(*args, **kwargs): 407*da0073e9SAndroid Build Coastguard Worker try: 408*da0073e9SAndroid Build Coastguard Worker lib = Library(*args, **kwargs) 409*da0073e9SAndroid Build Coastguard Worker yield lib 410*da0073e9SAndroid Build Coastguard Worker finally: 411*da0073e9SAndroid Build Coastguard Worker lib._destroy() 412*da0073e9SAndroid Build Coastguard Worker 413*da0073e9SAndroid Build Coastguard Worker 414*da0073e9SAndroid Build Coastguard Worker_keep_alive: List[Library] = [] 415*da0073e9SAndroid Build Coastguard Worker 416*da0073e9SAndroid Build Coastguard Worker 417*da0073e9SAndroid Build Coastguard WorkerNAMELESS_SCHEMA = re.compile(r"\(.*\) -> .*") 418*da0073e9SAndroid Build Coastguard Worker 419*da0073e9SAndroid Build Coastguard Worker 420*da0073e9SAndroid Build Coastguard Worker@functools.singledispatch 421*da0073e9SAndroid Build Coastguard Workerdef define(qualname, schema, *, lib=None, tags=()): 422*da0073e9SAndroid Build Coastguard Worker r"""Defines a new operator. 423*da0073e9SAndroid Build Coastguard Worker 424*da0073e9SAndroid Build Coastguard Worker In PyTorch, defining an op (short for "operator") is a two step-process: 425*da0073e9SAndroid Build Coastguard Worker - we need to define the op (by providing an operator name and schema) 426*da0073e9SAndroid Build Coastguard Worker - we need to implement behavior for how the operator interacts with 427*da0073e9SAndroid Build Coastguard Worker various PyTorch subsystems, like CPU/CUDA Tensors, Autograd, etc. 428*da0073e9SAndroid Build Coastguard Worker 429*da0073e9SAndroid Build Coastguard Worker This entrypoint defines the custom operator (the first step) 430*da0073e9SAndroid Build Coastguard Worker you must then perform the second step by calling various 431*da0073e9SAndroid Build Coastguard Worker ``impl_*`` APIs, like :func:`torch.library.impl` or 432*da0073e9SAndroid Build Coastguard Worker :func:`torch.library.register_fake`. 433*da0073e9SAndroid Build Coastguard Worker 434*da0073e9SAndroid Build Coastguard Worker Args: 435*da0073e9SAndroid Build Coastguard Worker qualname (str): The qualified name for the operator. Should be 436*da0073e9SAndroid Build Coastguard Worker a string that looks like "namespace::name", e.g. "aten::sin". 437*da0073e9SAndroid Build Coastguard Worker Operators in PyTorch need a namespace to 438*da0073e9SAndroid Build Coastguard Worker avoid name collisions; a given operator may only be created once. 439*da0073e9SAndroid Build Coastguard Worker If you are writing a Python library, we recommend the namespace to 440*da0073e9SAndroid Build Coastguard Worker be the name of your top-level module. 441*da0073e9SAndroid Build Coastguard Worker schema (str): The schema of the operator. E.g. "(Tensor x) -> Tensor" 442*da0073e9SAndroid Build Coastguard Worker for an op that accepts one Tensor and returns one Tensor. It does 443*da0073e9SAndroid Build Coastguard Worker not contain the operator name (that is passed in ``qualname``). 444*da0073e9SAndroid Build Coastguard Worker lib (Optional[Library]): If provided, the lifetime of this operator 445*da0073e9SAndroid Build Coastguard Worker will be tied to the lifetime of the Library object. 446*da0073e9SAndroid Build Coastguard Worker tags (Tag | Sequence[Tag]): one or more torch.Tag to apply to this 447*da0073e9SAndroid Build Coastguard Worker operator. Tagging an operator changes the operator's behavior 448*da0073e9SAndroid Build Coastguard Worker under various PyTorch subsystems; please read the docs for the 449*da0073e9SAndroid Build Coastguard Worker torch.Tag carefully before applying it. 450*da0073e9SAndroid Build Coastguard Worker 451*da0073e9SAndroid Build Coastguard Worker Example:: 452*da0073e9SAndroid Build Coastguard Worker >>> import torch 453*da0073e9SAndroid Build Coastguard Worker >>> import numpy as np 454*da0073e9SAndroid Build Coastguard Worker >>> 455*da0073e9SAndroid Build Coastguard Worker >>> # Define the operator 456*da0073e9SAndroid Build Coastguard Worker >>> torch.library.define("mylib::sin", "(Tensor x) -> Tensor") 457*da0073e9SAndroid Build Coastguard Worker >>> 458*da0073e9SAndroid Build Coastguard Worker >>> # Add implementations for the operator 459*da0073e9SAndroid Build Coastguard Worker >>> @torch.library.impl("mylib::sin", "cpu") 460*da0073e9SAndroid Build Coastguard Worker >>> def f(x): 461*da0073e9SAndroid Build Coastguard Worker >>> return torch.from_numpy(np.sin(x.numpy())) 462*da0073e9SAndroid Build Coastguard Worker >>> 463*da0073e9SAndroid Build Coastguard Worker >>> # Call the new operator from torch.ops. 464*da0073e9SAndroid Build Coastguard Worker >>> x = torch.randn(3) 465*da0073e9SAndroid Build Coastguard Worker >>> y = torch.ops.mylib.sin(x) 466*da0073e9SAndroid Build Coastguard Worker >>> assert torch.allclose(y, x.sin()) 467*da0073e9SAndroid Build Coastguard Worker 468*da0073e9SAndroid Build Coastguard Worker """ 469*da0073e9SAndroid Build Coastguard Worker if not isinstance(qualname, str): 470*da0073e9SAndroid Build Coastguard Worker raise ValueError( 471*da0073e9SAndroid Build Coastguard Worker f"define(qualname, schema): expected qualname " 472*da0073e9SAndroid Build Coastguard Worker f"to be instance of str, got {type(qualname)}" 473*da0073e9SAndroid Build Coastguard Worker ) 474*da0073e9SAndroid Build Coastguard Worker namespace, name = torch._library.utils.parse_namespace(qualname) 475*da0073e9SAndroid Build Coastguard Worker if lib is None: 476*da0073e9SAndroid Build Coastguard Worker lib = Library(namespace, "FRAGMENT") 477*da0073e9SAndroid Build Coastguard Worker _keep_alive.append(lib) 478*da0073e9SAndroid Build Coastguard Worker if not NAMELESS_SCHEMA.fullmatch(schema): 479*da0073e9SAndroid Build Coastguard Worker raise ValueError( 480*da0073e9SAndroid Build Coastguard Worker f"define(qualname, schema, ...): expected schema " 481*da0073e9SAndroid Build Coastguard Worker f'to look like e.g. "(Tensor x) -> Tensor" but ' 482*da0073e9SAndroid Build Coastguard Worker f'got "{schema}"' 483*da0073e9SAndroid Build Coastguard Worker ) 484*da0073e9SAndroid Build Coastguard Worker lib.define(name + schema, alias_analysis="", tags=tags) 485*da0073e9SAndroid Build Coastguard Worker 486*da0073e9SAndroid Build Coastguard Worker 487*da0073e9SAndroid Build Coastguard Worker@define.register 488*da0073e9SAndroid Build Coastguard Workerdef _(lib: Library, schema, alias_analysis=""): 489*da0073e9SAndroid Build Coastguard Worker """The old torch.library.define. 490*da0073e9SAndroid Build Coastguard Worker We're keeping this around for BC reasons 491*da0073e9SAndroid Build Coastguard Worker """ 492*da0073e9SAndroid Build Coastguard Worker 493*da0073e9SAndroid Build Coastguard Worker def wrap(f): 494*da0073e9SAndroid Build Coastguard Worker name = lib.define(schema, alias_analysis) 495*da0073e9SAndroid Build Coastguard Worker lib.impl(name, f) 496*da0073e9SAndroid Build Coastguard Worker return f 497*da0073e9SAndroid Build Coastguard Worker 498*da0073e9SAndroid Build Coastguard Worker return wrap 499*da0073e9SAndroid Build Coastguard Worker 500*da0073e9SAndroid Build Coastguard Worker 501*da0073e9SAndroid Build Coastguard Worker@functools.singledispatch 502*da0073e9SAndroid Build Coastguard Workerdef impl(qualname, types, func=None, *, lib=None): 503*da0073e9SAndroid Build Coastguard Worker """Register an implementation for a device type for this operator. 504*da0073e9SAndroid Build Coastguard Worker 505*da0073e9SAndroid Build Coastguard Worker You may pass "default" for ``types`` to register this implementation as the 506*da0073e9SAndroid Build Coastguard Worker default implementation for ALL device types. 507*da0073e9SAndroid Build Coastguard Worker Please only use this if the implementation truly supports all device types; 508*da0073e9SAndroid Build Coastguard Worker for example, this is true if it is a composition of built-in PyTorch operators. 509*da0073e9SAndroid Build Coastguard Worker 510*da0073e9SAndroid Build Coastguard Worker Some valid types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu". 511*da0073e9SAndroid Build Coastguard Worker 512*da0073e9SAndroid Build Coastguard Worker Args: 513*da0073e9SAndroid Build Coastguard Worker qualname (str): Should be a string that looks like "namespace::operator_name". 514*da0073e9SAndroid Build Coastguard Worker types (str | Sequence[str]): The device types to register an impl to. 515*da0073e9SAndroid Build Coastguard Worker lib (Optional[Library]): If provided, the lifetime of this registration 516*da0073e9SAndroid Build Coastguard Worker will be tied to the lifetime of the Library object. 517*da0073e9SAndroid Build Coastguard Worker 518*da0073e9SAndroid Build Coastguard Worker Examples: 519*da0073e9SAndroid Build Coastguard Worker >>> import torch 520*da0073e9SAndroid Build Coastguard Worker >>> import numpy as np 521*da0073e9SAndroid Build Coastguard Worker >>> 522*da0073e9SAndroid Build Coastguard Worker >>> # Define the operator 523*da0073e9SAndroid Build Coastguard Worker >>> torch.library.define("mylib::mysin", "(Tensor x) -> Tensor") 524*da0073e9SAndroid Build Coastguard Worker >>> 525*da0073e9SAndroid Build Coastguard Worker >>> # Add implementations for the cpu device 526*da0073e9SAndroid Build Coastguard Worker >>> @torch.library.impl("mylib::mysin", "cpu") 527*da0073e9SAndroid Build Coastguard Worker >>> def f(x): 528*da0073e9SAndroid Build Coastguard Worker >>> return torch.from_numpy(np.sin(x.numpy())) 529*da0073e9SAndroid Build Coastguard Worker >>> 530*da0073e9SAndroid Build Coastguard Worker >>> x = torch.randn(3) 531*da0073e9SAndroid Build Coastguard Worker >>> y = torch.ops.mylib.mysin(x) 532*da0073e9SAndroid Build Coastguard Worker >>> assert torch.allclose(y, x.sin()) 533*da0073e9SAndroid Build Coastguard Worker """ 534*da0073e9SAndroid Build Coastguard Worker return _impl(qualname, types, func, lib=lib, disable_dynamo=False) 535*da0073e9SAndroid Build Coastguard Worker 536*da0073e9SAndroid Build Coastguard Worker 537*da0073e9SAndroid Build Coastguard Workerdef _impl(qualname, types, func=None, *, lib=None, disable_dynamo=False): 538*da0073e9SAndroid Build Coastguard Worker if isinstance(types, str): 539*da0073e9SAndroid Build Coastguard Worker types = (types,) 540*da0073e9SAndroid Build Coastguard Worker keys = set({}) 541*da0073e9SAndroid Build Coastguard Worker for typ in types: 542*da0073e9SAndroid Build Coastguard Worker is_dispatch_key = torch._C._parse_dispatch_key(typ) 543*da0073e9SAndroid Build Coastguard Worker if is_dispatch_key: 544*da0073e9SAndroid Build Coastguard Worker # We also support passing a DispatchKey to impl. Please prefer using 545*da0073e9SAndroid Build Coastguard Worker # the higher-level torch.library APIs and only pass DispatchKey to 546*da0073e9SAndroid Build Coastguard Worker # torch.library.impl with caution (or even better, don't use this 547*da0073e9SAndroid Build Coastguard Worker # option and file an issue on GitHub for what you need). 548*da0073e9SAndroid Build Coastguard Worker # We don't advertise this to users because 549*da0073e9SAndroid Build Coastguard Worker # it is very easy to shoot yourself in the foot. 550*da0073e9SAndroid Build Coastguard Worker keys.add(typ) 551*da0073e9SAndroid Build Coastguard Worker else: 552*da0073e9SAndroid Build Coastguard Worker keys.add(_device_type_to_key(typ)) 553*da0073e9SAndroid Build Coastguard Worker 554*da0073e9SAndroid Build Coastguard Worker def register(func): 555*da0073e9SAndroid Build Coastguard Worker namespace, _ = torch._library.utils.parse_namespace(qualname) 556*da0073e9SAndroid Build Coastguard Worker 557*da0073e9SAndroid Build Coastguard Worker if lib is None: 558*da0073e9SAndroid Build Coastguard Worker use_lib = Library(namespace, "FRAGMENT") 559*da0073e9SAndroid Build Coastguard Worker _keep_alive.append(use_lib) 560*da0073e9SAndroid Build Coastguard Worker else: 561*da0073e9SAndroid Build Coastguard Worker use_lib = lib 562*da0073e9SAndroid Build Coastguard Worker if disable_dynamo: 563*da0073e9SAndroid Build Coastguard Worker 564*da0073e9SAndroid Build Coastguard Worker @torch._disable_dynamo 565*da0073e9SAndroid Build Coastguard Worker def func_no_dynamo(*args, **kwargs): 566*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 567*da0073e9SAndroid Build Coastguard Worker 568*da0073e9SAndroid Build Coastguard Worker for key in keys: 569*da0073e9SAndroid Build Coastguard Worker use_lib.impl(qualname, func_no_dynamo, key) 570*da0073e9SAndroid Build Coastguard Worker else: 571*da0073e9SAndroid Build Coastguard Worker for key in keys: 572*da0073e9SAndroid Build Coastguard Worker use_lib.impl(qualname, func, key) 573*da0073e9SAndroid Build Coastguard Worker 574*da0073e9SAndroid Build Coastguard Worker if func is None: 575*da0073e9SAndroid Build Coastguard Worker return register 576*da0073e9SAndroid Build Coastguard Worker else: 577*da0073e9SAndroid Build Coastguard Worker register(func) 578*da0073e9SAndroid Build Coastguard Worker 579*da0073e9SAndroid Build Coastguard Worker 580*da0073e9SAndroid Build Coastguard Workerdef _device_type_to_key(device_type: str) -> str: 581*da0073e9SAndroid Build Coastguard Worker if device_type == "default": 582*da0073e9SAndroid Build Coastguard Worker # This is technically not correct, because although all device_type 583*da0073e9SAndroid Build Coastguard Worker # DispatchKeys are included in CompositeExplicitAutograd, 584*da0073e9SAndroid Build Coastguard Worker # not everything in CompositeExplicitAutograd is associated with a 585*da0073e9SAndroid Build Coastguard Worker # device_type. I don't really care that much about the difference. 586*da0073e9SAndroid Build Coastguard Worker return "CompositeExplicitAutograd" 587*da0073e9SAndroid Build Coastguard Worker return torch._C._dispatch_key_for_device(device_type) 588*da0073e9SAndroid Build Coastguard Worker 589*da0073e9SAndroid Build Coastguard Worker 590*da0073e9SAndroid Build Coastguard Worker@impl.register 591*da0073e9SAndroid Build Coastguard Workerdef _(lib: Library, name, dispatch_key=""): 592*da0073e9SAndroid Build Coastguard Worker """Legacy torch.library.impl API. Kept around for BC""" 593*da0073e9SAndroid Build Coastguard Worker 594*da0073e9SAndroid Build Coastguard Worker def wrap(f): 595*da0073e9SAndroid Build Coastguard Worker lib.impl(name, f, dispatch_key) 596*da0073e9SAndroid Build Coastguard Worker return f 597*da0073e9SAndroid Build Coastguard Worker 598*da0073e9SAndroid Build Coastguard Worker return wrap 599*da0073e9SAndroid Build Coastguard Worker 600*da0073e9SAndroid Build Coastguard Worker 601*da0073e9SAndroid Build Coastguard Worker@deprecated( 602*da0073e9SAndroid Build Coastguard Worker "`torch.library.impl_abstract` was renamed to `torch.library.register_fake`. Please use that " 603*da0073e9SAndroid Build Coastguard Worker "instead; we will remove `torch.library.impl_abstract` in a future version of PyTorch.", 604*da0073e9SAndroid Build Coastguard Worker category=FutureWarning, 605*da0073e9SAndroid Build Coastguard Worker) 606*da0073e9SAndroid Build Coastguard Workerdef impl_abstract(qualname, func=None, *, lib=None, _stacklevel=1): 607*da0073e9SAndroid Build Coastguard Worker r"""This API was renamed to :func:`torch.library.register_fake` in PyTorch 2.4. 608*da0073e9SAndroid Build Coastguard Worker Please use that instead. 609*da0073e9SAndroid Build Coastguard Worker """ 610*da0073e9SAndroid Build Coastguard Worker if func is not None: 611*da0073e9SAndroid Build Coastguard Worker _stacklevel = _stacklevel + 1 612*da0073e9SAndroid Build Coastguard Worker return register_fake(qualname, func, lib=lib, _stacklevel=_stacklevel) 613*da0073e9SAndroid Build Coastguard Worker 614*da0073e9SAndroid Build Coastguard Worker 615*da0073e9SAndroid Build Coastguard Worker_op_identifier = Union[ 616*da0073e9SAndroid Build Coastguard Worker str, "torch._ops.OpOverload", "torch._library.custom_ops.CustomOpDef" 617*da0073e9SAndroid Build Coastguard Worker] 618*da0073e9SAndroid Build Coastguard Worker 619*da0073e9SAndroid Build Coastguard Worker 620*da0073e9SAndroid Build Coastguard Workerdef register_kernel( 621*da0073e9SAndroid Build Coastguard Worker op: _op_identifier, 622*da0073e9SAndroid Build Coastguard Worker device_types: device_types_t, 623*da0073e9SAndroid Build Coastguard Worker func: Optional[Callable] = None, 624*da0073e9SAndroid Build Coastguard Worker /, 625*da0073e9SAndroid Build Coastguard Worker *, 626*da0073e9SAndroid Build Coastguard Worker lib: Optional[Library] = None, 627*da0073e9SAndroid Build Coastguard Worker): 628*da0073e9SAndroid Build Coastguard Worker """Register an implementation for a device type for this operator. 629*da0073e9SAndroid Build Coastguard Worker 630*da0073e9SAndroid Build Coastguard Worker Some valid device_types are: "cpu", "cuda", "xla", "mps", "ipu", "xpu". 631*da0073e9SAndroid Build Coastguard Worker This API may be used as a decorator. 632*da0073e9SAndroid Build Coastguard Worker 633*da0073e9SAndroid Build Coastguard Worker Args: 634*da0073e9SAndroid Build Coastguard Worker fn (Callable): The function to register as the implementation for 635*da0073e9SAndroid Build Coastguard Worker the given device types. 636*da0073e9SAndroid Build Coastguard Worker device_types (None | str | Sequence[str]): The device_types to register an impl to. 637*da0073e9SAndroid Build Coastguard Worker If None, we will register to all device types -- please only use 638*da0073e9SAndroid Build Coastguard Worker this option if your implementation is truly device-type-agnostic. 639*da0073e9SAndroid Build Coastguard Worker 640*da0073e9SAndroid Build Coastguard Worker Examples:: 641*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) 642*da0073e9SAndroid Build Coastguard Worker >>> import torch 643*da0073e9SAndroid Build Coastguard Worker >>> from torch import Tensor 644*da0073e9SAndroid Build Coastguard Worker >>> from torch.library import custom_op 645*da0073e9SAndroid Build Coastguard Worker >>> import numpy as np 646*da0073e9SAndroid Build Coastguard Worker >>> 647*da0073e9SAndroid Build Coastguard Worker >>> # Create a custom op that works on cpu 648*da0073e9SAndroid Build Coastguard Worker >>> @custom_op("mylib::numpy_sin", mutates_args=(), device_types="cpu") 649*da0073e9SAndroid Build Coastguard Worker >>> def numpy_sin(x: Tensor) -> Tensor: 650*da0073e9SAndroid Build Coastguard Worker >>> x_np = x.numpy() 651*da0073e9SAndroid Build Coastguard Worker >>> y_np = np.sin(x_np) 652*da0073e9SAndroid Build Coastguard Worker >>> return torch.from_numpy(y_np) 653*da0073e9SAndroid Build Coastguard Worker >>> 654*da0073e9SAndroid Build Coastguard Worker >>> # Add implementations for the cuda device 655*da0073e9SAndroid Build Coastguard Worker >>> @torch.library.register_kernel("mylib::numpy_sin", "cuda") 656*da0073e9SAndroid Build Coastguard Worker >>> def _(x): 657*da0073e9SAndroid Build Coastguard Worker >>> x_np = x.cpu().numpy() 658*da0073e9SAndroid Build Coastguard Worker >>> y_np = np.sin(x_np) 659*da0073e9SAndroid Build Coastguard Worker >>> return torch.from_numpy(y_np).to(device=x.device) 660*da0073e9SAndroid Build Coastguard Worker >>> 661*da0073e9SAndroid Build Coastguard Worker >>> x_cpu = torch.randn(3) 662*da0073e9SAndroid Build Coastguard Worker >>> x_cuda = x_cpu.cuda() 663*da0073e9SAndroid Build Coastguard Worker >>> assert torch.allclose(numpy_sin(x_cpu), x_cpu.sin()) 664*da0073e9SAndroid Build Coastguard Worker >>> assert torch.allclose(numpy_sin(x_cuda), x_cuda.sin()) 665*da0073e9SAndroid Build Coastguard Worker 666*da0073e9SAndroid Build Coastguard Worker """ 667*da0073e9SAndroid Build Coastguard Worker 668*da0073e9SAndroid Build Coastguard Worker if not isinstance( 669*da0073e9SAndroid Build Coastguard Worker op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) 670*da0073e9SAndroid Build Coastguard Worker ): 671*da0073e9SAndroid Build Coastguard Worker raise ValueError("register_kernel(op): got unexpected type for op: {type(op)}") 672*da0073e9SAndroid Build Coastguard Worker if isinstance(op, torch._ops.OpOverload): 673*da0073e9SAndroid Build Coastguard Worker op = op._name 674*da0073e9SAndroid Build Coastguard Worker opdef = _maybe_get_opdef(op) 675*da0073e9SAndroid Build Coastguard Worker if opdef is not None: 676*da0073e9SAndroid Build Coastguard Worker return opdef.register_kernel(device_types, func) 677*da0073e9SAndroid Build Coastguard Worker assert isinstance(op, str) 678*da0073e9SAndroid Build Coastguard Worker if device_types is None: 679*da0073e9SAndroid Build Coastguard Worker device_types = "CompositeExplicitAutograd" 680*da0073e9SAndroid Build Coastguard Worker 681*da0073e9SAndroid Build Coastguard Worker return _impl(op, device_types, func, lib=lib, disable_dynamo=True) 682*da0073e9SAndroid Build Coastguard Worker 683*da0073e9SAndroid Build Coastguard Worker 684*da0073e9SAndroid Build Coastguard Workerdef register_fake( 685*da0073e9SAndroid Build Coastguard Worker op: _op_identifier, 686*da0073e9SAndroid Build Coastguard Worker func: Optional[Callable] = None, 687*da0073e9SAndroid Build Coastguard Worker /, 688*da0073e9SAndroid Build Coastguard Worker *, 689*da0073e9SAndroid Build Coastguard Worker lib: Optional[Library] = None, 690*da0073e9SAndroid Build Coastguard Worker _stacklevel: int = 1, 691*da0073e9SAndroid Build Coastguard Worker): 692*da0073e9SAndroid Build Coastguard Worker r"""Register a FakeTensor implementation ("fake impl") for this operator. 693*da0073e9SAndroid Build Coastguard Worker 694*da0073e9SAndroid Build Coastguard Worker Also sometimes known as a "meta kernel", "abstract impl". 695*da0073e9SAndroid Build Coastguard Worker 696*da0073e9SAndroid Build Coastguard Worker An "FakeTensor implementation" specifies the behavior of this operator on 697*da0073e9SAndroid Build Coastguard Worker Tensors that carry no data ("FakeTensor"). Given some input Tensors with 698*da0073e9SAndroid Build Coastguard Worker certain properties (sizes/strides/storage_offset/device), it specifies 699*da0073e9SAndroid Build Coastguard Worker what the properties of the output Tensors are. 700*da0073e9SAndroid Build Coastguard Worker 701*da0073e9SAndroid Build Coastguard Worker The FakeTensor implementation has the same signature as the operator. 702*da0073e9SAndroid Build Coastguard Worker It is run for both FakeTensors and meta tensors. To write a FakeTensor 703*da0073e9SAndroid Build Coastguard Worker implementation, assume that all Tensor inputs to the operator are 704*da0073e9SAndroid Build Coastguard Worker regular CPU/CUDA/Meta tensors, but they do not have storage, and 705*da0073e9SAndroid Build Coastguard Worker you are trying to return regular CPU/CUDA/Meta tensor(s) as output. 706*da0073e9SAndroid Build Coastguard Worker The FakeTensor implementation must consist of only PyTorch operations 707*da0073e9SAndroid Build Coastguard Worker (and may not directly access the storage or data of any input or 708*da0073e9SAndroid Build Coastguard Worker intermediate Tensors). 709*da0073e9SAndroid Build Coastguard Worker 710*da0073e9SAndroid Build Coastguard Worker This API may be used as a decorator (see examples). 711*da0073e9SAndroid Build Coastguard Worker 712*da0073e9SAndroid Build Coastguard Worker For a detailed guide on custom ops, please see 713*da0073e9SAndroid Build Coastguard Worker https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html 714*da0073e9SAndroid Build Coastguard Worker 715*da0073e9SAndroid Build Coastguard Worker Examples: 716*da0073e9SAndroid Build Coastguard Worker >>> import torch 717*da0073e9SAndroid Build Coastguard Worker >>> import numpy as np 718*da0073e9SAndroid Build Coastguard Worker >>> from torch import Tensor 719*da0073e9SAndroid Build Coastguard Worker >>> 720*da0073e9SAndroid Build Coastguard Worker >>> # Example 1: an operator without data-dependent output shape 721*da0073e9SAndroid Build Coastguard Worker >>> @torch.library.custom_op("mylib::custom_linear", mutates_args=()) 722*da0073e9SAndroid Build Coastguard Worker >>> def custom_linear(x: Tensor, weight: Tensor, bias: Tensor) -> Tensor: 723*da0073e9SAndroid Build Coastguard Worker >>> raise NotImplementedError("Implementation goes here") 724*da0073e9SAndroid Build Coastguard Worker >>> 725*da0073e9SAndroid Build Coastguard Worker >>> @torch.library.register_fake("mylib::custom_linear") 726*da0073e9SAndroid Build Coastguard Worker >>> def _(x, weight, bias): 727*da0073e9SAndroid Build Coastguard Worker >>> assert x.dim() == 2 728*da0073e9SAndroid Build Coastguard Worker >>> assert weight.dim() == 2 729*da0073e9SAndroid Build Coastguard Worker >>> assert bias.dim() == 1 730*da0073e9SAndroid Build Coastguard Worker >>> assert x.shape[1] == weight.shape[1] 731*da0073e9SAndroid Build Coastguard Worker >>> assert weight.shape[0] == bias.shape[0] 732*da0073e9SAndroid Build Coastguard Worker >>> assert x.device == weight.device 733*da0073e9SAndroid Build Coastguard Worker >>> 734*da0073e9SAndroid Build Coastguard Worker >>> return (x @ weight.t()) + bias 735*da0073e9SAndroid Build Coastguard Worker >>> 736*da0073e9SAndroid Build Coastguard Worker >>> with torch._subclasses.fake_tensor.FakeTensorMode(): 737*da0073e9SAndroid Build Coastguard Worker >>> x = torch.randn(2, 3) 738*da0073e9SAndroid Build Coastguard Worker >>> w = torch.randn(3, 3) 739*da0073e9SAndroid Build Coastguard Worker >>> b = torch.randn(3) 740*da0073e9SAndroid Build Coastguard Worker >>> y = torch.ops.mylib.custom_linear(x, w, b) 741*da0073e9SAndroid Build Coastguard Worker >>> 742*da0073e9SAndroid Build Coastguard Worker >>> assert y.shape == (2, 3) 743*da0073e9SAndroid Build Coastguard Worker >>> 744*da0073e9SAndroid Build Coastguard Worker >>> # Example 2: an operator with data-dependent output shape 745*da0073e9SAndroid Build Coastguard Worker >>> @torch.library.custom_op("mylib::custom_nonzero", mutates_args=()) 746*da0073e9SAndroid Build Coastguard Worker >>> def custom_nonzero(x: Tensor) -> Tensor: 747*da0073e9SAndroid Build Coastguard Worker >>> x_np = x.numpy(force=True) 748*da0073e9SAndroid Build Coastguard Worker >>> res = np.stack(np.nonzero(x_np), axis=1) 749*da0073e9SAndroid Build Coastguard Worker >>> return torch.tensor(res, device=x.device) 750*da0073e9SAndroid Build Coastguard Worker >>> 751*da0073e9SAndroid Build Coastguard Worker >>> @torch.library.register_fake("mylib::custom_nonzero") 752*da0073e9SAndroid Build Coastguard Worker >>> def _(x): 753*da0073e9SAndroid Build Coastguard Worker >>> # Number of nonzero-elements is data-dependent. 754*da0073e9SAndroid Build Coastguard Worker >>> # Since we cannot peek at the data in an fake impl, 755*da0073e9SAndroid Build Coastguard Worker >>> # we use the ctx object to construct a new symint that 756*da0073e9SAndroid Build Coastguard Worker >>> # represents the data-dependent size. 757*da0073e9SAndroid Build Coastguard Worker >>> ctx = torch.library.get_ctx() 758*da0073e9SAndroid Build Coastguard Worker >>> nnz = ctx.new_dynamic_size() 759*da0073e9SAndroid Build Coastguard Worker >>> shape = [nnz, x.dim()] 760*da0073e9SAndroid Build Coastguard Worker >>> result = x.new_empty(shape, dtype=torch.int64) 761*da0073e9SAndroid Build Coastguard Worker >>> return result 762*da0073e9SAndroid Build Coastguard Worker >>> 763*da0073e9SAndroid Build Coastguard Worker >>> from torch.fx.experimental.proxy_tensor import make_fx 764*da0073e9SAndroid Build Coastguard Worker >>> 765*da0073e9SAndroid Build Coastguard Worker >>> x = torch.tensor([0, 1, 2, 3, 4, 0]) 766*da0073e9SAndroid Build Coastguard Worker >>> trace = make_fx(torch.ops.mylib.custom_nonzero, tracing_mode="symbolic")(x) 767*da0073e9SAndroid Build Coastguard Worker >>> trace.print_readable() 768*da0073e9SAndroid Build Coastguard Worker >>> 769*da0073e9SAndroid Build Coastguard Worker >>> assert torch.allclose(trace(x), torch.ops.mylib.custom_nonzero(x)) 770*da0073e9SAndroid Build Coastguard Worker 771*da0073e9SAndroid Build Coastguard Worker """ 772*da0073e9SAndroid Build Coastguard Worker if not isinstance( 773*da0073e9SAndroid Build Coastguard Worker op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) 774*da0073e9SAndroid Build Coastguard Worker ): 775*da0073e9SAndroid Build Coastguard Worker raise ValueError("register_fake(op): got unexpected type for op: {type(op)}") 776*da0073e9SAndroid Build Coastguard Worker if isinstance(op, torch._ops.OpOverload): 777*da0073e9SAndroid Build Coastguard Worker op = op._name 778*da0073e9SAndroid Build Coastguard Worker opdef = _maybe_get_opdef(op) 779*da0073e9SAndroid Build Coastguard Worker if opdef is not None: 780*da0073e9SAndroid Build Coastguard Worker if func is None: 781*da0073e9SAndroid Build Coastguard Worker return opdef.register_fake 782*da0073e9SAndroid Build Coastguard Worker else: 783*da0073e9SAndroid Build Coastguard Worker return opdef.register_fake(func) 784*da0073e9SAndroid Build Coastguard Worker assert isinstance(op, str) 785*da0073e9SAndroid Build Coastguard Worker 786*da0073e9SAndroid Build Coastguard Worker stacklevel = _stacklevel 787*da0073e9SAndroid Build Coastguard Worker 788*da0073e9SAndroid Build Coastguard Worker def register(func): 789*da0073e9SAndroid Build Coastguard Worker namespace, op_name = torch._library.utils.parse_namespace(op) 790*da0073e9SAndroid Build Coastguard Worker if lib is None: 791*da0073e9SAndroid Build Coastguard Worker use_lib = Library(namespace, "FRAGMENT") 792*da0073e9SAndroid Build Coastguard Worker _keep_alive.append(use_lib) 793*da0073e9SAndroid Build Coastguard Worker else: 794*da0073e9SAndroid Build Coastguard Worker use_lib = lib 795*da0073e9SAndroid Build Coastguard Worker use_lib._register_fake(op_name, func, _stacklevel=stacklevel + 1) 796*da0073e9SAndroid Build Coastguard Worker return func 797*da0073e9SAndroid Build Coastguard Worker 798*da0073e9SAndroid Build Coastguard Worker if func is None: 799*da0073e9SAndroid Build Coastguard Worker return register 800*da0073e9SAndroid Build Coastguard Worker else: 801*da0073e9SAndroid Build Coastguard Worker stacklevel += 1 802*da0073e9SAndroid Build Coastguard Worker return register(func) 803*da0073e9SAndroid Build Coastguard Worker 804*da0073e9SAndroid Build Coastguard Worker 805*da0073e9SAndroid Build Coastguard Workerdef register_autograd( 806*da0073e9SAndroid Build Coastguard Worker op: _op_identifier, 807*da0073e9SAndroid Build Coastguard Worker backward: Callable, 808*da0073e9SAndroid Build Coastguard Worker /, 809*da0073e9SAndroid Build Coastguard Worker *, 810*da0073e9SAndroid Build Coastguard Worker setup_context: Optional[Callable] = None, 811*da0073e9SAndroid Build Coastguard Worker lib=None, 812*da0073e9SAndroid Build Coastguard Worker) -> None: 813*da0073e9SAndroid Build Coastguard Worker r"""Register a backward formula for this custom op. 814*da0073e9SAndroid Build Coastguard Worker 815*da0073e9SAndroid Build Coastguard Worker In order for an operator to work with autograd, you need to register 816*da0073e9SAndroid Build Coastguard Worker a backward formula: 817*da0073e9SAndroid Build Coastguard Worker 1. You must tell us how to compute gradients during the backward pass 818*da0073e9SAndroid Build Coastguard Worker by providing us a "backward" function. 819*da0073e9SAndroid Build Coastguard Worker 2. If you need any values from the forward to compute gradients, you can 820*da0073e9SAndroid Build Coastguard Worker use `setup_context` to save values for backward. 821*da0073e9SAndroid Build Coastguard Worker 822*da0073e9SAndroid Build Coastguard Worker ``backward`` runs during the backward pass. It accepts ``(ctx, *grads)``: 823*da0073e9SAndroid Build Coastguard Worker - ``grads`` is one or more gradients. The number of gradients matches 824*da0073e9SAndroid Build Coastguard Worker the number of outputs of the operator. 825*da0073e9SAndroid Build Coastguard Worker The ``ctx`` object is `the same ctx object <context_method_mixins>`_ used by 826*da0073e9SAndroid Build Coastguard Worker :class:`torch.autograd.Function`. The semantics of ``backward_fn`` are the 827*da0073e9SAndroid Build Coastguard Worker same as :meth:`torch.autograd.Function.backward`. 828*da0073e9SAndroid Build Coastguard Worker 829*da0073e9SAndroid Build Coastguard Worker ``setup_context(ctx, inputs, output)`` runs during the forward pass. 830*da0073e9SAndroid Build Coastguard Worker Please save quantities needed for backward onto the ``ctx`` object via 831*da0073e9SAndroid Build Coastguard Worker either :meth:`torch.autograd.function.FunctionCtx.save_for_backward` 832*da0073e9SAndroid Build Coastguard Worker or assigning them as attributes of ``ctx``. If your custom op has 833*da0073e9SAndroid Build Coastguard Worker kwarg-only arguments, we expect the signature of ``setup_context`` 834*da0073e9SAndroid Build Coastguard Worker to be ``setup_context(ctx, inputs, keyword_only_inputs, output)``. 835*da0073e9SAndroid Build Coastguard Worker 836*da0073e9SAndroid Build Coastguard Worker Both ``setup_context_fn`` and ``backward_fn`` must be traceable. That is, 837*da0073e9SAndroid Build Coastguard Worker they may not directly access :meth:`torch.Tensor.data_ptr` and they must 838*da0073e9SAndroid Build Coastguard Worker not depend on or mutate global state. If you need a non-traceable backward, 839*da0073e9SAndroid Build Coastguard Worker you can make it a separate custom_op that you call inside ``backward_fn``. 840*da0073e9SAndroid Build Coastguard Worker 841*da0073e9SAndroid Build Coastguard Worker Examples: 842*da0073e9SAndroid Build Coastguard Worker >>> import torch 843*da0073e9SAndroid Build Coastguard Worker >>> import numpy as np 844*da0073e9SAndroid Build Coastguard Worker >>> from torch import Tensor 845*da0073e9SAndroid Build Coastguard Worker >>> 846*da0073e9SAndroid Build Coastguard Worker >>> @torch.library.custom_op("mylib::numpy_sin", mutates_args=()) 847*da0073e9SAndroid Build Coastguard Worker >>> def numpy_sin(x: Tensor) -> Tensor: 848*da0073e9SAndroid Build Coastguard Worker >>> x_np = x.cpu().numpy() 849*da0073e9SAndroid Build Coastguard Worker >>> y_np = np.sin(x_np) 850*da0073e9SAndroid Build Coastguard Worker >>> return torch.from_numpy(y_np).to(device=x.device) 851*da0073e9SAndroid Build Coastguard Worker >>> 852*da0073e9SAndroid Build Coastguard Worker >>> def setup_context(ctx, inputs, output) -> Tensor: 853*da0073e9SAndroid Build Coastguard Worker >>> x, = inputs 854*da0073e9SAndroid Build Coastguard Worker >>> ctx.save_for_backward(x) 855*da0073e9SAndroid Build Coastguard Worker >>> 856*da0073e9SAndroid Build Coastguard Worker >>> def backward(ctx, grad): 857*da0073e9SAndroid Build Coastguard Worker >>> x, = ctx.saved_tensors 858*da0073e9SAndroid Build Coastguard Worker >>> return grad * x.cos() 859*da0073e9SAndroid Build Coastguard Worker >>> 860*da0073e9SAndroid Build Coastguard Worker >>> torch.library.register_autograd( 861*da0073e9SAndroid Build Coastguard Worker ... "mylib::numpy_sin", backward, setup_context=setup_context 862*da0073e9SAndroid Build Coastguard Worker ... ) 863*da0073e9SAndroid Build Coastguard Worker >>> 864*da0073e9SAndroid Build Coastguard Worker >>> x = torch.randn(3, requires_grad=True) 865*da0073e9SAndroid Build Coastguard Worker >>> y = numpy_sin(x) 866*da0073e9SAndroid Build Coastguard Worker >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) 867*da0073e9SAndroid Build Coastguard Worker >>> assert torch.allclose(grad_x, x.cos()) 868*da0073e9SAndroid Build Coastguard Worker >>> 869*da0073e9SAndroid Build Coastguard Worker >>> # Example with a keyword-only arg 870*da0073e9SAndroid Build Coastguard Worker >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) 871*da0073e9SAndroid Build Coastguard Worker >>> def numpy_mul(x: Tensor, *, val: float) -> Tensor: 872*da0073e9SAndroid Build Coastguard Worker >>> x_np = x.cpu().numpy() 873*da0073e9SAndroid Build Coastguard Worker >>> y_np = x_np * val 874*da0073e9SAndroid Build Coastguard Worker >>> return torch.from_numpy(y_np).to(device=x.device) 875*da0073e9SAndroid Build Coastguard Worker >>> 876*da0073e9SAndroid Build Coastguard Worker >>> def setup_context(ctx, inputs, keyword_only_inputs, output) -> Tensor: 877*da0073e9SAndroid Build Coastguard Worker >>> ctx.val = keyword_only_inputs["val"] 878*da0073e9SAndroid Build Coastguard Worker >>> 879*da0073e9SAndroid Build Coastguard Worker >>> def backward(ctx, grad): 880*da0073e9SAndroid Build Coastguard Worker >>> return grad * ctx.val 881*da0073e9SAndroid Build Coastguard Worker >>> 882*da0073e9SAndroid Build Coastguard Worker >>> torch.library.register_autograd( 883*da0073e9SAndroid Build Coastguard Worker ... "mylib::numpy_mul", backward, setup_context=setup_context 884*da0073e9SAndroid Build Coastguard Worker ... ) 885*da0073e9SAndroid Build Coastguard Worker >>> 886*da0073e9SAndroid Build Coastguard Worker >>> x = torch.randn(3, requires_grad=True) 887*da0073e9SAndroid Build Coastguard Worker >>> y = numpy_mul(x, val=3.14) 888*da0073e9SAndroid Build Coastguard Worker >>> (grad_x,) = torch.autograd.grad(y, x, torch.ones_like(y)) 889*da0073e9SAndroid Build Coastguard Worker >>> assert torch.allclose(grad_x, torch.full_like(x, 3.14)) 890*da0073e9SAndroid Build Coastguard Worker 891*da0073e9SAndroid Build Coastguard Worker """ 892*da0073e9SAndroid Build Coastguard Worker if not isinstance( 893*da0073e9SAndroid Build Coastguard Worker op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) 894*da0073e9SAndroid Build Coastguard Worker ): 895*da0073e9SAndroid Build Coastguard Worker raise ValueError( 896*da0073e9SAndroid Build Coastguard Worker f"register_autograd(op): got unexpected type for op: {type(op)}" 897*da0073e9SAndroid Build Coastguard Worker ) 898*da0073e9SAndroid Build Coastguard Worker if isinstance(op, torch._ops.OpOverload): 899*da0073e9SAndroid Build Coastguard Worker op = op._name 900*da0073e9SAndroid Build Coastguard Worker opdef = _maybe_get_opdef(op) 901*da0073e9SAndroid Build Coastguard Worker if opdef is not None: 902*da0073e9SAndroid Build Coastguard Worker opdef.register_autograd(backward, setup_context=setup_context) 903*da0073e9SAndroid Build Coastguard Worker return 904*da0073e9SAndroid Build Coastguard Worker 905*da0073e9SAndroid Build Coastguard Worker assert isinstance(op, str) 906*da0073e9SAndroid Build Coastguard Worker qualname = op 907*da0073e9SAndroid Build Coastguard Worker op = torch._library.utils.lookup_op(qualname) 908*da0073e9SAndroid Build Coastguard Worker schema = op._schema 909*da0073e9SAndroid Build Coastguard Worker if not _library.utils.is_functional_schema(schema): 910*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 911*da0073e9SAndroid Build Coastguard Worker f"Cannot register autograd formula for non-functional operator " 912*da0073e9SAndroid Build Coastguard Worker f"{op} with schema {schema}. Please create " 913*da0073e9SAndroid Build Coastguard Worker f"a functional operator and register an autograd formula for that." 914*da0073e9SAndroid Build Coastguard Worker ) 915*da0073e9SAndroid Build Coastguard Worker if _library.utils.has_kwarg_only_tensors(schema): 916*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError( 917*da0073e9SAndroid Build Coastguard Worker f"register_autograd with kwarg-only Tensor args. In the original " 918*da0073e9SAndroid Build Coastguard Worker f"definition of the op, please make your tensors not kwarg-only. " 919*da0073e9SAndroid Build Coastguard Worker f"Got: {schema}" 920*da0073e9SAndroid Build Coastguard Worker ) 921*da0073e9SAndroid Build Coastguard Worker 922*da0073e9SAndroid Build Coastguard Worker info = _library.autograd.Info(backward, setup_context) 923*da0073e9SAndroid Build Coastguard Worker autograd_kernel = _library.autograd.make_autograd_impl(op, info) 924*da0073e9SAndroid Build Coastguard Worker namespace, opname = torch._library.utils.parse_namespace(qualname) 925*da0073e9SAndroid Build Coastguard Worker if lib is None: 926*da0073e9SAndroid Build Coastguard Worker lib = Library(namespace, "FRAGMENT") 927*da0073e9SAndroid Build Coastguard Worker _keep_alive.append(lib) 928*da0073e9SAndroid Build Coastguard Worker lib.impl(opname, autograd_kernel, "Autograd", with_keyset=True) 929*da0073e9SAndroid Build Coastguard Worker 930*da0073e9SAndroid Build Coastguard Worker 931*da0073e9SAndroid Build Coastguard Workerdef register_torch_dispatch( 932*da0073e9SAndroid Build Coastguard Worker op: _op_identifier, 933*da0073e9SAndroid Build Coastguard Worker torch_dispatch_class: Any, 934*da0073e9SAndroid Build Coastguard Worker func: Optional[Callable] = None, 935*da0073e9SAndroid Build Coastguard Worker /, 936*da0073e9SAndroid Build Coastguard Worker *, 937*da0073e9SAndroid Build Coastguard Worker lib: Optional[Library] = None, 938*da0073e9SAndroid Build Coastguard Worker): 939*da0073e9SAndroid Build Coastguard Worker r"""Registers a torch_dispatch rule for the given operator and ``torch_dispatch_class``. 940*da0073e9SAndroid Build Coastguard Worker 941*da0073e9SAndroid Build Coastguard Worker This allows for open registration to specify the behavior between the operator 942*da0073e9SAndroid Build Coastguard Worker and the ``torch_dispatch_class`` without needing to modify the ``torch_dispatch_class`` 943*da0073e9SAndroid Build Coastguard Worker or the operator directly. 944*da0073e9SAndroid Build Coastguard Worker 945*da0073e9SAndroid Build Coastguard Worker The ``torch_dispatch_class`` is either a Tensor subclass with ``__torch_dispatch__`` or a 946*da0073e9SAndroid Build Coastguard Worker TorchDispatchMode. 947*da0073e9SAndroid Build Coastguard Worker 948*da0073e9SAndroid Build Coastguard Worker If it is a Tensor subclass, we expect ``func`` to have the following signature: 949*da0073e9SAndroid Build Coastguard Worker ``(cls, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any`` 950*da0073e9SAndroid Build Coastguard Worker 951*da0073e9SAndroid Build Coastguard Worker If it is a TorchDispatchMode, we expect ``func`` to have the following signature: 952*da0073e9SAndroid Build Coastguard Worker ``(mode, func: OpOverload, types: Tuple[type, ...], args, kwargs) -> Any`` 953*da0073e9SAndroid Build Coastguard Worker 954*da0073e9SAndroid Build Coastguard Worker ``args`` and ``kwargs`` will have been normalized the same way they are 955*da0073e9SAndroid Build Coastguard Worker in ``__torch_dispatch__`` (see :ref:`torch-dispatch-calling-convention`). 956*da0073e9SAndroid Build Coastguard Worker 957*da0073e9SAndroid Build Coastguard Worker Examples: 958*da0073e9SAndroid Build Coastguard Worker 959*da0073e9SAndroid Build Coastguard Worker >>> import torch 960*da0073e9SAndroid Build Coastguard Worker >>> 961*da0073e9SAndroid Build Coastguard Worker >>> @torch.library.custom_op("mylib::foo", mutates_args={}) 962*da0073e9SAndroid Build Coastguard Worker >>> def foo(x: torch.Tensor) -> torch.Tensor: 963*da0073e9SAndroid Build Coastguard Worker >>> return x.clone() 964*da0073e9SAndroid Build Coastguard Worker >>> 965*da0073e9SAndroid Build Coastguard Worker >>> class MyMode(torch.utils._python_dispatch.TorchDispatchMode): 966*da0073e9SAndroid Build Coastguard Worker >>> def __torch_dispatch__(self, func, types, args=(), kwargs=None): 967*da0073e9SAndroid Build Coastguard Worker >>> return func(*args, **kwargs) 968*da0073e9SAndroid Build Coastguard Worker >>> 969*da0073e9SAndroid Build Coastguard Worker >>> @torch.library.register_torch_dispatch("mylib::foo", MyMode) 970*da0073e9SAndroid Build Coastguard Worker >>> def _(mode, func, types, args, kwargs): 971*da0073e9SAndroid Build Coastguard Worker >>> x, = args 972*da0073e9SAndroid Build Coastguard Worker >>> return x + 1 973*da0073e9SAndroid Build Coastguard Worker >>> 974*da0073e9SAndroid Build Coastguard Worker >>> x = torch.randn(3) 975*da0073e9SAndroid Build Coastguard Worker >>> y = foo(x) 976*da0073e9SAndroid Build Coastguard Worker >>> assert torch.allclose(y, x) 977*da0073e9SAndroid Build Coastguard Worker >>> 978*da0073e9SAndroid Build Coastguard Worker >>> with MyMode(): 979*da0073e9SAndroid Build Coastguard Worker >>> y = foo(x) 980*da0073e9SAndroid Build Coastguard Worker >>> assert torch.allclose(y, x + 1) 981*da0073e9SAndroid Build Coastguard Worker 982*da0073e9SAndroid Build Coastguard Worker """ 983*da0073e9SAndroid Build Coastguard Worker if not isinstance( 984*da0073e9SAndroid Build Coastguard Worker op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) 985*da0073e9SAndroid Build Coastguard Worker ): 986*da0073e9SAndroid Build Coastguard Worker raise ValueError( 987*da0073e9SAndroid Build Coastguard Worker "register_torch_dispatch(op): got unexpected type for op: {type(op)}" 988*da0073e9SAndroid Build Coastguard Worker ) 989*da0073e9SAndroid Build Coastguard Worker if isinstance(op, torch._ops.OpOverload): 990*da0073e9SAndroid Build Coastguard Worker op = op._name 991*da0073e9SAndroid Build Coastguard Worker opdef = _maybe_get_opdef(op) 992*da0073e9SAndroid Build Coastguard Worker if opdef is not None: 993*da0073e9SAndroid Build Coastguard Worker return opdef.register_torch_dispatch(torch_dispatch_class, func) 994*da0073e9SAndroid Build Coastguard Worker assert isinstance(op, str) 995*da0073e9SAndroid Build Coastguard Worker 996*da0073e9SAndroid Build Coastguard Worker def register(func): 997*da0073e9SAndroid Build Coastguard Worker namespace, op_name = torch._library.utils.parse_namespace(op) 998*da0073e9SAndroid Build Coastguard Worker if lib is None: 999*da0073e9SAndroid Build Coastguard Worker use_lib = Library(namespace, "FRAGMENT") 1000*da0073e9SAndroid Build Coastguard Worker _keep_alive.append(use_lib) 1001*da0073e9SAndroid Build Coastguard Worker else: 1002*da0073e9SAndroid Build Coastguard Worker use_lib = lib 1003*da0073e9SAndroid Build Coastguard Worker use_lib._register_torch_dispatch_rule(op_name, torch_dispatch_class, func) 1004*da0073e9SAndroid Build Coastguard Worker return func 1005*da0073e9SAndroid Build Coastguard Worker 1006*da0073e9SAndroid Build Coastguard Worker if func is None: 1007*da0073e9SAndroid Build Coastguard Worker return register 1008*da0073e9SAndroid Build Coastguard Worker else: 1009*da0073e9SAndroid Build Coastguard Worker return register(func) 1010*da0073e9SAndroid Build Coastguard Worker 1011*da0073e9SAndroid Build Coastguard Worker 1012*da0073e9SAndroid Build Coastguard Workerdef register_vmap( 1013*da0073e9SAndroid Build Coastguard Worker op: _op_identifier, 1014*da0073e9SAndroid Build Coastguard Worker func: Optional[Callable] = None, 1015*da0073e9SAndroid Build Coastguard Worker /, 1016*da0073e9SAndroid Build Coastguard Worker *, 1017*da0073e9SAndroid Build Coastguard Worker lib=None, 1018*da0073e9SAndroid Build Coastguard Worker): 1019*da0073e9SAndroid Build Coastguard Worker r"""Register a vmap implementation to support :func:`torch.vmap` for this custom op. 1020*da0073e9SAndroid Build Coastguard Worker 1021*da0073e9SAndroid Build Coastguard Worker This API may be used as a decorator (see examples). 1022*da0073e9SAndroid Build Coastguard Worker 1023*da0073e9SAndroid Build Coastguard Worker In order for an operator to work with :func:`torch.vmap`, you may need to register a 1024*da0073e9SAndroid Build Coastguard Worker vmap implementation in the following signature: 1025*da0073e9SAndroid Build Coastguard Worker 1026*da0073e9SAndroid Build Coastguard Worker ``vmap_func(info, in_dims: Tuple[Optional[int]], *args, **kwargs)``, 1027*da0073e9SAndroid Build Coastguard Worker 1028*da0073e9SAndroid Build Coastguard Worker where ``*args`` and ``**kwargs`` are the arguments and kwargs for ``op``. 1029*da0073e9SAndroid Build Coastguard Worker We do not support kwarg-only Tensor args. 1030*da0073e9SAndroid Build Coastguard Worker 1031*da0073e9SAndroid Build Coastguard Worker It specifies how do we compute the batched version of ``op`` given inputs with an additional 1032*da0073e9SAndroid Build Coastguard Worker dimension (specified by ``in_dims``). 1033*da0073e9SAndroid Build Coastguard Worker 1034*da0073e9SAndroid Build Coastguard Worker For each arg in ``args``, ``in_dims`` has a corresponding ``Optional[int]``. It is ``None`` 1035*da0073e9SAndroid Build Coastguard Worker if the arg is not a Tensor or if the arg is not being vmapped over, otherwise, it is an integer 1036*da0073e9SAndroid Build Coastguard Worker specifying what dimension of the Tensor is being vmapped over. 1037*da0073e9SAndroid Build Coastguard Worker 1038*da0073e9SAndroid Build Coastguard Worker ``info`` is a collection of additional metadata that may be helpful: 1039*da0073e9SAndroid Build Coastguard Worker ``info.batch_size`` specifies the size of the dimension being vmapped over, while 1040*da0073e9SAndroid Build Coastguard Worker ``info.randomness`` is the ``randomness`` option that was passed to :func:`torch.vmap`. 1041*da0073e9SAndroid Build Coastguard Worker 1042*da0073e9SAndroid Build Coastguard Worker The return of the function ``func`` is a tuple of ``(output, out_dims)``. Similar to ``in_dims``, 1043*da0073e9SAndroid Build Coastguard Worker ``out_dims`` should be of the same structure as ``output`` and contain one ``out_dim`` 1044*da0073e9SAndroid Build Coastguard Worker per output that specifies if the output has the vmapped dimension and what index it is in. 1045*da0073e9SAndroid Build Coastguard Worker 1046*da0073e9SAndroid Build Coastguard Worker Examples: 1047*da0073e9SAndroid Build Coastguard Worker >>> import torch 1048*da0073e9SAndroid Build Coastguard Worker >>> import numpy as np 1049*da0073e9SAndroid Build Coastguard Worker >>> from torch import Tensor 1050*da0073e9SAndroid Build Coastguard Worker >>> from typing import Tuple 1051*da0073e9SAndroid Build Coastguard Worker >>> 1052*da0073e9SAndroid Build Coastguard Worker >>> def to_numpy(tensor): 1053*da0073e9SAndroid Build Coastguard Worker >>> return tensor.cpu().numpy() 1054*da0073e9SAndroid Build Coastguard Worker >>> 1055*da0073e9SAndroid Build Coastguard Worker >>> lib = torch.library.Library("mylib", "FRAGMENT") 1056*da0073e9SAndroid Build Coastguard Worker >>> @torch.library.custom_op("mylib::numpy_cube", mutates_args=()) 1057*da0073e9SAndroid Build Coastguard Worker >>> def numpy_cube(x: Tensor) -> Tuple[Tensor, Tensor]: 1058*da0073e9SAndroid Build Coastguard Worker >>> x_np = to_numpy(x) 1059*da0073e9SAndroid Build Coastguard Worker >>> dx = torch.tensor(3 * x_np ** 2, device=x.device) 1060*da0073e9SAndroid Build Coastguard Worker >>> return torch.tensor(x_np ** 3, device=x.device), dx 1061*da0073e9SAndroid Build Coastguard Worker >>> 1062*da0073e9SAndroid Build Coastguard Worker >>> def numpy_cube_vmap(info, in_dims, x): 1063*da0073e9SAndroid Build Coastguard Worker >>> result = numpy_cube(x) 1064*da0073e9SAndroid Build Coastguard Worker >>> return result, (in_dims[0], in_dims[0]) 1065*da0073e9SAndroid Build Coastguard Worker >>> 1066*da0073e9SAndroid Build Coastguard Worker >>> torch.library.register_vmap(numpy_cube, numpy_cube_vmap) 1067*da0073e9SAndroid Build Coastguard Worker >>> 1068*da0073e9SAndroid Build Coastguard Worker >>> x = torch.randn(3) 1069*da0073e9SAndroid Build Coastguard Worker >>> torch.vmap(numpy_cube)(x) 1070*da0073e9SAndroid Build Coastguard Worker >>> 1071*da0073e9SAndroid Build Coastguard Worker >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) 1072*da0073e9SAndroid Build Coastguard Worker >>> def numpy_mul(x: Tensor, y: Tensor) -> Tensor: 1073*da0073e9SAndroid Build Coastguard Worker >>> return torch.tensor(to_numpy(x) * to_numpy(y), device=x.device) 1074*da0073e9SAndroid Build Coastguard Worker >>> 1075*da0073e9SAndroid Build Coastguard Worker >>> @torch.library.register_vmap("mylib::numpy_mul") 1076*da0073e9SAndroid Build Coastguard Worker >>> def numpy_mul_vmap(info, in_dims, x, y): 1077*da0073e9SAndroid Build Coastguard Worker >>> x_bdim, y_bdim = in_dims 1078*da0073e9SAndroid Build Coastguard Worker >>> x = x.movedim(x_bdim, -1) if x_bdim is not None else x.unsqueeze(-1) 1079*da0073e9SAndroid Build Coastguard Worker >>> y = y.movedim(y_bdim, -1) if y_bdim is not None else y.unsqueeze(-1) 1080*da0073e9SAndroid Build Coastguard Worker >>> result = x * y 1081*da0073e9SAndroid Build Coastguard Worker >>> result = result.movedim(-1, 0) 1082*da0073e9SAndroid Build Coastguard Worker >>> return result, 0 1083*da0073e9SAndroid Build Coastguard Worker >>> 1084*da0073e9SAndroid Build Coastguard Worker >>> 1085*da0073e9SAndroid Build Coastguard Worker >>> x = torch.randn(3) 1086*da0073e9SAndroid Build Coastguard Worker >>> y = torch.randn(3) 1087*da0073e9SAndroid Build Coastguard Worker >>> torch.vmap(numpy_mul)(x, y) 1088*da0073e9SAndroid Build Coastguard Worker 1089*da0073e9SAndroid Build Coastguard Worker .. note:: 1090*da0073e9SAndroid Build Coastguard Worker The vmap function should aim to preserve the semantics of the entire custom operator. 1091*da0073e9SAndroid Build Coastguard Worker That is, ``grad(vmap(op))`` should be replaceable with a ``grad(map(op))``. 1092*da0073e9SAndroid Build Coastguard Worker 1093*da0073e9SAndroid Build Coastguard Worker If your custom operator has any custom behavior in the backward pass, please 1094*da0073e9SAndroid Build Coastguard Worker keep this in mind. 1095*da0073e9SAndroid Build Coastguard Worker 1096*da0073e9SAndroid Build Coastguard Worker """ 1097*da0073e9SAndroid Build Coastguard Worker if not isinstance( 1098*da0073e9SAndroid Build Coastguard Worker op, (str, torch._ops.OpOverload, torch._library.custom_ops.CustomOpDef) 1099*da0073e9SAndroid Build Coastguard Worker ): 1100*da0073e9SAndroid Build Coastguard Worker raise ValueError(f"register_vmap(op): got unexpected type for op: {type(op)}") 1101*da0073e9SAndroid Build Coastguard Worker if isinstance(op, torch._ops.OpOverload): 1102*da0073e9SAndroid Build Coastguard Worker op = op._name 1103*da0073e9SAndroid Build Coastguard Worker opdef = _maybe_get_opdef(op) 1104*da0073e9SAndroid Build Coastguard Worker if opdef is not None: 1105*da0073e9SAndroid Build Coastguard Worker return opdef.register_vmap(func) 1106*da0073e9SAndroid Build Coastguard Worker assert isinstance(op, str) 1107*da0073e9SAndroid Build Coastguard Worker qualname = op 1108*da0073e9SAndroid Build Coastguard Worker op = torch._library.utils.lookup_op(qualname) 1109*da0073e9SAndroid Build Coastguard Worker schema = op._schema 1110*da0073e9SAndroid Build Coastguard Worker if _library.utils.has_kwarg_only_tensors(schema): 1111*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError( 1112*da0073e9SAndroid Build Coastguard Worker f"register_vmap with kwarg-only Tensor args. In the original " 1113*da0073e9SAndroid Build Coastguard Worker f"definition of the op, please make your tensors not kwarg-only. " 1114*da0073e9SAndroid Build Coastguard Worker f"Got: {schema}" 1115*da0073e9SAndroid Build Coastguard Worker ) 1116*da0073e9SAndroid Build Coastguard Worker 1117*da0073e9SAndroid Build Coastguard Worker def register(func): 1118*da0073e9SAndroid Build Coastguard Worker nonlocal op, lib 1119*da0073e9SAndroid Build Coastguard Worker 1120*da0073e9SAndroid Build Coastguard Worker namespace, opname = torch._library.utils.parse_namespace(qualname) 1121*da0073e9SAndroid Build Coastguard Worker if lib is None: 1122*da0073e9SAndroid Build Coastguard Worker lib = Library(namespace, "FRAGMENT") 1123*da0073e9SAndroid Build Coastguard Worker _keep_alive.append(lib) 1124*da0073e9SAndroid Build Coastguard Worker 1125*da0073e9SAndroid Build Coastguard Worker from torch._functorch.autograd_function import custom_function_call_vmap_helper 1126*da0073e9SAndroid Build Coastguard Worker from torch._functorch.pyfunctorch import retrieve_current_functorch_interpreter 1127*da0073e9SAndroid Build Coastguard Worker 1128*da0073e9SAndroid Build Coastguard Worker def wrapped_func(keyset, *args, **kwargs): 1129*da0073e9SAndroid Build Coastguard Worker interpreter = retrieve_current_functorch_interpreter() 1130*da0073e9SAndroid Build Coastguard Worker return custom_function_call_vmap_helper( 1131*da0073e9SAndroid Build Coastguard Worker interpreter, func, op, *args, **kwargs 1132*da0073e9SAndroid Build Coastguard Worker ) 1133*da0073e9SAndroid Build Coastguard Worker 1134*da0073e9SAndroid Build Coastguard Worker lib.impl(opname, wrapped_func, "FuncTorchBatched", with_keyset=True) 1135*da0073e9SAndroid Build Coastguard Worker 1136*da0073e9SAndroid Build Coastguard Worker if func is None: 1137*da0073e9SAndroid Build Coastguard Worker return register 1138*da0073e9SAndroid Build Coastguard Worker else: 1139*da0073e9SAndroid Build Coastguard Worker return register(func) 1140*da0073e9SAndroid Build Coastguard Worker 1141*da0073e9SAndroid Build Coastguard Worker 1142*da0073e9SAndroid Build Coastguard Worker# If the op was defined in C++, then we want to make sure there was an 1143*da0073e9SAndroid Build Coastguard Worker# m.set_python_module(module, ...) call and that the module is the 1144*da0073e9SAndroid Build Coastguard Worker# same as the module that called torch.library.register_fake. 1145*da0073e9SAndroid Build Coastguard Workerdef _check_pystubs_once(func, qualname, actual_module_name): 1146*da0073e9SAndroid Build Coastguard Worker checked = False 1147*da0073e9SAndroid Build Coastguard Worker 1148*da0073e9SAndroid Build Coastguard Worker def inner(*args, **kwargs): 1149*da0073e9SAndroid Build Coastguard Worker nonlocal checked 1150*da0073e9SAndroid Build Coastguard Worker if checked: 1151*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 1152*da0073e9SAndroid Build Coastguard Worker 1153*da0073e9SAndroid Build Coastguard Worker op = torch._library.utils.lookup_op(qualname) 1154*da0073e9SAndroid Build Coastguard Worker if op._defined_in_python: 1155*da0073e9SAndroid Build Coastguard Worker checked = True 1156*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 1157*da0073e9SAndroid Build Coastguard Worker 1158*da0073e9SAndroid Build Coastguard Worker maybe_pystub = torch._C._dispatch_pystub( 1159*da0073e9SAndroid Build Coastguard Worker op._schema.name, op._schema.overload_name 1160*da0073e9SAndroid Build Coastguard Worker ) 1161*da0073e9SAndroid Build Coastguard Worker if maybe_pystub is None: 1162*da0073e9SAndroid Build Coastguard Worker if torch._library.utils.requires_set_python_module(): 1163*da0073e9SAndroid Build Coastguard Worker namespace = op.namespace 1164*da0073e9SAndroid Build Coastguard Worker cpp_filename = op._handle.debug() 1165*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 1166*da0073e9SAndroid Build Coastguard Worker f"Operator '{qualname}' was defined in C++ and has a Python " 1167*da0073e9SAndroid Build Coastguard Worker f"fake impl. In this situation, we require there to also be a " 1168*da0073e9SAndroid Build Coastguard Worker f'companion C++ `m.set_python_module("{actual_module_name}")` ' 1169*da0073e9SAndroid Build Coastguard Worker f"call, but we could not find one. Please add that to " 1170*da0073e9SAndroid Build Coastguard Worker f"to the top of the C++ TORCH_LIBRARY({namespace}, ...) block the " 1171*da0073e9SAndroid Build Coastguard Worker f"operator was registered in ({cpp_filename})" 1172*da0073e9SAndroid Build Coastguard Worker ) 1173*da0073e9SAndroid Build Coastguard Worker else: 1174*da0073e9SAndroid Build Coastguard Worker pystub_module = maybe_pystub[0] 1175*da0073e9SAndroid Build Coastguard Worker if actual_module_name != pystub_module: 1176*da0073e9SAndroid Build Coastguard Worker cpp_filename = op._handle.debug() 1177*da0073e9SAndroid Build Coastguard Worker raise RuntimeError( 1178*da0073e9SAndroid Build Coastguard Worker f"Operator '{qualname}' specified that its python fake impl " 1179*da0073e9SAndroid Build Coastguard Worker f"is in the Python module '{pystub_module}' but it was actually found " 1180*da0073e9SAndroid Build Coastguard Worker f"in '{actual_module_name}'. Please either move the fake impl " 1181*da0073e9SAndroid Build Coastguard Worker f"or correct the m.set_python_module call ({cpp_filename})" 1182*da0073e9SAndroid Build Coastguard Worker ) 1183*da0073e9SAndroid Build Coastguard Worker checked = True 1184*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 1185*da0073e9SAndroid Build Coastguard Worker 1186*da0073e9SAndroid Build Coastguard Worker return inner 1187*da0073e9SAndroid Build Coastguard Worker 1188*da0073e9SAndroid Build Coastguard Worker 1189*da0073e9SAndroid Build Coastguard Worker# NOTE [ctx inside the fake implementation] 1190*da0073e9SAndroid Build Coastguard Worker# If a user has an operator with data-dependent output shape, then when writing 1191*da0073e9SAndroid Build Coastguard Worker# a fake implementation they must query the current ctx and use methods on the 1192*da0073e9SAndroid Build Coastguard Worker# ctx to construct a new unbacked symint. 1193*da0073e9SAndroid Build Coastguard Worker# 1194*da0073e9SAndroid Build Coastguard Worker# This is done via us setting the global_ctx_getter function every time a fake 1195*da0073e9SAndroid Build Coastguard Worker# implementation is invoked. 1196*da0073e9SAndroid Build Coastguard Workerdef get_ctx() -> "torch._library.fake_impl.FakeImplCtx": 1197*da0073e9SAndroid Build Coastguard Worker """get_ctx() returns the current AbstractImplCtx object. 1198*da0073e9SAndroid Build Coastguard Worker 1199*da0073e9SAndroid Build Coastguard Worker Calling ``get_ctx()`` is only valid inside of an fake impl 1200*da0073e9SAndroid Build Coastguard Worker (see :func:`torch.library.register_fake` for more usage details. 1201*da0073e9SAndroid Build Coastguard Worker """ 1202*da0073e9SAndroid Build Coastguard Worker return torch._library.fake_impl.global_ctx_getter() 1203*da0073e9SAndroid Build Coastguard Worker 1204*da0073e9SAndroid Build Coastguard Worker 1205*da0073e9SAndroid Build Coastguard Worker_OPCHECK_DEFAULT_UTILS = ( 1206*da0073e9SAndroid Build Coastguard Worker "test_schema", 1207*da0073e9SAndroid Build Coastguard Worker "test_autograd_registration", 1208*da0073e9SAndroid Build Coastguard Worker "test_faketensor", 1209*da0073e9SAndroid Build Coastguard Worker "test_aot_dispatch_dynamic", 1210*da0073e9SAndroid Build Coastguard Worker) 1211*da0073e9SAndroid Build Coastguard Worker 1212*da0073e9SAndroid Build Coastguard Worker 1213*da0073e9SAndroid Build Coastguard Workerdef opcheck( 1214*da0073e9SAndroid Build Coastguard Worker op: Union[torch._ops.OpOverload, torch._ops.OpOverloadPacket, CustomOpDef], 1215*da0073e9SAndroid Build Coastguard Worker args: Tuple[Any, ...], 1216*da0073e9SAndroid Build Coastguard Worker kwargs: Optional[Dict[str, Any]] = None, 1217*da0073e9SAndroid Build Coastguard Worker *, 1218*da0073e9SAndroid Build Coastguard Worker test_utils: Union[str, Sequence[str]] = _OPCHECK_DEFAULT_UTILS, 1219*da0073e9SAndroid Build Coastguard Worker raise_exception: bool = True, 1220*da0073e9SAndroid Build Coastguard Worker) -> Dict[str, str]: 1221*da0073e9SAndroid Build Coastguard Worker """Given an operator and some sample arguments, tests if the operator is 1222*da0073e9SAndroid Build Coastguard Worker registered correctly. 1223*da0073e9SAndroid Build Coastguard Worker 1224*da0073e9SAndroid Build Coastguard Worker That is, when you use the torch.library/TORCH_LIBRARY APIs to create a 1225*da0073e9SAndroid Build Coastguard Worker custom op, you specified metadata (e.g. mutability info) about the custom op 1226*da0073e9SAndroid Build Coastguard Worker and these APIs require that the functions you pass them satisfy certain 1227*da0073e9SAndroid Build Coastguard Worker properties (e.g. no data pointer access in the fake/meta/abstract kernel) 1228*da0073e9SAndroid Build Coastguard Worker ``opcheck`` tests these metadata and properties. 1229*da0073e9SAndroid Build Coastguard Worker 1230*da0073e9SAndroid Build Coastguard Worker Concretely, we test the following: 1231*da0073e9SAndroid Build Coastguard Worker 1232*da0073e9SAndroid Build Coastguard Worker - test_schema: If the schema matches the implementation of 1233*da0073e9SAndroid Build Coastguard Worker the operator. For example: if the schema specifies a Tensor is mutated, 1234*da0073e9SAndroid Build Coastguard Worker then we check the implementation mutates the Tensor. If the schema 1235*da0073e9SAndroid Build Coastguard Worker specifies that we return a new Tensor, then we check that the 1236*da0073e9SAndroid Build Coastguard Worker implementation returns a new Tensor (instead of an existing one or 1237*da0073e9SAndroid Build Coastguard Worker a view of an existing one). 1238*da0073e9SAndroid Build Coastguard Worker - test_autograd_registration: If the operator supports training 1239*da0073e9SAndroid Build Coastguard Worker (autograd): we check that its autograd formula is registered via 1240*da0073e9SAndroid Build Coastguard Worker torch.library.register_autograd or a manual registration to one 1241*da0073e9SAndroid Build Coastguard Worker or more DispatchKey::Autograd keys. Any other DispatchKey-based 1242*da0073e9SAndroid Build Coastguard Worker registrations may lead to undefined behavior. 1243*da0073e9SAndroid Build Coastguard Worker - test_faketensor: If the operator has a FakeTensor kernel 1244*da0073e9SAndroid Build Coastguard Worker (and if it is correct). The FakeTensor kernel is necessary ( 1245*da0073e9SAndroid Build Coastguard Worker but not sufficient) for the operator to work with PyTorch compilation 1246*da0073e9SAndroid Build Coastguard Worker APIs (torch.compile/export/FX). We check that a FakeTensor kernel 1247*da0073e9SAndroid Build Coastguard Worker (also sometimes known as a meta kernel) was registered for the 1248*da0073e9SAndroid Build Coastguard Worker operator and that it is correct. This test takes the result of 1249*da0073e9SAndroid Build Coastguard Worker running the operator on real tensors and the result of running 1250*da0073e9SAndroid Build Coastguard Worker the operator on FakeTensors and checks that they have the same 1251*da0073e9SAndroid Build Coastguard Worker Tensor metadata (sizes/strides/dtype/device/etc). 1252*da0073e9SAndroid Build Coastguard Worker - test_aot_dispatch_dynamic: If the operator has correct behavior 1253*da0073e9SAndroid Build Coastguard Worker with PyTorch compilation APIs (torch.compile/export/FX). 1254*da0073e9SAndroid Build Coastguard Worker This checks that the outputs (and gradients, if applicable) are the 1255*da0073e9SAndroid Build Coastguard Worker same under eager-mode PyTorch and torch.compile. 1256*da0073e9SAndroid Build Coastguard Worker This test is a superset of ``test_faketensor`` and is an e2e test; 1257*da0073e9SAndroid Build Coastguard Worker other things it tests are that the operator supports 1258*da0073e9SAndroid Build Coastguard Worker functionalization and that the backward pass (if it exists) also 1259*da0073e9SAndroid Build Coastguard Worker supports FakeTensor and functionalization. 1260*da0073e9SAndroid Build Coastguard Worker 1261*da0073e9SAndroid Build Coastguard Worker For best results, please call ``opcheck`` multiple times with a 1262*da0073e9SAndroid Build Coastguard Worker representative set of inputs. If your operator supports 1263*da0073e9SAndroid Build Coastguard Worker autograd, please use ``opcheck`` with inputs with ``requires_grad = True``; 1264*da0073e9SAndroid Build Coastguard Worker if your operator supports multiple devices (e.g. CPU and CUDA), please 1265*da0073e9SAndroid Build Coastguard Worker use ``opcheck`` with inputs on all supported devices. 1266*da0073e9SAndroid Build Coastguard Worker 1267*da0073e9SAndroid Build Coastguard Worker Args: 1268*da0073e9SAndroid Build Coastguard Worker op: The operator. Must either be a function decorated with 1269*da0073e9SAndroid Build Coastguard Worker :func:`torch.library.custom_op` or an OpOverload/OpOverloadPacket 1270*da0073e9SAndroid Build Coastguard Worker found in torch.ops.* (e.g. torch.ops.aten.sin, torch.ops.mylib.foo) 1271*da0073e9SAndroid Build Coastguard Worker args: The args to the operator 1272*da0073e9SAndroid Build Coastguard Worker kwargs: The kwargs to the operator 1273*da0073e9SAndroid Build Coastguard Worker test_utils: Tests that we should run. Default: all of them. 1274*da0073e9SAndroid Build Coastguard Worker Example: ("test_schema", "test_faketensor") 1275*da0073e9SAndroid Build Coastguard Worker raise_exception: If we should raise an exception on the first 1276*da0073e9SAndroid Build Coastguard Worker error. If False, we will return a dict with information 1277*da0073e9SAndroid Build Coastguard Worker on if each test passed or not. 1278*da0073e9SAndroid Build Coastguard Worker 1279*da0073e9SAndroid Build Coastguard Worker .. warning:: 1280*da0073e9SAndroid Build Coastguard Worker 1281*da0073e9SAndroid Build Coastguard Worker opcheck and :func:`torch.autograd.gradcheck` test different things; 1282*da0073e9SAndroid Build Coastguard Worker opcheck tests if your usage of torch.library APIs is correct while 1283*da0073e9SAndroid Build Coastguard Worker :func:`torch.autograd.gradcheck` tests if your autograd formula is 1284*da0073e9SAndroid Build Coastguard Worker mathematically correct. Use both to test custom ops that support 1285*da0073e9SAndroid Build Coastguard Worker gradient computation. 1286*da0073e9SAndroid Build Coastguard Worker 1287*da0073e9SAndroid Build Coastguard Worker Example: 1288*da0073e9SAndroid Build Coastguard Worker 1289*da0073e9SAndroid Build Coastguard Worker >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) 1290*da0073e9SAndroid Build Coastguard Worker >>> @torch.library.custom_op("mylib::numpy_mul", mutates_args=()) 1291*da0073e9SAndroid Build Coastguard Worker >>> def numpy_add(x: Tensor, y: float) -> Tensor: 1292*da0073e9SAndroid Build Coastguard Worker >>> x_np = x.numpy(force=True) 1293*da0073e9SAndroid Build Coastguard Worker >>> z_np = x_np + y 1294*da0073e9SAndroid Build Coastguard Worker >>> return torch.from_numpy(z_np).to(x.device) 1295*da0073e9SAndroid Build Coastguard Worker >>> 1296*da0073e9SAndroid Build Coastguard Worker >>> @numpy_sin.register_fake 1297*da0073e9SAndroid Build Coastguard Worker >>> def _(x, y): 1298*da0073e9SAndroid Build Coastguard Worker >>> return torch.empty_like(x) 1299*da0073e9SAndroid Build Coastguard Worker >>> 1300*da0073e9SAndroid Build Coastguard Worker >>> def setup_context(ctx, inputs, output): 1301*da0073e9SAndroid Build Coastguard Worker >>> y, = inputs 1302*da0073e9SAndroid Build Coastguard Worker >>> ctx.y = y 1303*da0073e9SAndroid Build Coastguard Worker >>> 1304*da0073e9SAndroid Build Coastguard Worker >>> def backward(ctx, grad): 1305*da0073e9SAndroid Build Coastguard Worker >>> return grad * ctx.y, None 1306*da0073e9SAndroid Build Coastguard Worker >>> 1307*da0073e9SAndroid Build Coastguard Worker >>> numpy_sin.register_autograd(backward, setup_context=setup_context) 1308*da0073e9SAndroid Build Coastguard Worker >>> 1309*da0073e9SAndroid Build Coastguard Worker >>> sample_inputs = [ 1310*da0073e9SAndroid Build Coastguard Worker >>> (torch.randn(3), 3.14), 1311*da0073e9SAndroid Build Coastguard Worker >>> (torch.randn(2, 3, device='cuda'), 2.718), 1312*da0073e9SAndroid Build Coastguard Worker >>> (torch.randn(1, 10, requires_grad=True), 1.234), 1313*da0073e9SAndroid Build Coastguard Worker >>> (torch.randn(64, 64, device='cuda', requires_grad=True), 90.18), 1314*da0073e9SAndroid Build Coastguard Worker >>> ] 1315*da0073e9SAndroid Build Coastguard Worker >>> 1316*da0073e9SAndroid Build Coastguard Worker >>> for args in sample_inputs: 1317*da0073e9SAndroid Build Coastguard Worker >>> torch.library.opcheck(foo, args) 1318*da0073e9SAndroid Build Coastguard Worker 1319*da0073e9SAndroid Build Coastguard Worker """ 1320*da0073e9SAndroid Build Coastguard Worker import torch.testing._internal.optests as optests 1321*da0073e9SAndroid Build Coastguard Worker 1322*da0073e9SAndroid Build Coastguard Worker return optests.opcheck( 1323*da0073e9SAndroid Build Coastguard Worker op, args, kwargs, test_utils=test_utils, raise_exception=raise_exception 1324*da0073e9SAndroid Build Coastguard Worker ) 1325