1# mypy: allow-untyped-defs 2import weakref 3 4import torch 5import torch.utils._pytree as pytree 6from torch._C import _ExcludeDispatchKeyGuard, DispatchKey, DispatchKeySet 7from torch._ops import OpOverload 8from torch.library import Library 9from torchgen.model import ( 10 BaseTy, 11 BaseType, 12 FunctionSchema, 13 OperatorName, 14 OptionalType, 15 SchemaKind, 16) 17 18from .autograd import autograd_not_implemented 19 20 21def register_functional_op( 22 lib: Library, 23 new_op_name: str, 24 mutable_op: OpOverload, 25) -> None: 26 """Given a mutable operator, registers the functional variant. 27 28 This API also correctly links the functional variant with the mutable 29 operator for the purposes of functionalization. 30 31 All of the new registrations are performed on the ``lib`` passed in. 32 33 Arguments: 34 lib (Library): Should be a torch.library.Library object that has 35 the same namespace as ``mutable_op``'s namespace. 36 lib will be used to register the new functional op as well 37 as a functionalization kernel for the ``mutable_op`` 38 If you don't have a library handy, use 39 ``torch.library.Library(ns, 'FRAGMENT')`` to construct one. 40 new_op_name (str): The name of the functional operator (without the 41 namespace). If no namespace, the new functional variant will be 42 accessible under ``torch.ops.{lib.ns}.new_op_name``. 43 mutable_op (OpOverload): The mutable custom operator. Note 44 that you may need to add a `.default` to it, like 45 `torch.ops.aten.abs_.default`. 46 47 """ 48 validate(mutable_op) 49 schema = functional_schema(new_op_name, mutable_op) 50 lib.define(schema) 51 52 functional_impl = construct_functional_impl(mutable_op) 53 lib.impl(new_op_name, functional_impl, 'CompositeExplicitAutograd') 54 55 functional_op = getattr(getattr(torch.ops, lib.ns), new_op_name).default 56 57 # There's no easy way for us to generate the autograd kernel, so we 58 # use autograd_not_implemented. Also, this makes it so that the user 59 # is unable to register an autograd formula themselves. This shouldn't 60 # be a problem if the user doesn't use the functional op direclty 61 # in their program, but we may need to revist this in the future. 62 lib.impl(new_op_name, autograd_not_implemented(functional_op), 'Autograd') 63 64 f_kernel = construct_functionalization_kernel(weakref.proxy(mutable_op), functional_op) 65 66 lib.impl(mutable_op, f_kernel, 'Functionalize') 67 68 69def construct_functional_impl(mutable_op): 70 def functional_impl(*args): 71 # Strategy: 72 # - clone args that would have been mutated 73 # - run mutable_op 74 # - return the cloned args as additional outputs 75 new_args = [] 76 extra_rets = [] 77 for is_write, arg in zip(mutable_args(mutable_op), args): 78 if is_write: 79 cloned = arg.clone() if arg is not None else None 80 new_args.append(cloned) 81 extra_rets.append(cloned) 82 else: 83 new_args.append(arg) 84 result = mutable_op(*new_args) 85 if result is None: 86 return tuple(extra_rets) 87 if isinstance(result, tuple): 88 return (*result, *extra_rets) 89 return (result, *extra_rets) 90 return functional_impl 91 92 93def construct_functionalization_kernel(mutable_op, functional_op): 94 def kernel(*args): 95 # There's nothing to be functionalized! 96 # We can still end up here because DispatchKey::Functionalize is a mode key 97 if pytree.tree_all_only(torch.Tensor, lambda x: not torch._is_functional_tensor(x), args): 98 with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)): 99 return mutable_op(*args) 100 101 # NB: This differs from the codegen -- codegen handles cases where there 102 # are mixed FunctionalTensorWrapper and non-FunctionalTensorWrapper. 103 # This only really matters for XLA (mixed CPU-XLA tensors) and 104 # running functionalization without the PT2 stack (which guarantees to us that 105 # all tensors are FunctionalTensorWrapper). 106 if not pytree.tree_all_only(torch.Tensor, torch._is_functional_tensor, args): 107 raise RuntimeError("{mutable_op}: expected all args to be FunctionalTensorWrapper") 108 109 unwrapped_args = [] 110 for arg in args: 111 if isinstance(arg, torch.Tensor) and torch._is_functional_tensor(arg): 112 torch._sync(arg) 113 unwrapped = torch._from_functional_tensor(arg) 114 unwrapped_args.append(unwrapped) 115 else: 116 unwrapped_args.append(arg) 117 118 with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)): 119 output = functional_op(*unwrapped_args) 120 121 num_actual_output = len(mutable_op._schema.returns) 122 actual_output = pytree.tree_map( 123 torch._to_functional_tensor, output[:num_actual_output]) 124 125 new_values_to_propagate = output[num_actual_output:] 126 inputs_to_replace = [arg for is_write, arg in zip(mutable_args(mutable_op), args) 127 if is_write] 128 assert len(new_values_to_propagate) == len(inputs_to_replace) 129 for new_value, arg in zip(new_values_to_propagate, inputs_to_replace): 130 if (arg is None and new_value is None) or (arg is not None and new_value is not None): 131 continue 132 torch._C._propagate_xla_data(arg, new_value) 133 torch._C._replace_(arg, new_value) 134 torch._C._commit_update(arg) 135 torch._sync(arg) 136 137 if len(actual_output) == 1: 138 return actual_output[0] 139 elif len(actual_output) == 0: 140 return None 141 return actual_output 142 143 return kernel 144 145 146def validate(mutable_op: OpOverload): 147 if not isinstance(mutable_op, OpOverload): 148 raise TypeError( 149 f"register_functional_op(mutable_op): expected mutable_op to be instance of " 150 f"OpOverload but got {type(mutable_op)}") 151 152 # There are generally three types of "in-place" or "mutable" ops. 153 # Each of them have their own conventions: 154 # - inplace (first input modified in-place and returned as only output) 155 # - out= (some args modified in-place and returned as outputs) 156 # - mutable (some args modified in-place but none of those returned as outputs) 157 # In theory we can support all three, but we'll just support the last 158 # option right now for simplicity. 159 schema = FunctionSchema.parse(str(mutable_op._schema)) 160 if not schema.kind() == SchemaKind.mutable: 161 raise RuntimeError("Expected op to be mutable (as opposed to functional, inplace or out)") 162 for ret in schema.returns: 163 # construct_functionalization_kernel assumes this for simplicity 164 if ret.annotation is not None: 165 raise NotImplementedError( 166 "NYI: register_functional_op(op) where op returns a mutated or aliased value. " 167 "Please file an issue (and as a workaround, modify your operator to " 168 "not return the mutated value or aliases)") 169 for arg in schema.arguments.flat_all: 170 # construct_functionalization_kernel assumes this for simplicity 171 if arg.type.is_tensor_like() and ( 172 arg.type != BaseType(BaseTy.Tensor) 173 and arg.type != OptionalType(BaseType(BaseTy.Tensor)) 174 ): 175 raise NotImplementedError( 176 "NYI: register_functional_op(op) where op has a List[Tensor] input." 177 "Please file an issue.") 178 179 180def functional_schema(new_op_name, op: OpOverload): 181 schema = FunctionSchema.parse(str(op._schema)) 182 schema = schema.signature().with_name(OperatorName.parse(new_op_name)) 183 return str(schema) 184 185 186def mutable_args(op: OpOverload): 187 return tuple(False if arg.alias_info is None else arg.alias_info.is_write 188 for arg in op._schema.arguments) 189