1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: __torch_dispatch__"] 2*da0073e9SAndroid Build Coastguard Worker 3*da0073e9SAndroid Build Coastguard Workerimport logging 4*da0073e9SAndroid Build Coastguard Workerimport sys 5*da0073e9SAndroid Build Coastguard Workerimport tempfile 6*da0073e9SAndroid Build Coastguard Workerimport unittest 7*da0073e9SAndroid Build Coastguard Workerfrom copy import deepcopy 8*da0073e9SAndroid Build Coastguard Worker 9*da0073e9SAndroid Build Coastguard Workerimport torch 10*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo 11*da0073e9SAndroid Build Coastguard Workerfrom torch import SymInt 12*da0073e9SAndroid Build Coastguard Workerfrom torch._C import DispatchKey, DispatchKeySet 13*da0073e9SAndroid Build Coastguard Workerfrom torch._custom_op.functional import register_functional_op 14*da0073e9SAndroid Build Coastguard Workerfrom torch._subclasses.fake_tensor import FakeTensorMode 15*da0073e9SAndroid Build Coastguard Workerfrom torch.cuda.jiterator import _create_jit_fn 16*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental.proxy_tensor import make_fx 17*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental.symbolic_shapes import ShapeEnv 18*da0073e9SAndroid Build Coastguard Workerfrom torch.library import _scoped_library, fallthrough_kernel, impl, Library 19*da0073e9SAndroid Build Coastguard Workerfrom torch.multiprocessing.reductions import StorageWeakRef 20*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import ( 21*da0073e9SAndroid Build Coastguard Worker instantiate_device_type_tests, 22*da0073e9SAndroid Build Coastguard Worker ops, 23*da0073e9SAndroid Build Coastguard Worker) 24*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import op_db 25*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import ( 26*da0073e9SAndroid Build Coastguard Worker first_sample, 27*da0073e9SAndroid Build Coastguard Worker IS_WINDOWS, 28*da0073e9SAndroid Build Coastguard Worker run_tests, 29*da0073e9SAndroid Build Coastguard Worker TEST_WITH_ROCM, 30*da0073e9SAndroid Build Coastguard Worker TestCase, 31*da0073e9SAndroid Build Coastguard Worker) 32*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.custom_op_db import custom_op_db 33*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.logging_tensor import ( 34*da0073e9SAndroid Build Coastguard Worker capture_logs, 35*da0073e9SAndroid Build Coastguard Worker capture_logs_with_logging_tensor_mode, 36*da0073e9SAndroid Build Coastguard Worker log_input, 37*da0073e9SAndroid Build Coastguard Worker LoggingTensor, 38*da0073e9SAndroid Build Coastguard Worker LoggingTensorMode, 39*da0073e9SAndroid Build Coastguard Worker LoggingTensorReentrant, 40*da0073e9SAndroid Build Coastguard Worker) 41*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.two_tensor import TwoTensor 42*da0073e9SAndroid Build Coastguard Workerfrom torch.utils import _pytree as pytree 43*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._mode_utils import all_same_mode, no_dispatch 44*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._python_dispatch import ( 45*da0073e9SAndroid Build Coastguard Worker _get_current_dispatch_mode, 46*da0073e9SAndroid Build Coastguard Worker _get_current_dispatch_mode_stack, 47*da0073e9SAndroid Build Coastguard Worker is_in_torch_dispatch_mode, 48*da0073e9SAndroid Build Coastguard Worker TorchDispatchMode, 49*da0073e9SAndroid Build Coastguard Worker) 50*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._pytree import tree_map, tree_map_only 51*da0073e9SAndroid Build Coastguard Worker 52*da0073e9SAndroid Build Coastguard Worker 53*da0073e9SAndroid Build Coastguard Worker# used as DataLoader collate_fn below; named here to avoid trying to pickle a lambda 54*da0073e9SAndroid Build Coastguard Workerdef _identity(x): 55*da0073e9SAndroid Build Coastguard Worker return x 56*da0073e9SAndroid Build Coastguard Worker 57*da0073e9SAndroid Build Coastguard Worker 58*da0073e9SAndroid Build Coastguard Workerclass TestDispatcherPythonBindings(TestCase): 59*da0073e9SAndroid Build Coastguard Worker def test_call_boxed(self) -> None: 60*da0073e9SAndroid Build Coastguard Worker sin = torch._C._dispatch_find_schema_or_throw("aten::sin", "") 61*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 62*da0073e9SAndroid Build Coastguard Worker y = torch._C._dispatch_call_boxed(sin, x) 63*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y, x.sin()) 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Workerclass TestPythonRegistration(TestCase): 67*da0073e9SAndroid Build Coastguard Worker test_ns = "_test_python_registration" 68*da0073e9SAndroid Build Coastguard Worker 69*da0073e9SAndroid Build Coastguard Worker def tearDown(self): 70*da0073e9SAndroid Build Coastguard Worker if hasattr(torch.ops, self.test_ns): 71*da0073e9SAndroid Build Coastguard Worker del torch.ops._test_python_registration 72*da0073e9SAndroid Build Coastguard Worker 73*da0073e9SAndroid Build Coastguard Worker def test_fallback(self) -> None: 74*da0073e9SAndroid Build Coastguard Worker test_key = "TESTING_ONLY_GenericMode" 75*da0073e9SAndroid Build Coastguard Worker test_keyset = torch._C.DispatchKeySet(test_key) 76*da0073e9SAndroid Build Coastguard Worker include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset 77*da0073e9SAndroid Build Coastguard Worker exclude_to_set = torch._C._dispatch_tls_local_exclude_set() 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker with _scoped_library("_", "IMPL") as my_lib: 80*da0073e9SAndroid Build Coastguard Worker expected_op = None 81*da0073e9SAndroid Build Coastguard Worker expected_args = None 82*da0073e9SAndroid Build Coastguard Worker expected_kwargs = None 83*da0073e9SAndroid Build Coastguard Worker # Use this out shape to make sure the result from our fallback 84*da0073e9SAndroid Build Coastguard Worker # is what is returned to the user 85*da0073e9SAndroid Build Coastguard Worker out_shape = None 86*da0073e9SAndroid Build Coastguard Worker 87*da0073e9SAndroid Build Coastguard Worker def my_fallback(op, *args, **kwargs): 88*da0073e9SAndroid Build Coastguard Worker # Disable our handler during checks and generating the output 89*da0073e9SAndroid Build Coastguard Worker with torch._C._ForceDispatchKeyGuard( 90*da0073e9SAndroid Build Coastguard Worker include_to_set, exclude_to_set | test_keyset 91*da0073e9SAndroid Build Coastguard Worker ): 92*da0073e9SAndroid Build Coastguard Worker self.assertIs(op, expected_op) 93*da0073e9SAndroid Build Coastguard Worker self.assertEqual(args, expected_args) 94*da0073e9SAndroid Build Coastguard Worker self.assertEqual(kwargs, expected_kwargs) 95*da0073e9SAndroid Build Coastguard Worker # Return something specific 96*da0073e9SAndroid Build Coastguard Worker return torch.empty(out_shape) 97*da0073e9SAndroid Build Coastguard Worker 98*da0073e9SAndroid Build Coastguard Worker my_lib.fallback(my_fallback, test_key) 99*da0073e9SAndroid Build Coastguard Worker 100*da0073e9SAndroid Build Coastguard Worker a, b = torch.rand(2), torch.rand(2) 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set): 103*da0073e9SAndroid Build Coastguard Worker # Check a factory function 104*da0073e9SAndroid Build Coastguard Worker expected_op = torch.ops.aten.empty.memory_format 105*da0073e9SAndroid Build Coastguard Worker expected_args = ((2, 2),) 106*da0073e9SAndroid Build Coastguard Worker # Extra kwargs to bypass issues with default args in factory functions 107*da0073e9SAndroid Build Coastguard Worker expected_kwargs = { 108*da0073e9SAndroid Build Coastguard Worker "dtype": torch.float64, 109*da0073e9SAndroid Build Coastguard Worker "pin_memory": False, 110*da0073e9SAndroid Build Coastguard Worker "device": torch.device("cpu"), 111*da0073e9SAndroid Build Coastguard Worker } 112*da0073e9SAndroid Build Coastguard Worker out_shape = (3,) 113*da0073e9SAndroid Build Coastguard Worker out = torch.empty(*expected_args, **expected_kwargs) 114*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.size(), out_shape) 115*da0073e9SAndroid Build Coastguard Worker 116*da0073e9SAndroid Build Coastguard Worker # Check a regular function 117*da0073e9SAndroid Build Coastguard Worker expected_op = torch.ops.aten.add.Tensor 118*da0073e9SAndroid Build Coastguard Worker expected_args = (a, b) 119*da0073e9SAndroid Build Coastguard Worker expected_kwargs = {} 120*da0073e9SAndroid Build Coastguard Worker out_shape = (4,) 121*da0073e9SAndroid Build Coastguard Worker out = a + b 122*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out.size(), out_shape) 123*da0073e9SAndroid Build Coastguard Worker 124*da0073e9SAndroid Build Coastguard Worker def test_fallback_keyset(self) -> None: 125*da0073e9SAndroid Build Coastguard Worker test_key_first = "TESTING_ONLY_GenericMode" 126*da0073e9SAndroid Build Coastguard Worker test_key_second = "TESTING_ONLY_GenericWrapper" 127*da0073e9SAndroid Build Coastguard Worker test_keyset = torch._C.DispatchKeySet(test_key_first) | torch._C.DispatchKeySet( 128*da0073e9SAndroid Build Coastguard Worker test_key_second 129*da0073e9SAndroid Build Coastguard Worker ) 130*da0073e9SAndroid Build Coastguard Worker include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset 131*da0073e9SAndroid Build Coastguard Worker exclude_to_set = torch._C._dispatch_tls_local_exclude_set() 132*da0073e9SAndroid Build Coastguard Worker 133*da0073e9SAndroid Build Coastguard Worker with _scoped_library("_", "IMPL") as my_lib: 134*da0073e9SAndroid Build Coastguard Worker first_called = False 135*da0073e9SAndroid Build Coastguard Worker second_called = False 136*da0073e9SAndroid Build Coastguard Worker 137*da0073e9SAndroid Build Coastguard Worker def first_fallback(keyset, op, *args, **kwargs): 138*da0073e9SAndroid Build Coastguard Worker nonlocal first_called 139*da0073e9SAndroid Build Coastguard Worker if second_called: 140*da0073e9SAndroid Build Coastguard Worker # Recursive call 141*da0073e9SAndroid Build Coastguard Worker first_called = True 142*da0073e9SAndroid Build Coastguard Worker with torch._C._ForceDispatchKeyGuard( 143*da0073e9SAndroid Build Coastguard Worker include_to_set, exclude_to_set | test_keyset 144*da0073e9SAndroid Build Coastguard Worker ): 145*da0073e9SAndroid Build Coastguard Worker return op(*args, **kwargs) 146*da0073e9SAndroid Build Coastguard Worker else: 147*da0073e9SAndroid Build Coastguard Worker # Redispatch down 148*da0073e9SAndroid Build Coastguard Worker keyset = keyset.remove(test_key_first) 149*da0073e9SAndroid Build Coastguard Worker return op.redispatch(keyset, *args, **kwargs) 150*da0073e9SAndroid Build Coastguard Worker 151*da0073e9SAndroid Build Coastguard Worker def second_fallback(op, *args, **kwargs): 152*da0073e9SAndroid Build Coastguard Worker nonlocal second_called 153*da0073e9SAndroid Build Coastguard Worker # Set to avoid infinite recursion 154*da0073e9SAndroid Build Coastguard Worker second_called = True 155*da0073e9SAndroid Build Coastguard Worker # New dispatcher call should hit the first callback again 156*da0073e9SAndroid Build Coastguard Worker self.assertFalse(first_called) 157*da0073e9SAndroid Build Coastguard Worker a, b = args 158*da0073e9SAndroid Build Coastguard Worker # Make a substraction here instead of add ! 159*da0073e9SAndroid Build Coastguard Worker c = a - b 160*da0073e9SAndroid Build Coastguard Worker self.assertTrue(first_called) 161*da0073e9SAndroid Build Coastguard Worker return c 162*da0073e9SAndroid Build Coastguard Worker 163*da0073e9SAndroid Build Coastguard Worker my_lib.fallback(first_fallback, test_key_first, with_keyset=True) 164*da0073e9SAndroid Build Coastguard Worker my_lib.fallback(second_fallback, test_key_second) 165*da0073e9SAndroid Build Coastguard Worker 166*da0073e9SAndroid Build Coastguard Worker a, b = torch.rand(2), torch.rand(2) 167*da0073e9SAndroid Build Coastguard Worker with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set): 168*da0073e9SAndroid Build Coastguard Worker c = a + b 169*da0073e9SAndroid Build Coastguard Worker 170*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c, a - b) 171*da0073e9SAndroid Build Coastguard Worker self.assertTrue(first_called) 172*da0073e9SAndroid Build Coastguard Worker self.assertTrue(second_called) 173*da0073e9SAndroid Build Coastguard Worker 174*da0073e9SAndroid Build Coastguard Worker def test_fallback_fallthrough(self) -> None: 175*da0073e9SAndroid Build Coastguard Worker test_key_first = "TESTING_ONLY_GenericMode" 176*da0073e9SAndroid Build Coastguard Worker test_key_second = "TESTING_ONLY_GenericWrapper" 177*da0073e9SAndroid Build Coastguard Worker test_keyset = torch._C.DispatchKeySet(test_key_first) | torch._C.DispatchKeySet( 178*da0073e9SAndroid Build Coastguard Worker test_key_second 179*da0073e9SAndroid Build Coastguard Worker ) 180*da0073e9SAndroid Build Coastguard Worker include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset 181*da0073e9SAndroid Build Coastguard Worker exclude_to_set = torch._C._dispatch_tls_local_exclude_set() 182*da0073e9SAndroid Build Coastguard Worker 183*da0073e9SAndroid Build Coastguard Worker with _scoped_library("_", "IMPL") as my_lib: 184*da0073e9SAndroid Build Coastguard Worker is_called = False 185*da0073e9SAndroid Build Coastguard Worker 186*da0073e9SAndroid Build Coastguard Worker def my_fallback(op, *args, **kwargs): 187*da0073e9SAndroid Build Coastguard Worker nonlocal is_called 188*da0073e9SAndroid Build Coastguard Worker is_called = True 189*da0073e9SAndroid Build Coastguard Worker with torch._C._ForceDispatchKeyGuard( 190*da0073e9SAndroid Build Coastguard Worker include_to_set, exclude_to_set | test_keyset 191*da0073e9SAndroid Build Coastguard Worker ): 192*da0073e9SAndroid Build Coastguard Worker return op(*args, **kwargs) 193*da0073e9SAndroid Build Coastguard Worker 194*da0073e9SAndroid Build Coastguard Worker my_lib.fallback(torch.library.fallthrough_kernel, test_key_first) 195*da0073e9SAndroid Build Coastguard Worker my_lib.fallback(my_fallback, test_key_second) 196*da0073e9SAndroid Build Coastguard Worker 197*da0073e9SAndroid Build Coastguard Worker a, b = torch.rand(2), torch.rand(2) 198*da0073e9SAndroid Build Coastguard Worker with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set): 199*da0073e9SAndroid Build Coastguard Worker c = a + b 200*da0073e9SAndroid Build Coastguard Worker 201*da0073e9SAndroid Build Coastguard Worker self.assertEqual(c, a + b) 202*da0073e9SAndroid Build Coastguard Worker self.assertTrue(is_called) 203*da0073e9SAndroid Build Coastguard Worker 204*da0073e9SAndroid Build Coastguard Worker def test_override_aten_ops_with_multiple_libraries(self) -> None: 205*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1, 2]) 206*da0073e9SAndroid Build Coastguard Worker with _scoped_library("aten", "IMPL") as my_lib2: 207*da0073e9SAndroid Build Coastguard Worker with _scoped_library("aten", "IMPL") as my_lib1: 208*da0073e9SAndroid Build Coastguard Worker # Example 1 209*da0073e9SAndroid Build Coastguard Worker def my_neg(*args, **kwargs): 210*da0073e9SAndroid Build Coastguard Worker return args[0]._neg_view() 211*da0073e9SAndroid Build Coastguard Worker 212*da0073e9SAndroid Build Coastguard Worker # Now we are secretly making the operator a view op so autograd needs to know how 213*da0073e9SAndroid Build Coastguard Worker # to handle it 214*da0073e9SAndroid Build Coastguard Worker my_lib1.impl("neg", my_neg, "AutogradCPU") 215*da0073e9SAndroid Build Coastguard Worker 216*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.neg(x).is_neg()) 217*da0073e9SAndroid Build Coastguard Worker 218*da0073e9SAndroid Build Coastguard Worker # RuntimeError: impl("aten::neg", ...): 219*da0073e9SAndroid Build Coastguard Worker # Explicitly provided namespace (aten) in operator name does not match ... 220*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 221*da0073e9SAndroid Build Coastguard Worker RuntimeError, "operator name does not match namespace" 222*da0073e9SAndroid Build Coastguard Worker ): 223*da0073e9SAndroid Build Coastguard Worker with _scoped_library("foo", "DEF") as my_lib3: 224*da0073e9SAndroid Build Coastguard Worker my_lib3.define("neg(Tensor self) -> Tensor") 225*da0073e9SAndroid Build Coastguard Worker my_lib3.impl(torch.ops.aten.neg.default, my_neg, "AutogradCPU") 226*da0073e9SAndroid Build Coastguard Worker 227*da0073e9SAndroid Build Coastguard Worker # Example 2 228*da0073e9SAndroid Build Coastguard Worker def my_mul(*args, **kwargs): 229*da0073e9SAndroid Build Coastguard Worker return torch.zeros_like(args[0]) 230*da0073e9SAndroid Build Coastguard Worker 231*da0073e9SAndroid Build Coastguard Worker # torch.ops.aten.mul.Tensor 232*da0073e9SAndroid Build Coastguard Worker my_lib2.impl("aten::mul.Tensor", my_mul, "ZeroTensor") 233*da0073e9SAndroid Build Coastguard Worker 234*da0073e9SAndroid Build Coastguard Worker y = torch._efficientzerotensor(2) 235*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.mul(x, y)._is_zerotensor()) 236*da0073e9SAndroid Build Coastguard Worker 237*da0073e9SAndroid Build Coastguard Worker # Assert that a user can't override the behavior of a (ns, op, dispatch_key) 238*da0073e9SAndroid Build Coastguard Worker # combination if someone overridden the behavior for the same before them 239*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 240*da0073e9SAndroid Build Coastguard Worker RuntimeError, "already a kernel registered from python" 241*da0073e9SAndroid Build Coastguard Worker ): 242*da0073e9SAndroid Build Coastguard Worker my_lib2.impl(torch.ops.aten.mul.Tensor, my_mul, "ZeroTensor") 243*da0073e9SAndroid Build Coastguard Worker 244*da0073e9SAndroid Build Coastguard Worker # Validate that lib2 is not affected by removing lib1 245*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.mul(x, y)._is_zerotensor()) 246*da0073e9SAndroid Build Coastguard Worker 247*da0073e9SAndroid Build Coastguard Worker # Validate that the old behavior is restored for neg and mul 248*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch.neg(x).is_neg()) 249*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.mul(x, y)._is_zerotensor()) 250*da0073e9SAndroid Build Coastguard Worker 251*da0073e9SAndroid Build Coastguard Worker def test_error_if_fn_not_callable(self): 252*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 253*da0073e9SAndroid Build Coastguard Worker TypeError, "Input function is required to be a callable" 254*da0073e9SAndroid Build Coastguard Worker ): 255*da0073e9SAndroid Build Coastguard Worker with _scoped_library("aten", "IMPL") as my_lib: 256*da0073e9SAndroid Build Coastguard Worker my_lib.impl(torch.ops.aten.neg.default, [], "AutogradCPU") 257*da0073e9SAndroid Build Coastguard Worker 258*da0073e9SAndroid Build Coastguard Worker def test_finalizer(self): 259*da0073e9SAndroid Build Coastguard Worker impls_refcnt = sys.getrefcount(torch.library._impls) 260*da0073e9SAndroid Build Coastguard Worker lib = Library(self.test_ns, "FRAGMENT") # noqa: TOR901 261*da0073e9SAndroid Build Coastguard Worker lib.define("foo123(Tensor x) -> Tensor") 262*da0073e9SAndroid Build Coastguard Worker 263*da0073e9SAndroid Build Coastguard Worker # 1 for `lib`, 1 for sys.getrefcount 264*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sys.getrefcount(lib), 2) 265*da0073e9SAndroid Build Coastguard Worker # We gained an additional reference that gets cleared when the finalizer runs 266*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sys.getrefcount(torch.library._impls), impls_refcnt + 1) 267*da0073e9SAndroid Build Coastguard Worker # 1 for `lib` 268*da0073e9SAndroid Build Coastguard Worker # 1 for the finalizer 269*da0073e9SAndroid Build Coastguard Worker # 1 for sys.getrefcount 270*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sys.getrefcount(lib._op_impls), 3) 271*da0073e9SAndroid Build Coastguard Worker 272*da0073e9SAndroid Build Coastguard Worker def foo123(x): 273*da0073e9SAndroid Build Coastguard Worker pass 274*da0073e9SAndroid Build Coastguard Worker 275*da0073e9SAndroid Build Coastguard Worker lib.impl(f"{self.test_ns}::foo123", foo123, "CPU") 276*da0073e9SAndroid Build Coastguard Worker key = f"{self.test_ns}/foo123/CPU" 277*da0073e9SAndroid Build Coastguard Worker self.assertTrue(key in torch.library._impls) 278*da0073e9SAndroid Build Coastguard Worker 279*da0073e9SAndroid Build Coastguard Worker saved_op_impls = lib._op_impls 280*da0073e9SAndroid Build Coastguard Worker 281*da0073e9SAndroid Build Coastguard Worker # del will definitely work if the following passes 282*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sys.getrefcount(lib), 2) 283*da0073e9SAndroid Build Coastguard Worker del lib 284*da0073e9SAndroid Build Coastguard Worker 285*da0073e9SAndroid Build Coastguard Worker # 1 for saved_op_impls 286*da0073e9SAndroid Build Coastguard Worker # 1 for sys.getrefcount 287*da0073e9SAndroid Build Coastguard Worker # This function should be the last user of lib._op_impls: 288*da0073e9SAndroid Build Coastguard Worker # - lib should not have a reference anymore (it was del'ed) 289*da0073e9SAndroid Build Coastguard Worker # - lib's finalizer should not have a reference anymore 290*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sys.getrefcount(saved_op_impls), 2) 291*da0073e9SAndroid Build Coastguard Worker 292*da0073e9SAndroid Build Coastguard Worker self.assertTrue(key not in torch.library._impls) 293*da0073e9SAndroid Build Coastguard Worker 294*da0073e9SAndroid Build Coastguard Worker # lib's finalizer should not have a reference anymore 295*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sys.getrefcount(torch.library._impls), impls_refcnt) 296*da0073e9SAndroid Build Coastguard Worker 297*da0073e9SAndroid Build Coastguard Worker def test_override_cpu_sum(self) -> None: 298*da0073e9SAndroid Build Coastguard Worker # Example 1 299*da0073e9SAndroid Build Coastguard Worker run = [False] 300*da0073e9SAndroid Build Coastguard Worker 301*da0073e9SAndroid Build Coastguard Worker def my_sum(*args, **kwargs): 302*da0073e9SAndroid Build Coastguard Worker run[0] = True 303*da0073e9SAndroid Build Coastguard Worker return args[0].clone() 304*da0073e9SAndroid Build Coastguard Worker 305*da0073e9SAndroid Build Coastguard Worker with _scoped_library("aten", "IMPL") as my_lib1: 306*da0073e9SAndroid Build Coastguard Worker my_lib1.impl("aten::sum", my_sum, "CPU") 307*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1, 2]) 308*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.sum(x), x) 309*da0073e9SAndroid Build Coastguard Worker self.assertTrue(run[0]) 310*da0073e9SAndroid Build Coastguard Worker # Validate that the old behavior is restored for sum 311*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.sum(x), torch.tensor(3)) 312*da0073e9SAndroid Build Coastguard Worker 313*da0073e9SAndroid Build Coastguard Worker def test_override_cuda_with_jiterator(self) -> None: 314*da0073e9SAndroid Build Coastguard Worker def override_where_cuda() -> None: 315*da0073e9SAndroid Build Coastguard Worker # Example 1: Invert the behavior of where's condition input 316*da0073e9SAndroid Build Coastguard Worker not_where_code_string = """ 317*da0073e9SAndroid Build Coastguard Worker template <typename T> T inverted_where(bool cond, T a, T b){ 318*da0073e9SAndroid Build Coastguard Worker return !cond ? a : b; 319*da0073e9SAndroid Build Coastguard Worker } 320*da0073e9SAndroid Build Coastguard Worker """ 321*da0073e9SAndroid Build Coastguard Worker jitted_where = _create_jit_fn(not_where_code_string) 322*da0073e9SAndroid Build Coastguard Worker 323*da0073e9SAndroid Build Coastguard Worker CALLED = [False] 324*da0073e9SAndroid Build Coastguard Worker 325*da0073e9SAndroid Build Coastguard Worker def inverted_where(*args, **kwargs): 326*da0073e9SAndroid Build Coastguard Worker CALLED[0] = True 327*da0073e9SAndroid Build Coastguard Worker return jitted_where(*args, **kwargs) 328*da0073e9SAndroid Build Coastguard Worker 329*da0073e9SAndroid Build Coastguard Worker # overriding where's cuda kernel with Jiterator generated kernel 330*da0073e9SAndroid Build Coastguard Worker with _scoped_library("aten", "IMPL") as my_lib: 331*da0073e9SAndroid Build Coastguard Worker my_lib.impl("aten::where.self", inverted_where, "CUDA") 332*da0073e9SAndroid Build Coastguard Worker 333*da0073e9SAndroid Build Coastguard Worker device = "cuda" 334*da0073e9SAndroid Build Coastguard Worker cond = torch.tensor( 335*da0073e9SAndroid Build Coastguard Worker [True, True, False], device=device, dtype=torch.bool 336*da0073e9SAndroid Build Coastguard Worker ) 337*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1, 2, 3], device=device) 338*da0073e9SAndroid Build Coastguard Worker y = torch.tensor([-1, -2, -3], device=device) 339*da0073e9SAndroid Build Coastguard Worker 340*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.where(cond, x, y), torch.tensor([-1, -2, 3])) 341*da0073e9SAndroid Build Coastguard Worker self.assertTrue(CALLED[0]) 342*da0073e9SAndroid Build Coastguard Worker 343*da0073e9SAndroid Build Coastguard Worker # behavior restored after deregistration 344*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.where(cond, x, y), torch.tensor([1, 2, -3])) 345*da0073e9SAndroid Build Coastguard Worker 346*da0073e9SAndroid Build Coastguard Worker def override_gelu_cuda() -> None: 347*da0073e9SAndroid Build Coastguard Worker # Example 2: Use relu to approximate gelu for faster compute 348*da0073e9SAndroid Build Coastguard Worker fastest_gelu_code_string = """ 349*da0073e9SAndroid Build Coastguard Worker template <typename T> T fast_gelu(T a){ 350*da0073e9SAndroid Build Coastguard Worker return a > 0 ? a : 0; 351*da0073e9SAndroid Build Coastguard Worker } 352*da0073e9SAndroid Build Coastguard Worker """ 353*da0073e9SAndroid Build Coastguard Worker jitted_gelu = _create_jit_fn(fastest_gelu_code_string) 354*da0073e9SAndroid Build Coastguard Worker 355*da0073e9SAndroid Build Coastguard Worker CALLED = [False] 356*da0073e9SAndroid Build Coastguard Worker 357*da0073e9SAndroid Build Coastguard Worker def fast_gelu(*args, **kwargs): 358*da0073e9SAndroid Build Coastguard Worker CALLED[0] = True 359*da0073e9SAndroid Build Coastguard Worker return jitted_gelu(*args, **kwargs) 360*da0073e9SAndroid Build Coastguard Worker 361*da0073e9SAndroid Build Coastguard Worker # overriding gelu's cuda kernel with Jiterator generated relu kernel 362*da0073e9SAndroid Build Coastguard Worker with _scoped_library("aten", "IMPL") as my_lib: 363*da0073e9SAndroid Build Coastguard Worker my_lib.impl("aten::gelu", fast_gelu, "CUDA") 364*da0073e9SAndroid Build Coastguard Worker 365*da0073e9SAndroid Build Coastguard Worker x = torch.rand([3, 3], device="cuda", dtype=torch.float) 366*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 367*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.gelu(x), torch.nn.functional.relu(x) 368*da0073e9SAndroid Build Coastguard Worker ) 369*da0073e9SAndroid Build Coastguard Worker self.assertTrue(CALLED[0]) 370*da0073e9SAndroid Build Coastguard Worker 371*da0073e9SAndroid Build Coastguard Worker # behavior restored after deregistration 372*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual( 373*da0073e9SAndroid Build Coastguard Worker torch.nn.functional.gelu(x), torch.nn.functional.relu(x) 374*da0073e9SAndroid Build Coastguard Worker ) 375*da0073e9SAndroid Build Coastguard Worker 376*da0073e9SAndroid Build Coastguard Worker def override_exp_cuda() -> None: 377*da0073e9SAndroid Build Coastguard Worker # Example 3: Preventing exp from exploding for float16 378*da0073e9SAndroid Build Coastguard Worker clipped_exp_code_string = """ 379*da0073e9SAndroid Build Coastguard Worker template <typename T> T clipped_exp(T a){ 380*da0073e9SAndroid Build Coastguard Worker return a > T(10.0) ? T(22026.4657948) : exp(a); 381*da0073e9SAndroid Build Coastguard Worker } 382*da0073e9SAndroid Build Coastguard Worker """ 383*da0073e9SAndroid Build Coastguard Worker jitted_exp = _create_jit_fn(clipped_exp_code_string) 384*da0073e9SAndroid Build Coastguard Worker 385*da0073e9SAndroid Build Coastguard Worker CALLED = [False] 386*da0073e9SAndroid Build Coastguard Worker 387*da0073e9SAndroid Build Coastguard Worker def clipped_exp(*args, **kwargs): 388*da0073e9SAndroid Build Coastguard Worker CALLED[0] = True 389*da0073e9SAndroid Build Coastguard Worker return jitted_exp(*args, **kwargs) 390*da0073e9SAndroid Build Coastguard Worker 391*da0073e9SAndroid Build Coastguard Worker # overriding exp's cuda kernel with clipped_exp kernel 392*da0073e9SAndroid Build Coastguard Worker with _scoped_library("aten", "IMPL") as my_lib: 393*da0073e9SAndroid Build Coastguard Worker my_lib.impl("aten::exp", clipped_exp, "CUDA") 394*da0073e9SAndroid Build Coastguard Worker 395*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([0.0, 100.0], device="cuda", dtype=torch.float16) 396*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 397*da0073e9SAndroid Build Coastguard Worker torch.exp(x), 398*da0073e9SAndroid Build Coastguard Worker torch.tensor([1.0, 22026.4657948], dtype=torch.float16), 399*da0073e9SAndroid Build Coastguard Worker ) 400*da0073e9SAndroid Build Coastguard Worker self.assertTrue(CALLED[0]) 401*da0073e9SAndroid Build Coastguard Worker 402*da0073e9SAndroid Build Coastguard Worker # behavior restored after deregistration 403*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 404*da0073e9SAndroid Build Coastguard Worker torch.exp(x), torch.tensor([1.0, torch.inf], dtype=torch.float16) 405*da0073e9SAndroid Build Coastguard Worker ) 406*da0073e9SAndroid Build Coastguard Worker 407*da0073e9SAndroid Build Coastguard Worker def override_add_cuda() -> None: 408*da0073e9SAndroid Build Coastguard Worker # Example 4: simulate a hardware bug, where the adder is always off by 1 409*da0073e9SAndroid Build Coastguard Worker buggy_add_code_string = """ 410*da0073e9SAndroid Build Coastguard Worker template <typename T> T buggy_add(T a, T b){ 411*da0073e9SAndroid Build Coastguard Worker return a + b + T(1); 412*da0073e9SAndroid Build Coastguard Worker } 413*da0073e9SAndroid Build Coastguard Worker """ 414*da0073e9SAndroid Build Coastguard Worker jitted_add = _create_jit_fn(buggy_add_code_string) 415*da0073e9SAndroid Build Coastguard Worker 416*da0073e9SAndroid Build Coastguard Worker CALLED = [False] 417*da0073e9SAndroid Build Coastguard Worker 418*da0073e9SAndroid Build Coastguard Worker def buggy_add(*args, **kwargs): 419*da0073e9SAndroid Build Coastguard Worker CALLED[0] = True 420*da0073e9SAndroid Build Coastguard Worker return jitted_add(*args, **kwargs) 421*da0073e9SAndroid Build Coastguard Worker 422*da0073e9SAndroid Build Coastguard Worker with _scoped_library("aten", "IMPL") as my_lib: 423*da0073e9SAndroid Build Coastguard Worker my_lib.impl("aten::add.Tensor", buggy_add, "CUDA") 424*da0073e9SAndroid Build Coastguard Worker 425*da0073e9SAndroid Build Coastguard Worker x_cpu = torch.rand([3, 3], device="cpu") 426*da0073e9SAndroid Build Coastguard Worker y_cpu = torch.rand([3], device="cpu") 427*da0073e9SAndroid Build Coastguard Worker 428*da0073e9SAndroid Build Coastguard Worker x_cuda = x_cpu.cuda() 429*da0073e9SAndroid Build Coastguard Worker y_cuda = y_cpu.cuda() 430*da0073e9SAndroid Build Coastguard Worker 431*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_cuda + y_cuda, x_cpu + y_cpu + 1) 432*da0073e9SAndroid Build Coastguard Worker self.assertTrue(CALLED[0]) 433*da0073e9SAndroid Build Coastguard Worker 434*da0073e9SAndroid Build Coastguard Worker # behavior restored after deregistration 435*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x_cuda + y_cuda, x_cpu + y_cpu) 436*da0073e9SAndroid Build Coastguard Worker 437*da0073e9SAndroid Build Coastguard Worker if torch.cuda.is_available() and not TEST_WITH_ROCM: 438*da0073e9SAndroid Build Coastguard Worker override_where_cuda() 439*da0073e9SAndroid Build Coastguard Worker override_gelu_cuda() 440*da0073e9SAndroid Build Coastguard Worker override_exp_cuda() 441*da0073e9SAndroid Build Coastguard Worker override_add_cuda() 442*da0073e9SAndroid Build Coastguard Worker 443*da0073e9SAndroid Build Coastguard Worker def test_extend_library_with_dispatch_key_arg(self): 444*da0073e9SAndroid Build Coastguard Worker def my_sum(*args, **kwargs): 445*da0073e9SAndroid Build Coastguard Worker return args[0].clone() 446*da0073e9SAndroid Build Coastguard Worker 447*da0073e9SAndroid Build Coastguard Worker with _scoped_library("aten", "IMPL", dispatch_key="CPU") as my_lib1: 448*da0073e9SAndroid Build Coastguard Worker # RuntimeError: Explicitly provided dispatch key (Conjugate) is 449*da0073e9SAndroid Build Coastguard Worker # inconsistent with the dispatch key of the enclosing TORCH_LIBRARY_IMPL block 450*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 451*da0073e9SAndroid Build Coastguard Worker RuntimeError, "inconsistent with the dispatch key" 452*da0073e9SAndroid Build Coastguard Worker ): 453*da0073e9SAndroid Build Coastguard Worker my_lib1.impl("sum", my_sum, "Conjugate") 454*da0073e9SAndroid Build Coastguard Worker my_lib1.impl("aten::sum", my_sum) 455*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1, 2]) 456*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.sum(x), x) 457*da0073e9SAndroid Build Coastguard Worker 458*da0073e9SAndroid Build Coastguard Worker def test_create_new_library(self) -> None: 459*da0073e9SAndroid Build Coastguard Worker with _scoped_library(self.test_ns, "DEF") as my_lib1: 460*da0073e9SAndroid Build Coastguard Worker my_lib1.define("sum(Tensor self) -> Tensor") 461*da0073e9SAndroid Build Coastguard Worker 462*da0073e9SAndroid Build Coastguard Worker # Example 1 463*da0073e9SAndroid Build Coastguard Worker @torch.library.impl(my_lib1, "sum", "CPU") 464*da0073e9SAndroid Build Coastguard Worker def my_sum(*args, **kwargs): 465*da0073e9SAndroid Build Coastguard Worker return args[0].clone() 466*da0073e9SAndroid Build Coastguard Worker 467*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1, 2]) 468*da0073e9SAndroid Build Coastguard Worker op = getattr(torch.ops, self.test_ns).sum 469*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(x), x) 470*da0073e9SAndroid Build Coastguard Worker 471*da0073e9SAndroid Build Coastguard Worker with _scoped_library(self.test_ns, "IMPL") as my_lib2: 472*da0073e9SAndroid Build Coastguard Worker # Example 2 473*da0073e9SAndroid Build Coastguard Worker @torch.library.impl(my_lib2, op.default, "ZeroTensor") 474*da0073e9SAndroid Build Coastguard Worker def my_sum_zt(*args, **kwargs): 475*da0073e9SAndroid Build Coastguard Worker if args[0]._is_zerotensor(): 476*da0073e9SAndroid Build Coastguard Worker return torch._efficientzerotensor(args[0].shape) 477*da0073e9SAndroid Build Coastguard Worker else: 478*da0073e9SAndroid Build Coastguard Worker return args[0].clone() 479*da0073e9SAndroid Build Coastguard Worker 480*da0073e9SAndroid Build Coastguard Worker y = torch._efficientzerotensor(3) 481*da0073e9SAndroid Build Coastguard Worker self.assertTrue(op(y)._is_zerotensor()) 482*da0073e9SAndroid Build Coastguard Worker self.assertEqual(op(x), x) 483*da0073e9SAndroid Build Coastguard Worker 484*da0073e9SAndroid Build Coastguard Worker def test_create_new_library_fragment_no_existing(self): 485*da0073e9SAndroid Build Coastguard Worker with _scoped_library(self.test_ns, "FRAGMENT") as my_lib: 486*da0073e9SAndroid Build Coastguard Worker my_lib.define("sum2(Tensor self) -> Tensor") 487*da0073e9SAndroid Build Coastguard Worker 488*da0073e9SAndroid Build Coastguard Worker @torch.library.impl(my_lib, "sum2", "CPU") 489*da0073e9SAndroid Build Coastguard Worker def my_sum(*args, **kwargs): 490*da0073e9SAndroid Build Coastguard Worker return args[0] 491*da0073e9SAndroid Build Coastguard Worker 492*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1, 2]) 493*da0073e9SAndroid Build Coastguard Worker self.assertEqual(getattr(torch.ops, self.test_ns).sum2(x), x) 494*da0073e9SAndroid Build Coastguard Worker 495*da0073e9SAndroid Build Coastguard Worker def test_create_new_library_fragment_with_existing(self): 496*da0073e9SAndroid Build Coastguard Worker with _scoped_library(self.test_ns, "DEF") as my_lib1: 497*da0073e9SAndroid Build Coastguard Worker # Create a fragment 498*da0073e9SAndroid Build Coastguard Worker with _scoped_library(self.test_ns, "FRAGMENT") as my_lib2: 499*da0073e9SAndroid Build Coastguard Worker my_lib2.define("sum4(Tensor self) -> Tensor") 500*da0073e9SAndroid Build Coastguard Worker 501*da0073e9SAndroid Build Coastguard Worker @torch.library.impl(my_lib2, "sum4", "CPU") 502*da0073e9SAndroid Build Coastguard Worker def my_sum4(*args, **kwargs): 503*da0073e9SAndroid Build Coastguard Worker return args[0] 504*da0073e9SAndroid Build Coastguard Worker 505*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1, 2]) 506*da0073e9SAndroid Build Coastguard Worker self.assertEqual(getattr(torch.ops, self.test_ns).sum4(x), x) 507*da0073e9SAndroid Build Coastguard Worker 508*da0073e9SAndroid Build Coastguard Worker # Create another fragment 509*da0073e9SAndroid Build Coastguard Worker with _scoped_library(self.test_ns, "FRAGMENT") as my_lib3: 510*da0073e9SAndroid Build Coastguard Worker my_lib3.define("sum3(Tensor self) -> Tensor") 511*da0073e9SAndroid Build Coastguard Worker 512*da0073e9SAndroid Build Coastguard Worker @torch.library.impl(my_lib3, "sum3", "CPU") 513*da0073e9SAndroid Build Coastguard Worker def my_sum3(*args, **kwargs): 514*da0073e9SAndroid Build Coastguard Worker return args[0] 515*da0073e9SAndroid Build Coastguard Worker 516*da0073e9SAndroid Build Coastguard Worker x = torch.tensor([1, 2]) 517*da0073e9SAndroid Build Coastguard Worker self.assertEqual(getattr(torch.ops, self.test_ns).sum3(x), x) 518*da0073e9SAndroid Build Coastguard Worker 519*da0073e9SAndroid Build Coastguard Worker @unittest.skipIf(IS_WINDOWS, "Skipped under Windows") 520*da0073e9SAndroid Build Coastguard Worker def test_alias_analysis(self): 521*da0073e9SAndroid Build Coastguard Worker def test_helper(alias_analysis=""): 522*da0073e9SAndroid Build Coastguard Worker my_lib1 = Library(self.test_ns, "DEF") # noqa: TOR901 523*da0073e9SAndroid Build Coastguard Worker 524*da0073e9SAndroid Build Coastguard Worker called = [0] 525*da0073e9SAndroid Build Coastguard Worker 526*da0073e9SAndroid Build Coastguard Worker @torch.library.define( 527*da0073e9SAndroid Build Coastguard Worker my_lib1, "_op() -> None", alias_analysis=alias_analysis 528*da0073e9SAndroid Build Coastguard Worker ) 529*da0073e9SAndroid Build Coastguard Worker def _op(*args, **kwargs): 530*da0073e9SAndroid Build Coastguard Worker called[0] += 1 531*da0073e9SAndroid Build Coastguard Worker 532*da0073e9SAndroid Build Coastguard Worker @torch.jit.script 533*da0073e9SAndroid Build Coastguard Worker def _test(): 534*da0073e9SAndroid Build Coastguard Worker torch.ops._test_python_registration._op() 535*da0073e9SAndroid Build Coastguard Worker 536*da0073e9SAndroid Build Coastguard Worker assert "_test_python_registration::_op" in str(_test.graph) 537*da0073e9SAndroid Build Coastguard Worker 538*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 539*da0073e9SAndroid Build Coastguard Worker test_helper("") # alias_analysis="FROM_SCHEMA" 540*da0073e9SAndroid Build Coastguard Worker 541*da0073e9SAndroid Build Coastguard Worker test_helper("CONSERVATIVE") 542*da0073e9SAndroid Build Coastguard Worker 543*da0073e9SAndroid Build Coastguard Worker def test_error_for_unsupported_ns_or_kind(self) -> None: 544*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "Unsupported kind"): 545*da0073e9SAndroid Build Coastguard Worker my_lib1 = Library("myns", "BLA") # noqa: TOR901 546*da0073e9SAndroid Build Coastguard Worker 547*da0073e9SAndroid Build Coastguard Worker for kind in ("DEF", "FRAGMENT"): 548*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ValueError, "reserved namespace"): 549*da0073e9SAndroid Build Coastguard Worker my_lib1 = Library("prim", kind) # noqa: TOR901 550*da0073e9SAndroid Build Coastguard Worker 551*da0073e9SAndroid Build Coastguard Worker def test_returning_symint(self) -> None: 552*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 553*da0073e9SAndroid Build Coastguard Worker fake_tensor_mode = FakeTensorMode(shape_env=shape_env) 554*da0073e9SAndroid Build Coastguard Worker 555*da0073e9SAndroid Build Coastguard Worker ft = fake_tensor_mode.from_tensor(torch.rand(2, 3)) 556*da0073e9SAndroid Build Coastguard Worker 557*da0073e9SAndroid Build Coastguard Worker s0, s1 = ft.shape 558*da0073e9SAndroid Build Coastguard Worker 559*da0073e9SAndroid Build Coastguard Worker with _scoped_library(self.test_ns, "DEF") as tlib: 560*da0073e9SAndroid Build Coastguard Worker tlib.define("sqsum(SymInt a, SymInt b) -> SymInt") 561*da0073e9SAndroid Build Coastguard Worker 562*da0073e9SAndroid Build Coastguard Worker @impl(tlib, "sqsum", "CompositeExplicitAutograd") 563*da0073e9SAndroid Build Coastguard Worker def sqsum(a: SymInt, b: SymInt): 564*da0073e9SAndroid Build Coastguard Worker return a * a + b * b 565*da0073e9SAndroid Build Coastguard Worker 566*da0073e9SAndroid Build Coastguard Worker out = getattr(torch.ops, self.test_ns).sqsum.default(s0, s1) 567*da0073e9SAndroid Build Coastguard Worker out_val = shape_env.evaluate_expr(out.node.expr) 568*da0073e9SAndroid Build Coastguard Worker self.assertEqual(out_val, 13) 569*da0073e9SAndroid Build Coastguard Worker 570*da0073e9SAndroid Build Coastguard Worker def test_register_functional_op_error_cases(self): 571*da0073e9SAndroid Build Coastguard Worker with _scoped_library(self.test_ns, "FRAGMENT") as lib: 572*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, "instance of OpOverload"): 573*da0073e9SAndroid Build Coastguard Worker register_functional_op(lib, "abs", torch.ops.aten.abs_) 574*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected op to be mutable"): 575*da0073e9SAndroid Build Coastguard Worker register_functional_op(lib, "abs", torch.ops.aten.abs_.default) 576*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "Expected op to be mutable"): 577*da0073e9SAndroid Build Coastguard Worker register_functional_op(lib, "abs", torch.ops.aten.abs.out) 578*da0073e9SAndroid Build Coastguard Worker 579*da0073e9SAndroid Build Coastguard Worker schemas = [ 580*da0073e9SAndroid Build Coastguard Worker "foo(Tensor x, Tensor(a!)[] y) -> ()", 581*da0073e9SAndroid Build Coastguard Worker "foo(Tensor x, Tensor(a!) y, Tensor(b) z) -> Tensor(b)", 582*da0073e9SAndroid Build Coastguard Worker "foo(Tensor x, Tensor(a!) y) -> (Tensor, Tensor(a))", 583*da0073e9SAndroid Build Coastguard Worker ] 584*da0073e9SAndroid Build Coastguard Worker 585*da0073e9SAndroid Build Coastguard Worker for schema in schemas: 586*da0073e9SAndroid Build Coastguard Worker with _scoped_library(self.test_ns, "FRAGMENT") as lib: 587*da0073e9SAndroid Build Coastguard Worker lib.define(schema) 588*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "NYI"): 589*da0073e9SAndroid Build Coastguard Worker register_functional_op( 590*da0073e9SAndroid Build Coastguard Worker lib, 591*da0073e9SAndroid Build Coastguard Worker "foo_functional", 592*da0073e9SAndroid Build Coastguard Worker getattr(torch.ops, self.test_ns).foo.default, 593*da0073e9SAndroid Build Coastguard Worker ) 594*da0073e9SAndroid Build Coastguard Worker 595*da0073e9SAndroid Build Coastguard Worker def _check_is_functional_variant(self, mutable_op, functional_op, args): 596*da0073e9SAndroid Build Coastguard Worker # functional op should not mutate 597*da0073e9SAndroid Build Coastguard Worker cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args) 598*da0073e9SAndroid Build Coastguard Worker functional_result = functional_op(*cloned_args) 599*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cloned_args, args) 600*da0073e9SAndroid Build Coastguard Worker 601*da0073e9SAndroid Build Coastguard Worker # check functional_result includes mutable_result 602*da0073e9SAndroid Build Coastguard Worker mutable_result = mutable_op(*cloned_args) 603*da0073e9SAndroid Build Coastguard Worker if mutable_result is None: 604*da0073e9SAndroid Build Coastguard Worker flat_mutable_result = [] 605*da0073e9SAndroid Build Coastguard Worker else: 606*da0073e9SAndroid Build Coastguard Worker flat_mutable_result = pytree.tree_leaves(mutable_result) 607*da0073e9SAndroid Build Coastguard Worker flat_functional_result = pytree.tree_leaves(functional_result) 608*da0073e9SAndroid Build Coastguard Worker assert len(flat_functional_result) > len(flat_mutable_result) 609*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 610*da0073e9SAndroid Build Coastguard Worker flat_functional_result[: len(flat_mutable_result)], flat_mutable_result 611*da0073e9SAndroid Build Coastguard Worker ) 612*da0073e9SAndroid Build Coastguard Worker 613*da0073e9SAndroid Build Coastguard Worker # check rest of functional_result is the mutated args 614*da0073e9SAndroid Build Coastguard Worker mutated_args = [ 615*da0073e9SAndroid Build Coastguard Worker maybe_mutated_arg 616*da0073e9SAndroid Build Coastguard Worker for maybe_mutated_arg, arg in zip(cloned_args, args) 617*da0073e9SAndroid Build Coastguard Worker if not ( 618*da0073e9SAndroid Build Coastguard Worker maybe_mutated_arg is not None 619*da0073e9SAndroid Build Coastguard Worker and arg is not None 620*da0073e9SAndroid Build Coastguard Worker and torch.allclose(maybe_mutated_arg, arg) 621*da0073e9SAndroid Build Coastguard Worker ) 622*da0073e9SAndroid Build Coastguard Worker ] 623*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 624*da0073e9SAndroid Build Coastguard Worker flat_functional_result[len(flat_mutable_result) :], mutated_args 625*da0073e9SAndroid Build Coastguard Worker ) 626*da0073e9SAndroid Build Coastguard Worker 627*da0073e9SAndroid Build Coastguard Worker # check that functionalization kernel was indeed registered 628*da0073e9SAndroid Build Coastguard Worker def fn(*args): 629*da0073e9SAndroid Build Coastguard Worker cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args) 630*da0073e9SAndroid Build Coastguard Worker mutable_op(*cloned_args) 631*da0073e9SAndroid Build Coastguard Worker return cloned_args 632*da0073e9SAndroid Build Coastguard Worker 633*da0073e9SAndroid Build Coastguard Worker gm = make_fx(torch.func.functionalize(fn))(*args) 634*da0073e9SAndroid Build Coastguard Worker has_functional_op = False 635*da0073e9SAndroid Build Coastguard Worker for node in gm.graph.nodes: 636*da0073e9SAndroid Build Coastguard Worker self.assertFalse(node.target is mutable_op) 637*da0073e9SAndroid Build Coastguard Worker if node.target is functional_op: 638*da0073e9SAndroid Build Coastguard Worker has_functional_op = True 639*da0073e9SAndroid Build Coastguard Worker self.assertTrue(has_functional_op) 640*da0073e9SAndroid Build Coastguard Worker 641*da0073e9SAndroid Build Coastguard Worker def test_register_functional_op_no_returns(self): 642*da0073e9SAndroid Build Coastguard Worker with _scoped_library(self.test_ns, "FRAGMENT") as lib: 643*da0073e9SAndroid Build Coastguard Worker lib.define("foo(Tensor x, Tensor(a!) y, Tensor z, Tensor(b!) w) -> ()") 644*da0073e9SAndroid Build Coastguard Worker 645*da0073e9SAndroid Build Coastguard Worker def foo_impl(x, y, z, w): 646*da0073e9SAndroid Build Coastguard Worker y.fill_(3.14) 647*da0073e9SAndroid Build Coastguard Worker w.fill_(2.71) 648*da0073e9SAndroid Build Coastguard Worker 649*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_impl, "CPU") 650*da0073e9SAndroid Build Coastguard Worker register_functional_op( 651*da0073e9SAndroid Build Coastguard Worker lib, "foo_functional", getattr(torch.ops, self.test_ns).foo.default 652*da0073e9SAndroid Build Coastguard Worker ) 653*da0073e9SAndroid Build Coastguard Worker x = torch.randn([]) 654*da0073e9SAndroid Build Coastguard Worker y = torch.randn([]) 655*da0073e9SAndroid Build Coastguard Worker z = torch.randn([]) 656*da0073e9SAndroid Build Coastguard Worker w = torch.randn([]) 657*da0073e9SAndroid Build Coastguard Worker self._check_is_functional_variant( 658*da0073e9SAndroid Build Coastguard Worker getattr(torch.ops, self.test_ns).foo.default, 659*da0073e9SAndroid Build Coastguard Worker getattr(torch.ops, self.test_ns).foo_functional.default, 660*da0073e9SAndroid Build Coastguard Worker (x, y, z, w), 661*da0073e9SAndroid Build Coastguard Worker ) 662*da0073e9SAndroid Build Coastguard Worker 663*da0073e9SAndroid Build Coastguard Worker def test_register_functional_op_with_optional(self): 664*da0073e9SAndroid Build Coastguard Worker with _scoped_library(self.test_ns, "FRAGMENT") as lib: 665*da0073e9SAndroid Build Coastguard Worker lib.define( 666*da0073e9SAndroid Build Coastguard Worker "foo(Tensor x, Tensor(a!) y, Tensor (b!) z, Tensor(c!)? w) -> ()" 667*da0073e9SAndroid Build Coastguard Worker ) 668*da0073e9SAndroid Build Coastguard Worker 669*da0073e9SAndroid Build Coastguard Worker def foo_impl(x, y, z, w): 670*da0073e9SAndroid Build Coastguard Worker y.fill_(3.14) 671*da0073e9SAndroid Build Coastguard Worker z.fill_(2.71) 672*da0073e9SAndroid Build Coastguard Worker if w is not None: 673*da0073e9SAndroid Build Coastguard Worker w.fill_(1.618) 674*da0073e9SAndroid Build Coastguard Worker 675*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_impl, "CPU") 676*da0073e9SAndroid Build Coastguard Worker register_functional_op( 677*da0073e9SAndroid Build Coastguard Worker lib, "foo_functional", getattr(torch.ops, self.test_ns).foo.default 678*da0073e9SAndroid Build Coastguard Worker ) 679*da0073e9SAndroid Build Coastguard Worker x = torch.randn([]) 680*da0073e9SAndroid Build Coastguard Worker y = torch.randn([]) 681*da0073e9SAndroid Build Coastguard Worker z = torch.randn([]) 682*da0073e9SAndroid Build Coastguard Worker w = torch.randn([]) 683*da0073e9SAndroid Build Coastguard Worker self._check_is_functional_variant( 684*da0073e9SAndroid Build Coastguard Worker getattr(torch.ops, self.test_ns).foo.default, 685*da0073e9SAndroid Build Coastguard Worker getattr(torch.ops, self.test_ns).foo_functional.default, 686*da0073e9SAndroid Build Coastguard Worker (x, y, z, w), 687*da0073e9SAndroid Build Coastguard Worker ) 688*da0073e9SAndroid Build Coastguard Worker self._check_is_functional_variant( 689*da0073e9SAndroid Build Coastguard Worker getattr(torch.ops, self.test_ns).foo.default, 690*da0073e9SAndroid Build Coastguard Worker getattr(torch.ops, self.test_ns).foo_functional.default, 691*da0073e9SAndroid Build Coastguard Worker (x, y, z, None), 692*da0073e9SAndroid Build Coastguard Worker ) 693*da0073e9SAndroid Build Coastguard Worker 694*da0073e9SAndroid Build Coastguard Worker def test_register_functional_op_one_return(self): 695*da0073e9SAndroid Build Coastguard Worker with _scoped_library(self.test_ns, "FRAGMENT") as lib: 696*da0073e9SAndroid Build Coastguard Worker lib.define( 697*da0073e9SAndroid Build Coastguard Worker "foo(Tensor x, Tensor(a!) y, Tensor(c!) z, Tensor(b!) w) -> Tensor" 698*da0073e9SAndroid Build Coastguard Worker ) 699*da0073e9SAndroid Build Coastguard Worker 700*da0073e9SAndroid Build Coastguard Worker def foo_impl(x, y, z, w): 701*da0073e9SAndroid Build Coastguard Worker y.fill_(3.14) 702*da0073e9SAndroid Build Coastguard Worker w.fill_(2.71) 703*da0073e9SAndroid Build Coastguard Worker z.fill_(0.99) 704*da0073e9SAndroid Build Coastguard Worker return x.clone() 705*da0073e9SAndroid Build Coastguard Worker 706*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_impl, "CPU") 707*da0073e9SAndroid Build Coastguard Worker register_functional_op( 708*da0073e9SAndroid Build Coastguard Worker lib, "foo_functional", getattr(torch.ops, self.test_ns).foo.default 709*da0073e9SAndroid Build Coastguard Worker ) 710*da0073e9SAndroid Build Coastguard Worker x = torch.randn([]) 711*da0073e9SAndroid Build Coastguard Worker y = torch.randn([]) 712*da0073e9SAndroid Build Coastguard Worker z = torch.randn([]) 713*da0073e9SAndroid Build Coastguard Worker w = torch.randn([]) 714*da0073e9SAndroid Build Coastguard Worker self._check_is_functional_variant( 715*da0073e9SAndroid Build Coastguard Worker getattr(torch.ops, self.test_ns).foo.default, 716*da0073e9SAndroid Build Coastguard Worker getattr(torch.ops, self.test_ns).foo_functional.default, 717*da0073e9SAndroid Build Coastguard Worker (x, y, z, w), 718*da0073e9SAndroid Build Coastguard Worker ) 719*da0073e9SAndroid Build Coastguard Worker 720*da0073e9SAndroid Build Coastguard Worker def test_register_functional_op_multiple_returns(self): 721*da0073e9SAndroid Build Coastguard Worker with _scoped_library(self.test_ns, "FRAGMENT") as lib: 722*da0073e9SAndroid Build Coastguard Worker lib.define( 723*da0073e9SAndroid Build Coastguard Worker "foo(Tensor x, Tensor(a!) y, Tensor z, Tensor(b!) w) -> (Tensor, Tensor)" 724*da0073e9SAndroid Build Coastguard Worker ) 725*da0073e9SAndroid Build Coastguard Worker 726*da0073e9SAndroid Build Coastguard Worker def foo_impl(x, y, z, w): 727*da0073e9SAndroid Build Coastguard Worker y.fill_(3.14) 728*da0073e9SAndroid Build Coastguard Worker w.fill_(2.71) 729*da0073e9SAndroid Build Coastguard Worker return x.clone(), z.clone() 730*da0073e9SAndroid Build Coastguard Worker 731*da0073e9SAndroid Build Coastguard Worker lib.impl("foo", foo_impl, "CPU") 732*da0073e9SAndroid Build Coastguard Worker register_functional_op( 733*da0073e9SAndroid Build Coastguard Worker lib, "foo_functional", getattr(torch.ops, self.test_ns).foo.default 734*da0073e9SAndroid Build Coastguard Worker ) 735*da0073e9SAndroid Build Coastguard Worker 736*da0073e9SAndroid Build Coastguard Worker x = torch.randn([]) 737*da0073e9SAndroid Build Coastguard Worker y = torch.randn([]) 738*da0073e9SAndroid Build Coastguard Worker z = torch.randn([]) 739*da0073e9SAndroid Build Coastguard Worker w = torch.randn([]) 740*da0073e9SAndroid Build Coastguard Worker self._check_is_functional_variant( 741*da0073e9SAndroid Build Coastguard Worker getattr(torch.ops, self.test_ns).foo.default, 742*da0073e9SAndroid Build Coastguard Worker getattr(torch.ops, self.test_ns).foo_functional.default, 743*da0073e9SAndroid Build Coastguard Worker (x, y, z, w), 744*da0073e9SAndroid Build Coastguard Worker ) 745*da0073e9SAndroid Build Coastguard Worker 746*da0073e9SAndroid Build Coastguard Worker def test_register_fallthrough(self): 747*da0073e9SAndroid Build Coastguard Worker with _scoped_library("aten", "IMPL") as my_lib: 748*da0073e9SAndroid Build Coastguard Worker my_lib.impl("mm", fallthrough_kernel, "AutocastCPU") 749*da0073e9SAndroid Build Coastguard Worker 750*da0073e9SAndroid Build Coastguard Worker a = torch.randn(2, 3, device="cpu", dtype=torch.float32) 751*da0073e9SAndroid Build Coastguard Worker b = torch.randn(3, 2, device="cpu", dtype=torch.float32) 752*da0073e9SAndroid Build Coastguard Worker with torch.autocast(device_type="cpu", dtype=torch.bfloat16): 753*da0073e9SAndroid Build Coastguard Worker # dtype for mm should be float32 since we registered a fallthrough 754*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.mm(a, b).dtype, torch.float32) 755*da0073e9SAndroid Build Coastguard Worker # ops that don't have a fallthrough registered should not be affected 756*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.matmul(a, b).dtype, torch.bfloat16) 757*da0073e9SAndroid Build Coastguard Worker 758*da0073e9SAndroid Build Coastguard Worker with torch.autocast(device_type="cpu", dtype=torch.bfloat16): 759*da0073e9SAndroid Build Coastguard Worker # default behavior should have been restored 760*da0073e9SAndroid Build Coastguard Worker self.assertEqual(torch.mm(a, b).dtype, torch.bfloat16) 761*da0073e9SAndroid Build Coastguard Worker 762*da0073e9SAndroid Build Coastguard Worker 763*da0073e9SAndroid Build Coastguard Workerclass TestPythonDispatch(TestCase): 764*da0073e9SAndroid Build Coastguard Worker def test_basic(self) -> None: 765*da0073e9SAndroid Build Coastguard Worker with capture_logs() as logs: 766*da0073e9SAndroid Build Coastguard Worker x = LoggingTensor(torch.tensor([3.0]), requires_grad=True) 767*da0073e9SAndroid Build Coastguard Worker log_input("x", x) 768*da0073e9SAndroid Build Coastguard Worker y = x * x 769*da0073e9SAndroid Build Coastguard Worker saved_x = y.grad_fn._saved_self 770*da0073e9SAndroid Build Coastguard Worker grad_y = LoggingTensor(torch.tensor([1.0])) 771*da0073e9SAndroid Build Coastguard Worker log_input("grad_y", grad_y) 772*da0073e9SAndroid Build Coastguard Worker (g,) = torch.autograd.grad((y,), (x,), (grad_y,)) 773*da0073e9SAndroid Build Coastguard Worker 774*da0073e9SAndroid Build Coastguard Worker self.assertEqual(g.elem, torch.tensor([6.0])) 775*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 776*da0073e9SAndroid Build Coastguard Worker self.assertEqual(saved_x, x) 777*da0073e9SAndroid Build Coastguard Worker self.assertEqual(saved_x._version, x._version) 778*da0073e9SAndroid Build Coastguard Worker x.add_(2) 779*da0073e9SAndroid Build Coastguard Worker self.assertEqual(saved_x, x) 780*da0073e9SAndroid Build Coastguard Worker # TODO: figure out why broken 781*da0073e9SAndroid Build Coastguard Worker # self.assertEqual(saved_x._version, x._version) 782*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 783*da0073e9SAndroid Build Coastguard Worker "\n".join(logs), 784*da0073e9SAndroid Build Coastguard Worker """\ 785*da0073e9SAndroid Build Coastguard Worker$0: f32[1] = input('x') 786*da0073e9SAndroid Build Coastguard Worker$1: f32[1] = torch._ops.aten.mul.Tensor($0, $0) 787*da0073e9SAndroid Build Coastguard Worker$2: f32[1] = input('grad_y') 788*da0073e9SAndroid Build Coastguard Worker$3: f32[1] = torch._ops.aten.mul.Tensor($2, $0) 789*da0073e9SAndroid Build Coastguard Worker$4: f32[1] = torch._ops.aten.mul.Tensor($2, $0) 790*da0073e9SAndroid Build Coastguard Worker$5: f32[1] = torch._ops.aten.add.Tensor($4, $3)""", 791*da0073e9SAndroid Build Coastguard Worker ) 792*da0073e9SAndroid Build Coastguard Worker 793*da0073e9SAndroid Build Coastguard Worker def test_out(self) -> None: 794*da0073e9SAndroid Build Coastguard Worker with capture_logs() as logs: 795*da0073e9SAndroid Build Coastguard Worker x = LoggingTensor(torch.ones(1)) 796*da0073e9SAndroid Build Coastguard Worker y = LoggingTensor(torch.zeros(1)) 797*da0073e9SAndroid Build Coastguard Worker log_input("x", x) 798*da0073e9SAndroid Build Coastguard Worker log_input("y", y) 799*da0073e9SAndroid Build Coastguard Worker torch.abs(x, out=y) 800*da0073e9SAndroid Build Coastguard Worker 801*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.elem, torch.ones(1)) 802*da0073e9SAndroid Build Coastguard Worker # TODO: arguably this shouldn't pass and we should complain 803*da0073e9SAndroid Build Coastguard Worker # that out isn't a kwarg 804*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 805*da0073e9SAndroid Build Coastguard Worker "\n".join(logs), 806*da0073e9SAndroid Build Coastguard Worker """\ 807*da0073e9SAndroid Build Coastguard Worker$0: f32[1] = input('x') 808*da0073e9SAndroid Build Coastguard Worker$1: f32[1] = input('y') 809*da0073e9SAndroid Build Coastguard Worker$2: f32[1] = torch._ops.aten.abs.out($0, out=$1)""", 810*da0073e9SAndroid Build Coastguard Worker ) 811*da0073e9SAndroid Build Coastguard Worker 812*da0073e9SAndroid Build Coastguard Worker def test_kwarg_only(self) -> None: 813*da0073e9SAndroid Build Coastguard Worker with capture_logs() as logs: 814*da0073e9SAndroid Build Coastguard Worker x = LoggingTensor(torch.ones(1)) 815*da0073e9SAndroid Build Coastguard Worker y = LoggingTensor(torch.ones(1, 1)) 816*da0073e9SAndroid Build Coastguard Worker z = LoggingTensor(torch.ones(1)) 817*da0073e9SAndroid Build Coastguard Worker log_input("x", x) 818*da0073e9SAndroid Build Coastguard Worker log_input("y", y) 819*da0073e9SAndroid Build Coastguard Worker log_input("z", z) 820*da0073e9SAndroid Build Coastguard Worker torch.addmv(x, y, z) 821*da0073e9SAndroid Build Coastguard Worker torch.addmv(x, y, z, beta=1) 822*da0073e9SAndroid Build Coastguard Worker torch.addmv(x, y, z, beta=2) 823*da0073e9SAndroid Build Coastguard Worker torch.addmv(x, y, z, alpha=2) 824*da0073e9SAndroid Build Coastguard Worker torch.addmv(x, y, z, beta=2, alpha=2) 825*da0073e9SAndroid Build Coastguard Worker 826*da0073e9SAndroid Build Coastguard Worker # The expectation is that beta/alpha don't show up when they're 827*da0073e9SAndroid Build Coastguard Worker # defaulted. This is even if the user explicitly specified it. 828*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 829*da0073e9SAndroid Build Coastguard Worker "\n".join(logs), 830*da0073e9SAndroid Build Coastguard Worker """\ 831*da0073e9SAndroid Build Coastguard Worker$0: f32[1] = input('x') 832*da0073e9SAndroid Build Coastguard Worker$1: f32[1, 1] = input('y') 833*da0073e9SAndroid Build Coastguard Worker$2: f32[1] = input('z') 834*da0073e9SAndroid Build Coastguard Worker$3: f32[1] = torch._ops.aten.addmv.default($0, $1, $2) 835*da0073e9SAndroid Build Coastguard Worker$4: f32[1] = torch._ops.aten.addmv.default($0, $1, $2) 836*da0073e9SAndroid Build Coastguard Worker$5: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, beta=2) 837*da0073e9SAndroid Build Coastguard Worker$6: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, alpha=2) 838*da0073e9SAndroid Build Coastguard Worker$7: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, beta=2, alpha=2)""", 839*da0073e9SAndroid Build Coastguard Worker ) 840*da0073e9SAndroid Build Coastguard Worker 841*da0073e9SAndroid Build Coastguard Worker def test_kwarg_only_and_positional_default(self) -> None: 842*da0073e9SAndroid Build Coastguard Worker with capture_logs() as logs: 843*da0073e9SAndroid Build Coastguard Worker x = LoggingTensor(torch.ones(1)) 844*da0073e9SAndroid Build Coastguard Worker log_input("x", x) 845*da0073e9SAndroid Build Coastguard Worker torch.ops.aten._foobar(x) 846*da0073e9SAndroid Build Coastguard Worker torch.ops.aten._foobar(x, False) 847*da0073e9SAndroid Build Coastguard Worker torch.ops.aten._foobar(x, arg3=False) 848*da0073e9SAndroid Build Coastguard Worker torch.ops.aten._foobar(x, False, arg3=False) 849*da0073e9SAndroid Build Coastguard Worker 850*da0073e9SAndroid Build Coastguard Worker # What we are testing here is that we omit arg2 851*da0073e9SAndroid Build Coastguard Worker # if it is defaulted, even if a kwarg is set 852*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 853*da0073e9SAndroid Build Coastguard Worker "\n".join(logs), 854*da0073e9SAndroid Build Coastguard Worker """\ 855*da0073e9SAndroid Build Coastguard Worker$0: f32[1] = input('x') 856*da0073e9SAndroid Build Coastguard Worker$1: f32[1] = torch._ops.aten._foobar.default($0) 857*da0073e9SAndroid Build Coastguard Worker$2: f32[1] = torch._ops.aten._foobar.default($0, False) 858*da0073e9SAndroid Build Coastguard Worker$3: f32[1] = torch._ops.aten._foobar.default($0, arg3=False) 859*da0073e9SAndroid Build Coastguard Worker$4: f32[1] = torch._ops.aten._foobar.default($0, False, arg3=False)""", 860*da0073e9SAndroid Build Coastguard Worker ) 861*da0073e9SAndroid Build Coastguard Worker 862*da0073e9SAndroid Build Coastguard Worker def test_produce_real_type(self) -> None: 863*da0073e9SAndroid Build Coastguard Worker with capture_logs() as logs: 864*da0073e9SAndroid Build Coastguard Worker x = LoggingTensor(torch.ones(2, 2)) 865*da0073e9SAndroid Build Coastguard Worker log_input("x", x) 866*da0073e9SAndroid Build Coastguard Worker x.to(dtype=torch.double) # non-optional dtype 867*da0073e9SAndroid Build Coastguard Worker torch.cumprod(x, 0, dtype=torch.double) # optional dtype 868*da0073e9SAndroid Build Coastguard Worker x[:, 1].contiguous( 869*da0073e9SAndroid Build Coastguard Worker memory_format=torch.contiguous_format 870*da0073e9SAndroid Build Coastguard Worker ) # optional memory format 871*da0073e9SAndroid Build Coastguard Worker # There doesn't appear to be any layout signatures which are 872*da0073e9SAndroid Build Coastguard Worker # triggerable using tensor subclasses (need to use a mode) 873*da0073e9SAndroid Build Coastguard Worker 874*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 875*da0073e9SAndroid Build Coastguard Worker "\n".join(logs), 876*da0073e9SAndroid Build Coastguard Worker """\ 877*da0073e9SAndroid Build Coastguard Worker$0: f32[2, 2] = input('x') 878*da0073e9SAndroid Build Coastguard Worker$1: f64[2, 2] = torch._ops.aten._to_copy.default($0, dtype=torch.float64) 879*da0073e9SAndroid Build Coastguard Worker$2: f64[2, 2] = torch._ops.aten.cumprod.default($0, 0, dtype=torch.float64) 880*da0073e9SAndroid Build Coastguard Worker$3: f32[2, 2] = torch._ops.aten.slice.Tensor($0, 0, 0, 9223372036854775807) 881*da0073e9SAndroid Build Coastguard Worker$4: f32[2] = torch._ops.aten.select.int($3, 1, 1) 882*da0073e9SAndroid Build Coastguard Worker$5: f32[2] = torch._ops.aten.clone.default($4, memory_format=torch.contiguous_format)""", 883*da0073e9SAndroid Build Coastguard Worker ) 884*da0073e9SAndroid Build Coastguard Worker 885*da0073e9SAndroid Build Coastguard Worker def test_optional_tensor_list(self) -> None: 886*da0073e9SAndroid Build Coastguard Worker def weird(xs): 887*da0073e9SAndroid Build Coastguard Worker print("woof") 888*da0073e9SAndroid Build Coastguard Worker return torch.empty(()) 889*da0073e9SAndroid Build Coastguard Worker 890*da0073e9SAndroid Build Coastguard Worker with _scoped_library("my_lib", "DEF") as my_lib: 891*da0073e9SAndroid Build Coastguard Worker my_lib.define("weird(Tensor?[] self) -> Tensor") 892*da0073e9SAndroid Build Coastguard Worker my_lib.impl("weird", weird, "CPU") 893*da0073e9SAndroid Build Coastguard Worker with capture_logs() as logs: 894*da0073e9SAndroid Build Coastguard Worker x = LoggingTensor(torch.ones(2, 2)) 895*da0073e9SAndroid Build Coastguard Worker log_input("x", x) 896*da0073e9SAndroid Build Coastguard Worker torch.ops.my_lib.weird.default([None, x]) 897*da0073e9SAndroid Build Coastguard Worker 898*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 899*da0073e9SAndroid Build Coastguard Worker "\n".join(logs), 900*da0073e9SAndroid Build Coastguard Worker """\ 901*da0073e9SAndroid Build Coastguard Worker$0: f32[2, 2] = input('x') 902*da0073e9SAndroid Build Coastguard Worker$1: f32[] = torch._ops.my_lib.weird.default(['None', '$0'])""", 903*da0073e9SAndroid Build Coastguard Worker ) 904*da0073e9SAndroid Build Coastguard Worker 905*da0073e9SAndroid Build Coastguard Worker def test_list_ret(self) -> None: 906*da0073e9SAndroid Build Coastguard Worker # test all sequence types are permissible returns 907*da0073e9SAndroid Build Coastguard Worker for list_type in (list, tuple): 908*da0073e9SAndroid Build Coastguard Worker 909*da0073e9SAndroid Build Coastguard Worker class A(torch.Tensor): 910*da0073e9SAndroid Build Coastguard Worker @staticmethod 911*da0073e9SAndroid Build Coastguard Worker def __new__(cls, elem): 912*da0073e9SAndroid Build Coastguard Worker return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) 913*da0073e9SAndroid Build Coastguard Worker 914*da0073e9SAndroid Build Coastguard Worker @classmethod 915*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 916*da0073e9SAndroid Build Coastguard Worker if func.overloadpacket == torch.ops.aten.split: 917*da0073e9SAndroid Build Coastguard Worker with no_dispatch(): 918*da0073e9SAndroid Build Coastguard Worker return list_type(torch.split(*args)) 919*da0073e9SAndroid Build Coastguard Worker else: 920*da0073e9SAndroid Build Coastguard Worker raise AssertionError(f"unrecognized func: {func}") 921*da0073e9SAndroid Build Coastguard Worker 922*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 923*da0073e9SAndroid Build Coastguard Worker torch.split(A(torch.tensor([0, 1])), 2), 924*da0073e9SAndroid Build Coastguard Worker torch.split(torch.tensor([0, 1]), 2), 925*da0073e9SAndroid Build Coastguard Worker ) 926*da0073e9SAndroid Build Coastguard Worker 927*da0073e9SAndroid Build Coastguard Worker def test_invalid_ret(self) -> None: 928*da0073e9SAndroid Build Coastguard Worker # test invalid return gets reasonable error message 929*da0073e9SAndroid Build Coastguard Worker class A(torch.Tensor): 930*da0073e9SAndroid Build Coastguard Worker @staticmethod 931*da0073e9SAndroid Build Coastguard Worker def __new__(cls, elem): 932*da0073e9SAndroid Build Coastguard Worker return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) 933*da0073e9SAndroid Build Coastguard Worker 934*da0073e9SAndroid Build Coastguard Worker @classmethod 935*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 936*da0073e9SAndroid Build Coastguard Worker return "arf" 937*da0073e9SAndroid Build Coastguard Worker 938*da0073e9SAndroid Build Coastguard Worker # Wobbles depending on NDEBUG mode of pybind11 939*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 940*da0073e9SAndroid Build Coastguard Worker RuntimeError, 941*da0073e9SAndroid Build Coastguard Worker "Unable to cast", 942*da0073e9SAndroid Build Coastguard Worker lambda: A(torch.zeros(1)).neg(), 943*da0073e9SAndroid Build Coastguard Worker ) 944*da0073e9SAndroid Build Coastguard Worker self.assertRaisesRegex( 945*da0073e9SAndroid Build Coastguard Worker RuntimeError, 946*da0073e9SAndroid Build Coastguard Worker "Unable to cast", 947*da0073e9SAndroid Build Coastguard Worker lambda: A(torch.zeros(1)).detach(), 948*da0073e9SAndroid Build Coastguard Worker ) 949*da0073e9SAndroid Build Coastguard Worker 950*da0073e9SAndroid Build Coastguard Worker def test_detach_appears_twice_when_called_once(self) -> None: 951*da0073e9SAndroid Build Coastguard Worker with capture_logs() as logs: 952*da0073e9SAndroid Build Coastguard Worker x = LoggingTensor(torch.tensor([3.0]), requires_grad=True) 953*da0073e9SAndroid Build Coastguard Worker log_input("x", x) 954*da0073e9SAndroid Build Coastguard Worker x.detach() 955*da0073e9SAndroid Build Coastguard Worker # FIXME: We actually want this to emit a single detach. However, 956*da0073e9SAndroid Build Coastguard Worker # it currently emits two, for reasons unclear to us. Leaving 957*da0073e9SAndroid Build Coastguard Worker # this test here to make sure we don't regress even further (it 958*da0073e9SAndroid Build Coastguard Worker # would be bad if calling .detach() once emits 3+ detaches). 959*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 960*da0073e9SAndroid Build Coastguard Worker "\n".join(logs), 961*da0073e9SAndroid Build Coastguard Worker """\ 962*da0073e9SAndroid Build Coastguard Worker$0: f32[1] = input('x') 963*da0073e9SAndroid Build Coastguard Worker$1: f32[1] = torch._ops.aten.detach.default($0) 964*da0073e9SAndroid Build Coastguard Worker$2: f32[1] = torch._ops.aten.detach.default($1)""", 965*da0073e9SAndroid Build Coastguard Worker ) 966*da0073e9SAndroid Build Coastguard Worker 967*da0073e9SAndroid Build Coastguard Worker def test_storage(self) -> None: 968*da0073e9SAndroid Build Coastguard Worker # For now, just make sure it doesn't crash. Ideally, we should 969*da0073e9SAndroid Build Coastguard Worker # return some virtual storage that is safe to work with 970*da0073e9SAndroid Build Coastguard Worker x = LoggingTensor(torch.ones(1)) 971*da0073e9SAndroid Build Coastguard Worker storage = x.untyped_storage() 972*da0073e9SAndroid Build Coastguard Worker self.assertRaises(RuntimeError, lambda: storage.data_ptr()) 973*da0073e9SAndroid Build Coastguard Worker 974*da0073e9SAndroid Build Coastguard Worker def test_make_wrapper_subclass_noalloc(self) -> None: 975*da0073e9SAndroid Build Coastguard Worker # This is ludicrously big (8TB) and this should pass because wrapper 976*da0073e9SAndroid Build Coastguard Worker # subclasses don't allocate 977*da0073e9SAndroid Build Coastguard Worker torch.Tensor._make_wrapper_subclass(LoggingTensor, (1000000000000,)) 978*da0073e9SAndroid Build Coastguard Worker 979*da0073e9SAndroid Build Coastguard Worker def test_version(self) -> None: 980*da0073e9SAndroid Build Coastguard Worker x = LoggingTensor(torch.ones(1)) 981*da0073e9SAndroid Build Coastguard Worker prev_vc = x._version 982*da0073e9SAndroid Build Coastguard Worker x.detach().add_(2) 983*da0073e9SAndroid Build Coastguard Worker cur_vc = x._version 984*da0073e9SAndroid Build Coastguard Worker self.assertNotEqual(prev_vc, cur_vc) 985*da0073e9SAndroid Build Coastguard Worker x.data.add_(2) 986*da0073e9SAndroid Build Coastguard Worker self.assertEqual(cur_vc, x._version) 987*da0073e9SAndroid Build Coastguard Worker 988*da0073e9SAndroid Build Coastguard Worker def test_subclass_priority(self) -> None: 989*da0073e9SAndroid Build Coastguard Worker class ErrorA(RuntimeError): 990*da0073e9SAndroid Build Coastguard Worker pass 991*da0073e9SAndroid Build Coastguard Worker 992*da0073e9SAndroid Build Coastguard Worker class ErrorB(RuntimeError): 993*da0073e9SAndroid Build Coastguard Worker pass 994*da0073e9SAndroid Build Coastguard Worker 995*da0073e9SAndroid Build Coastguard Worker # The big tests for code coverage are test_precedence_semantics in 996*da0073e9SAndroid Build Coastguard Worker # test_overrides.py; this is just to make sure it is wired up at all 997*da0073e9SAndroid Build Coastguard Worker # correctly for __torch_dispatch__ 998*da0073e9SAndroid Build Coastguard Worker class A(torch.Tensor): 999*da0073e9SAndroid Build Coastguard Worker @staticmethod 1000*da0073e9SAndroid Build Coastguard Worker def __new__(cls, elem): 1001*da0073e9SAndroid Build Coastguard Worker return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) 1002*da0073e9SAndroid Build Coastguard Worker 1003*da0073e9SAndroid Build Coastguard Worker @classmethod 1004*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1005*da0073e9SAndroid Build Coastguard Worker raise ErrorA 1006*da0073e9SAndroid Build Coastguard Worker 1007*da0073e9SAndroid Build Coastguard Worker class B(A): 1008*da0073e9SAndroid Build Coastguard Worker @staticmethod 1009*da0073e9SAndroid Build Coastguard Worker def __new__(cls, elem): 1010*da0073e9SAndroid Build Coastguard Worker return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) 1011*da0073e9SAndroid Build Coastguard Worker 1012*da0073e9SAndroid Build Coastguard Worker @classmethod 1013*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1014*da0073e9SAndroid Build Coastguard Worker raise ErrorB 1015*da0073e9SAndroid Build Coastguard Worker 1016*da0073e9SAndroid Build Coastguard Worker self.assertRaises( 1017*da0073e9SAndroid Build Coastguard Worker ErrorA, lambda: torch.add(A(torch.empty(1)), A(torch.empty(1))) 1018*da0073e9SAndroid Build Coastguard Worker ) 1019*da0073e9SAndroid Build Coastguard Worker self.assertRaises( 1020*da0073e9SAndroid Build Coastguard Worker ErrorB, lambda: torch.add(A(torch.empty(1)), B(torch.empty(1))) 1021*da0073e9SAndroid Build Coastguard Worker ) 1022*da0073e9SAndroid Build Coastguard Worker self.assertRaises( 1023*da0073e9SAndroid Build Coastguard Worker ErrorB, lambda: torch.add(B(torch.empty(1)), A(torch.empty(1))) 1024*da0073e9SAndroid Build Coastguard Worker ) 1025*da0073e9SAndroid Build Coastguard Worker self.assertRaises( 1026*da0073e9SAndroid Build Coastguard Worker ErrorB, lambda: torch.add(B(torch.empty(1)), B(torch.empty(1))) 1027*da0073e9SAndroid Build Coastguard Worker ) 1028*da0073e9SAndroid Build Coastguard Worker 1029*da0073e9SAndroid Build Coastguard Worker def test_format(self) -> None: 1030*da0073e9SAndroid Build Coastguard Worker x = LoggingTensor(torch.ones(1)) 1031*da0073e9SAndroid Build Coastguard Worker s1 = str(x) 1032*da0073e9SAndroid Build Coastguard Worker s2 = repr(x) 1033*da0073e9SAndroid Build Coastguard Worker s3 = f"{x}" 1034*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline(s1, """LoggingTensor(tensor([1.]))""") 1035*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s1, s2) 1036*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s1, s3) 1037*da0073e9SAndroid Build Coastguard Worker 1038*da0073e9SAndroid Build Coastguard Worker def test_custom_autograd(self) -> None: 1039*da0073e9SAndroid Build Coastguard Worker escape = [None] 1040*da0073e9SAndroid Build Coastguard Worker 1041*da0073e9SAndroid Build Coastguard Worker class Square(torch.autograd.Function): 1042*da0073e9SAndroid Build Coastguard Worker @staticmethod 1043*da0073e9SAndroid Build Coastguard Worker def forward(ctx, x): 1044*da0073e9SAndroid Build Coastguard Worker y = x**2 1045*da0073e9SAndroid Build Coastguard Worker ctx.save_for_backward(x) 1046*da0073e9SAndroid Build Coastguard Worker return y 1047*da0073e9SAndroid Build Coastguard Worker 1048*da0073e9SAndroid Build Coastguard Worker @staticmethod 1049*da0073e9SAndroid Build Coastguard Worker def backward(ctx, grad_output): 1050*da0073e9SAndroid Build Coastguard Worker assert isinstance(grad_output, LoggingTensor) 1051*da0073e9SAndroid Build Coastguard Worker (x,) = ctx.saved_tensors 1052*da0073e9SAndroid Build Coastguard Worker assert isinstance(x, LoggingTensor) 1053*da0073e9SAndroid Build Coastguard Worker escape[0] = x 1054*da0073e9SAndroid Build Coastguard Worker return grad_output * 2 * x 1055*da0073e9SAndroid Build Coastguard Worker 1056*da0073e9SAndroid Build Coastguard Worker with capture_logs() as logs: 1057*da0073e9SAndroid Build Coastguard Worker x = LoggingTensor(torch.ones(1), requires_grad=True) 1058*da0073e9SAndroid Build Coastguard Worker log_input("x", x) 1059*da0073e9SAndroid Build Coastguard Worker x.grad = LoggingTensor(torch.zeros(1)) 1060*da0073e9SAndroid Build Coastguard Worker log_input("x.grad", x.grad) 1061*da0073e9SAndroid Build Coastguard Worker y = Square.apply(x) 1062*da0073e9SAndroid Build Coastguard Worker grad_output = LoggingTensor(torch.ones(1)) 1063*da0073e9SAndroid Build Coastguard Worker log_input("grad_output", grad_output) 1064*da0073e9SAndroid Build Coastguard Worker y.backward(grad_output) 1065*da0073e9SAndroid Build Coastguard Worker 1066*da0073e9SAndroid Build Coastguard Worker with torch.no_grad(): 1067*da0073e9SAndroid Build Coastguard Worker self.assertEqual(escape[0], x) 1068*da0073e9SAndroid Build Coastguard Worker self.assertEqual(escape[0]._version, x._version) 1069*da0073e9SAndroid Build Coastguard Worker # TODO: figure out why x.requires_grad = False doesn't 1070*da0073e9SAndroid Build Coastguard Worker # trigger an error for LoggingTensor 1071*da0073e9SAndroid Build Coastguard Worker x.add_(2) 1072*da0073e9SAndroid Build Coastguard Worker self.assertEqual(escape[0], x) 1073*da0073e9SAndroid Build Coastguard Worker # TODO: figure out why this is broken 1074*da0073e9SAndroid Build Coastguard Worker # self.assertEqual(escape[0]._version, x._version) 1075*da0073e9SAndroid Build Coastguard Worker 1076*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1077*da0073e9SAndroid Build Coastguard Worker "\n".join(logs), 1078*da0073e9SAndroid Build Coastguard Worker """\ 1079*da0073e9SAndroid Build Coastguard Worker$0: f32[1] = input('x') 1080*da0073e9SAndroid Build Coastguard Worker$1: f32[1] = input('x.grad') 1081*da0073e9SAndroid Build Coastguard Worker$2: f32[1] = torch._ops.aten.pow.Tensor_Scalar($0, 2) 1082*da0073e9SAndroid Build Coastguard Worker$3: f32[1] = input('grad_output') 1083*da0073e9SAndroid Build Coastguard Worker$4: f32[1] = torch._ops.aten.mul.Tensor($3, 2) 1084*da0073e9SAndroid Build Coastguard Worker$5: f32[1] = torch._ops.aten.mul.Tensor($4, $0) 1085*da0073e9SAndroid Build Coastguard Worker$6: f32[1] = torch._ops.aten.add_.Tensor($1, $5)""", 1086*da0073e9SAndroid Build Coastguard Worker ) 1087*da0073e9SAndroid Build Coastguard Worker 1088*da0073e9SAndroid Build Coastguard Worker def test_subclass_creation(self): 1089*da0073e9SAndroid Build Coastguard Worker # Make sure these statements runs without error 1090*da0073e9SAndroid Build Coastguard Worker # In particular checking that when internal detach returns 1091*da0073e9SAndroid Build Coastguard Worker # subclasses, these are cleanly overwritten. 1092*da0073e9SAndroid Build Coastguard Worker class Foo(torch.Tensor): 1093*da0073e9SAndroid Build Coastguard Worker pass 1094*da0073e9SAndroid Build Coastguard Worker 1095*da0073e9SAndroid Build Coastguard Worker err_msg = "subclass Foo but.*already associated to a python object of type LoggingTensor" 1096*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_msg): 1097*da0073e9SAndroid Build Coastguard Worker a = torch.Tensor._make_subclass(Foo, LoggingTensor(torch.rand(2))) 1098*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_msg): 1099*da0073e9SAndroid Build Coastguard Worker b = LoggingTensor(torch.rand(2)).as_subclass(Foo) 1100*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, err_msg): 1101*da0073e9SAndroid Build Coastguard Worker Foo(LoggingTensor(torch.rand(2))) 1102*da0073e9SAndroid Build Coastguard Worker 1103*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, "Foo must define __torch_dispatch__"): 1104*da0073e9SAndroid Build Coastguard Worker torch.Tensor._make_wrapper_subclass(Foo, (2, 2)) 1105*da0073e9SAndroid Build Coastguard Worker 1106*da0073e9SAndroid Build Coastguard Worker def test_new_ones(self) -> None: 1107*da0073e9SAndroid Build Coastguard Worker class MyTensor(torch.Tensor): 1108*da0073e9SAndroid Build Coastguard Worker @classmethod 1109*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1110*da0073e9SAndroid Build Coastguard Worker return MyTensor(3) 1111*da0073e9SAndroid Build Coastguard Worker 1112*da0073e9SAndroid Build Coastguard Worker self.assertEqual(type(MyTensor(2).new_ones(3)), MyTensor) 1113*da0073e9SAndroid Build Coastguard Worker 1114*da0073e9SAndroid Build Coastguard Worker def test_like(self) -> None: 1115*da0073e9SAndroid Build Coastguard Worker class MyTensor(torch.Tensor): 1116*da0073e9SAndroid Build Coastguard Worker @classmethod 1117*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1118*da0073e9SAndroid Build Coastguard Worker return MyTensor(3) 1119*da0073e9SAndroid Build Coastguard Worker 1120*da0073e9SAndroid Build Coastguard Worker for f in ["empty", "ones", "rand", "randn", "zeros"]: 1121*da0073e9SAndroid Build Coastguard Worker f_name = f + "_like" 1122*da0073e9SAndroid Build Coastguard Worker self.assertEqual(type(getattr(torch, f_name)(MyTensor(2))), MyTensor) 1123*da0073e9SAndroid Build Coastguard Worker 1124*da0073e9SAndroid Build Coastguard Worker self.assertEqual(type(torch.full_like(MyTensor(2), 1.0)), MyTensor) 1125*da0073e9SAndroid Build Coastguard Worker self.assertEqual(type(torch.randint_like(MyTensor(2), high=3)), MyTensor) 1126*da0073e9SAndroid Build Coastguard Worker 1127*da0073e9SAndroid Build Coastguard Worker def test_make_fx_with_subclass(self) -> None: 1128*da0073e9SAndroid Build Coastguard Worker def f(x, y): 1129*da0073e9SAndroid Build Coastguard Worker # Returns (TwoTensor, Tensor) 1130*da0073e9SAndroid Build Coastguard Worker return x * y, y + y 1131*da0073e9SAndroid Build Coastguard Worker 1132*da0073e9SAndroid Build Coastguard Worker x_a = torch.zeros(4) 1133*da0073e9SAndroid Build Coastguard Worker x_b = torch.zeros(4) 1134*da0073e9SAndroid Build Coastguard Worker y = torch.ones(4) 1135*da0073e9SAndroid Build Coastguard Worker 1136*da0073e9SAndroid Build Coastguard Worker # make_fx() is not responsible for unwrapping tensor subclass inputs, 1137*da0073e9SAndroid Build Coastguard Worker # so we do it manually here. 1138*da0073e9SAndroid Build Coastguard Worker # Why? In general, make_fx(f)(*args) promises that the graph returned has the same calling 1139*da0073e9SAndroid Build Coastguard Worker # convention as f(*args). Unwrapping tensor subclass inputs can potentially change 1140*da0073e9SAndroid Build Coastguard Worker # the number of input args to the graph, breaking that assumption 1141*da0073e9SAndroid Build Coastguard Worker def f_to_trace(x_a, x_b, y): 1142*da0073e9SAndroid Build Coastguard Worker x = TwoTensor(x_a, x_b) 1143*da0073e9SAndroid Build Coastguard Worker out1, out2 = f(x, y) 1144*da0073e9SAndroid Build Coastguard Worker out1_unwrapped_attrs, _ = out1.__tensor_flatten__() 1145*da0073e9SAndroid Build Coastguard Worker return (*[getattr(out1, attr) for attr in out1_unwrapped_attrs], out2) 1146*da0073e9SAndroid Build Coastguard Worker 1147*da0073e9SAndroid Build Coastguard Worker fx_g = make_fx(f_to_trace, tracing_mode="fake")(x_a, x_b, y) 1148*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1149*da0073e9SAndroid Build Coastguard Worker fx_g.code, 1150*da0073e9SAndroid Build Coastguard Worker """\ 1151*da0073e9SAndroid Build Coastguard Worker 1152*da0073e9SAndroid Build Coastguard Worker 1153*da0073e9SAndroid Build Coastguard Worker 1154*da0073e9SAndroid Build Coastguard Workerdef forward(self, x_a_1, x_b_1, y_1): 1155*da0073e9SAndroid Build Coastguard Worker mul = torch.ops.aten.mul.Tensor(x_a_1, y_1); x_a_1 = None 1156*da0073e9SAndroid Build Coastguard Worker mul_1 = torch.ops.aten.mul.Tensor(x_b_1, y_1); x_b_1 = None 1157*da0073e9SAndroid Build Coastguard Worker add = torch.ops.aten.add.Tensor(y_1, y_1); y_1 = None 1158*da0073e9SAndroid Build Coastguard Worker return (mul, mul_1, add) 1159*da0073e9SAndroid Build Coastguard Worker """, 1160*da0073e9SAndroid Build Coastguard Worker ) 1161*da0073e9SAndroid Build Coastguard Worker 1162*da0073e9SAndroid Build Coastguard Worker # See https://github.com/pytorch/pytorch/issues/117794 1163*da0073e9SAndroid Build Coastguard Worker def test_return_and_correct_aliasing_gives_correct_stride(self): 1164*da0073e9SAndroid Build Coastguard Worker t = TwoTensor(torch.randn(2, 2), torch.randn(2, 2)) 1165*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, 2) 1166*da0073e9SAndroid Build Coastguard Worker # slicing should result in the same stride for TwoTensor as a dense tensor would give 1167*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t[:, 0].stride(), x[:, 0].stride()) 1168*da0073e9SAndroid Build Coastguard Worker 1169*da0073e9SAndroid Build Coastguard Worker def test_make_wrapper_subclass_propagates_metadata(self) -> None: 1170*da0073e9SAndroid Build Coastguard Worker class WrapperTensor(torch.Tensor): 1171*da0073e9SAndroid Build Coastguard Worker elem: torch.Tensor 1172*da0073e9SAndroid Build Coastguard Worker 1173*da0073e9SAndroid Build Coastguard Worker __slots__ = ["elem"] 1174*da0073e9SAndroid Build Coastguard Worker 1175*da0073e9SAndroid Build Coastguard Worker @staticmethod 1176*da0073e9SAndroid Build Coastguard Worker def __new__(cls, elem, *args, **kwargs): 1177*da0073e9SAndroid Build Coastguard Worker r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] 1178*da0073e9SAndroid Build Coastguard Worker cls, 1179*da0073e9SAndroid Build Coastguard Worker elem.size(), 1180*da0073e9SAndroid Build Coastguard Worker dtype=elem.dtype, 1181*da0073e9SAndroid Build Coastguard Worker layout=elem.layout, 1182*da0073e9SAndroid Build Coastguard Worker device=elem.device, 1183*da0073e9SAndroid Build Coastguard Worker requires_grad=elem.requires_grad, 1184*da0073e9SAndroid Build Coastguard Worker strides=elem.stride(), 1185*da0073e9SAndroid Build Coastguard Worker storage_offset=elem.storage_offset(), 1186*da0073e9SAndroid Build Coastguard Worker ) 1187*da0073e9SAndroid Build Coastguard Worker r.elem = elem 1188*da0073e9SAndroid Build Coastguard Worker return r 1189*da0073e9SAndroid Build Coastguard Worker 1190*da0073e9SAndroid Build Coastguard Worker @classmethod 1191*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1192*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("NYI") 1193*da0073e9SAndroid Build Coastguard Worker 1194*da0073e9SAndroid Build Coastguard Worker # non-contiguous strides, non-zero storage offset 1195*da0073e9SAndroid Build Coastguard Worker x = torch.randn(4, 6).t().diagonal(offset=2) 1196*da0073e9SAndroid Build Coastguard Worker y = WrapperTensor(x) 1197*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.size(), x.size()) 1198*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.stride(), x.stride()) 1199*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.storage_offset(), x.storage_offset()) 1200*da0073e9SAndroid Build Coastguard Worker 1201*da0073e9SAndroid Build Coastguard Worker def test_wrapper_subclass_serializes(self) -> None: 1202*da0073e9SAndroid Build Coastguard Worker with tempfile.TemporaryFile() as f: 1203*da0073e9SAndroid Build Coastguard Worker # purposefully use int64 to test non-default dtype 1204*da0073e9SAndroid Build Coastguard Worker x = LoggingTensor(torch.randperm(3)) 1205*da0073e9SAndroid Build Coastguard Worker torch.save(x, f) 1206*da0073e9SAndroid Build Coastguard Worker f.seek(0) 1207*da0073e9SAndroid Build Coastguard Worker with torch.serialization.safe_globals([LoggingTensor]): 1208*da0073e9SAndroid Build Coastguard Worker x_loaded = torch.load(f) 1209*da0073e9SAndroid Build Coastguard Worker self.assertTrue(type(x_loaded) is type(x)) 1210*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x_loaded) 1211*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.elem, x_loaded.elem) 1212*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x is x_loaded) 1213*da0073e9SAndroid Build Coastguard Worker 1214*da0073e9SAndroid Build Coastguard Worker def test_deepcopy_wrapper_subclass(self) -> None: 1215*da0073e9SAndroid Build Coastguard Worker # purposefully use int64 to test non-default dtype 1216*da0073e9SAndroid Build Coastguard Worker x = LoggingTensor(torch.randperm(3)) 1217*da0073e9SAndroid Build Coastguard Worker x_copy = deepcopy(x) 1218*da0073e9SAndroid Build Coastguard Worker self.assertTrue(type(x_copy) is type(x)) 1219*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, x_copy) 1220*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.elem, x_copy.elem) 1221*da0073e9SAndroid Build Coastguard Worker self.assertFalse(x is x_copy) 1222*da0073e9SAndroid Build Coastguard Worker 1223*da0073e9SAndroid Build Coastguard Worker def test_deepcopy_wrapper_subclass_with_clone_returning_different_type( 1224*da0073e9SAndroid Build Coastguard Worker self, 1225*da0073e9SAndroid Build Coastguard Worker ) -> None: 1226*da0073e9SAndroid Build Coastguard Worker class MyWrapperTensor(torch.Tensor): 1227*da0073e9SAndroid Build Coastguard Worker elem: torch.Tensor 1228*da0073e9SAndroid Build Coastguard Worker 1229*da0073e9SAndroid Build Coastguard Worker __slots__ = ["elem"] 1230*da0073e9SAndroid Build Coastguard Worker 1231*da0073e9SAndroid Build Coastguard Worker @staticmethod 1232*da0073e9SAndroid Build Coastguard Worker def __new__(cls, elem, *args, **kwargs): 1233*da0073e9SAndroid Build Coastguard Worker r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] 1234*da0073e9SAndroid Build Coastguard Worker cls, 1235*da0073e9SAndroid Build Coastguard Worker elem.size(), 1236*da0073e9SAndroid Build Coastguard Worker dtype=elem.dtype, 1237*da0073e9SAndroid Build Coastguard Worker layout=elem.layout, 1238*da0073e9SAndroid Build Coastguard Worker device=elem.device, 1239*da0073e9SAndroid Build Coastguard Worker requires_grad=elem.requires_grad, 1240*da0073e9SAndroid Build Coastguard Worker strides=elem.stride(), 1241*da0073e9SAndroid Build Coastguard Worker storage_offset=elem.storage_offset(), 1242*da0073e9SAndroid Build Coastguard Worker ) 1243*da0073e9SAndroid Build Coastguard Worker r.elem = elem 1244*da0073e9SAndroid Build Coastguard Worker return r 1245*da0073e9SAndroid Build Coastguard Worker 1246*da0073e9SAndroid Build Coastguard Worker @classmethod 1247*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1248*da0073e9SAndroid Build Coastguard Worker if func.overloadpacket.__name__ == "clone": 1249*da0073e9SAndroid Build Coastguard Worker # Return a plain tensor from clone(). 1250*da0073e9SAndroid Build Coastguard Worker return args[0].elem.clone() 1251*da0073e9SAndroid Build Coastguard Worker raise RuntimeError("NYI") 1252*da0073e9SAndroid Build Coastguard Worker 1253*da0073e9SAndroid Build Coastguard Worker # NB: The default Tensor.__torch_function__ implementation called for deepcopy 1254*da0073e9SAndroid Build Coastguard Worker # disables __torch_function__ by the time we get to clone(), so there is no need to 1255*da0073e9SAndroid Build Coastguard Worker # explicitly disable __torch_function__ for this subclass. 1256*da0073e9SAndroid Build Coastguard Worker 1257*da0073e9SAndroid Build Coastguard Worker x = MyWrapperTensor(torch.randn(3)) 1258*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1259*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1260*da0073e9SAndroid Build Coastguard Worker "for which cloning returns another instance of the same subclass", 1261*da0073e9SAndroid Build Coastguard Worker ): 1262*da0073e9SAndroid Build Coastguard Worker x_copy = deepcopy(x) 1263*da0073e9SAndroid Build Coastguard Worker 1264*da0073e9SAndroid Build Coastguard Worker def test_deepcopy_non_wrapper_subclass(self) -> None: 1265*da0073e9SAndroid Build Coastguard Worker # Ensure correct error is thrown for common error cases. 1266*da0073e9SAndroid Build Coastguard Worker class SubTensorError1(torch.Tensor): 1267*da0073e9SAndroid Build Coastguard Worker # Default implementation of new_empty() returns a plain tensor. 1268*da0073e9SAndroid Build Coastguard Worker pass 1269*da0073e9SAndroid Build Coastguard Worker 1270*da0073e9SAndroid Build Coastguard Worker class SubTensorError2(torch.Tensor): 1271*da0073e9SAndroid Build Coastguard Worker # new_empty() incorrectly returns a different type (i.e. a plain tensor). 1272*da0073e9SAndroid Build Coastguard Worker def new_empty(self, shape): 1273*da0073e9SAndroid Build Coastguard Worker return torch.Tensor(shape) 1274*da0073e9SAndroid Build Coastguard Worker 1275*da0073e9SAndroid Build Coastguard Worker for error_cls in [SubTensorError1, SubTensorError2]: 1276*da0073e9SAndroid Build Coastguard Worker x = error_cls(3) 1277*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1278*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1279*da0073e9SAndroid Build Coastguard Worker "for which that function returns another instance of the same subclass", 1280*da0073e9SAndroid Build Coastguard Worker ): 1281*da0073e9SAndroid Build Coastguard Worker x_copy = deepcopy(x) 1282*da0073e9SAndroid Build Coastguard Worker 1283*da0073e9SAndroid Build Coastguard Worker # Ensure a correctly implemented new_empty() causes deepcopy() to work. 1284*da0073e9SAndroid Build Coastguard Worker class SubTensorSuccess(torch.Tensor): 1285*da0073e9SAndroid Build Coastguard Worker def new_empty(self, shape): 1286*da0073e9SAndroid Build Coastguard Worker return type(self)(shape) 1287*da0073e9SAndroid Build Coastguard Worker 1288*da0073e9SAndroid Build Coastguard Worker x = SubTensorSuccess(3) 1289*da0073e9SAndroid Build Coastguard Worker x_copy = deepcopy(x) 1290*da0073e9SAndroid Build Coastguard Worker self.assertIs(type(x_copy), type(x)) 1291*da0073e9SAndroid Build Coastguard Worker 1292*da0073e9SAndroid Build Coastguard Worker def test_wrapper_subclass_extra_dispatch_keys(self) -> None: 1293*da0073e9SAndroid Build Coastguard Worker class ExtraKeysTensor(torch.Tensor): 1294*da0073e9SAndroid Build Coastguard Worker @staticmethod 1295*da0073e9SAndroid Build Coastguard Worker def __new__(cls, elem, *args, **kwargs): 1296*da0073e9SAndroid Build Coastguard Worker # NB: only the non-kwarg overload of _make_wrapper_subclass supports 1297*da0073e9SAndroid Build Coastguard Worker # extra dispatch keys. We probably want to unify the two APIs 1298*da0073e9SAndroid Build Coastguard Worker # in the future. 1299*da0073e9SAndroid Build Coastguard Worker r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] 1300*da0073e9SAndroid Build Coastguard Worker cls, 1301*da0073e9SAndroid Build Coastguard Worker elem.size(), 1302*da0073e9SAndroid Build Coastguard Worker elem.stride(), 1303*da0073e9SAndroid Build Coastguard Worker elem.storage_offset(), 1304*da0073e9SAndroid Build Coastguard Worker torch.contiguous_format, 1305*da0073e9SAndroid Build Coastguard Worker elem.dtype, 1306*da0073e9SAndroid Build Coastguard Worker elem.layout, 1307*da0073e9SAndroid Build Coastguard Worker elem.device, 1308*da0073e9SAndroid Build Coastguard Worker False, 1309*da0073e9SAndroid Build Coastguard Worker False, 1310*da0073e9SAndroid Build Coastguard Worker None, 1311*da0073e9SAndroid Build Coastguard Worker False, 1312*da0073e9SAndroid Build Coastguard Worker False, 1313*da0073e9SAndroid Build Coastguard Worker DispatchKeySet(DispatchKey.NestedTensor), 1314*da0073e9SAndroid Build Coastguard Worker ) 1315*da0073e9SAndroid Build Coastguard Worker return r 1316*da0073e9SAndroid Build Coastguard Worker 1317*da0073e9SAndroid Build Coastguard Worker @classmethod 1318*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1319*da0073e9SAndroid Build Coastguard Worker pass 1320*da0073e9SAndroid Build Coastguard Worker 1321*da0073e9SAndroid Build Coastguard Worker x = ExtraKeysTensor(torch.randn(3)) 1322*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch._C._dispatch_keys(x).has(DispatchKey.NestedTensor)) 1323*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 1324*da0073e9SAndroid Build Coastguard Worker torch._C._dispatch_keys(x).has(DispatchKey.AutogradNestedTensor) 1325*da0073e9SAndroid Build Coastguard Worker ) 1326*da0073e9SAndroid Build Coastguard Worker 1327*da0073e9SAndroid Build Coastguard Worker def test_wrapper_subclass_multiprocessing_preserves_dtype(self): 1328*da0073e9SAndroid Build Coastguard Worker # a and b have dtype of int64, which is purposefully different from the default 1329*da0073e9SAndroid Build Coastguard Worker # assumed by _make_wrapper_subclass(). 1330*da0073e9SAndroid Build Coastguard Worker a = torch.randperm(5) 1331*da0073e9SAndroid Build Coastguard Worker b = torch.randperm(5) 1332*da0073e9SAndroid Build Coastguard Worker data = TwoTensor(a, b) 1333*da0073e9SAndroid Build Coastguard Worker expected_dtype = data.dtype 1334*da0073e9SAndroid Build Coastguard Worker 1335*da0073e9SAndroid Build Coastguard Worker loader = torch.utils.data.DataLoader( 1336*da0073e9SAndroid Build Coastguard Worker [data, data], 1337*da0073e9SAndroid Build Coastguard Worker batch_size=2, 1338*da0073e9SAndroid Build Coastguard Worker num_workers=2, 1339*da0073e9SAndroid Build Coastguard Worker collate_fn=_identity, 1340*da0073e9SAndroid Build Coastguard Worker ) 1341*da0073e9SAndroid Build Coastguard Worker for batch in loader: 1342*da0073e9SAndroid Build Coastguard Worker self.assertEqual(batch[0].dtype, expected_dtype) 1343*da0073e9SAndroid Build Coastguard Worker 1344*da0073e9SAndroid Build Coastguard Worker def test_index_put_where_only_index_is_subclass(self) -> None: 1345*da0073e9SAndroid Build Coastguard Worker called_funcs = [] 1346*da0073e9SAndroid Build Coastguard Worker 1347*da0073e9SAndroid Build Coastguard Worker class MyTensor(torch.Tensor): 1348*da0073e9SAndroid Build Coastguard Worker elem: torch.Tensor 1349*da0073e9SAndroid Build Coastguard Worker __slots__ = ["elem"] 1350*da0073e9SAndroid Build Coastguard Worker 1351*da0073e9SAndroid Build Coastguard Worker @staticmethod 1352*da0073e9SAndroid Build Coastguard Worker def __new__(cls, elem, *args, **kwargs): 1353*da0073e9SAndroid Build Coastguard Worker r = torch.Tensor._make_wrapper_subclass( 1354*da0073e9SAndroid Build Coastguard Worker cls, 1355*da0073e9SAndroid Build Coastguard Worker elem.size(), 1356*da0073e9SAndroid Build Coastguard Worker dtype=elem.dtype, 1357*da0073e9SAndroid Build Coastguard Worker layout=elem.layout, 1358*da0073e9SAndroid Build Coastguard Worker device=elem.device, 1359*da0073e9SAndroid Build Coastguard Worker requires_grad=elem.requires_grad, 1360*da0073e9SAndroid Build Coastguard Worker ) 1361*da0073e9SAndroid Build Coastguard Worker r.elem = elem 1362*da0073e9SAndroid Build Coastguard Worker return r 1363*da0073e9SAndroid Build Coastguard Worker 1364*da0073e9SAndroid Build Coastguard Worker @classmethod 1365*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1366*da0073e9SAndroid Build Coastguard Worker called_funcs.append(func) 1367*da0073e9SAndroid Build Coastguard Worker return MyTensor(torch.tensor(3)) 1368*da0073e9SAndroid Build Coastguard Worker 1369*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, 3) 1370*da0073e9SAndroid Build Coastguard Worker idxs = (MyTensor(torch.tensor(0)),) 1371*da0073e9SAndroid Build Coastguard Worker v = torch.randn(1) 1372*da0073e9SAndroid Build Coastguard Worker res = x.index_put_(idxs, v) 1373*da0073e9SAndroid Build Coastguard Worker self.assertEqual(called_funcs, [torch.ops.aten.index_put_.default]) 1374*da0073e9SAndroid Build Coastguard Worker 1375*da0073e9SAndroid Build Coastguard Worker def test_torch_dispatch_mode_basic(self) -> None: 1376*da0073e9SAndroid Build Coastguard Worker with capture_logs(is_mode=True) as logs: 1377*da0073e9SAndroid Build Coastguard Worker with LoggingTensorMode(): 1378*da0073e9SAndroid Build Coastguard Worker torch.empty([]) 1379*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1380*da0073e9SAndroid Build Coastguard Worker "\n".join(logs), 1381*da0073e9SAndroid Build Coastguard Worker """\ 1382*da0073e9SAndroid Build Coastguard Worker$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""", 1383*da0073e9SAndroid Build Coastguard Worker ) 1384*da0073e9SAndroid Build Coastguard Worker 1385*da0073e9SAndroid Build Coastguard Worker def test_torch_dispatch_mode_unrelated_tensors(self) -> None: 1386*da0073e9SAndroid Build Coastguard Worker x = torch.randn([]) 1387*da0073e9SAndroid Build Coastguard Worker y = torch.randn([]) 1388*da0073e9SAndroid Build Coastguard Worker with capture_logs(is_mode=True) as logs: 1389*da0073e9SAndroid Build Coastguard Worker with LoggingTensorMode(): 1390*da0073e9SAndroid Build Coastguard Worker x + y 1391*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1392*da0073e9SAndroid Build Coastguard Worker "\n".join(logs), """$2: f32[] = torch._ops.aten.add.Tensor($0, $1)""" 1393*da0073e9SAndroid Build Coastguard Worker ) 1394*da0073e9SAndroid Build Coastguard Worker 1395*da0073e9SAndroid Build Coastguard Worker def test_nested_push_logging_tensor_mode(self): 1396*da0073e9SAndroid Build Coastguard Worker x = torch.randn([]) 1397*da0073e9SAndroid Build Coastguard Worker y = torch.randn([]) 1398*da0073e9SAndroid Build Coastguard Worker with capture_logs(is_mode=True) as logs: 1399*da0073e9SAndroid Build Coastguard Worker with LoggingTensorMode(): 1400*da0073e9SAndroid Build Coastguard Worker with LoggingTensorMode(): 1401*da0073e9SAndroid Build Coastguard Worker torch.empty([]) 1402*da0073e9SAndroid Build Coastguard Worker x + y 1403*da0073e9SAndroid Build Coastguard Worker 1404*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1405*da0073e9SAndroid Build Coastguard Worker "\n".join(logs), 1406*da0073e9SAndroid Build Coastguard Worker """\ 1407*da0073e9SAndroid Build Coastguard Worker$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False) 1408*da0073e9SAndroid Build Coastguard Worker$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False) 1409*da0073e9SAndroid Build Coastguard Worker$3: f32[] = torch._ops.aten.add.Tensor($1, $2) 1410*da0073e9SAndroid Build Coastguard Worker$3: f32[] = torch._ops.aten.add.Tensor($1, $2)""", 1411*da0073e9SAndroid Build Coastguard Worker ) 1412*da0073e9SAndroid Build Coastguard Worker 1413*da0073e9SAndroid Build Coastguard Worker def test_capture_logs_with_torch_dispatch_mode(self): 1414*da0073e9SAndroid Build Coastguard Worker x = torch.randn([]) 1415*da0073e9SAndroid Build Coastguard Worker y = torch.randn([]) 1416*da0073e9SAndroid Build Coastguard Worker with capture_logs_with_logging_tensor_mode() as logs: 1417*da0073e9SAndroid Build Coastguard Worker torch.empty([]) 1418*da0073e9SAndroid Build Coastguard Worker x + y 1419*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1420*da0073e9SAndroid Build Coastguard Worker "\n".join(logs), 1421*da0073e9SAndroid Build Coastguard Worker """\ 1422*da0073e9SAndroid Build Coastguard Worker$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False) 1423*da0073e9SAndroid Build Coastguard Worker$3: f32[] = torch._ops.aten.add.Tensor($1, $2)""", 1424*da0073e9SAndroid Build Coastguard Worker ) 1425*da0073e9SAndroid Build Coastguard Worker 1426*da0073e9SAndroid Build Coastguard Worker x = torch.randn([]) 1427*da0073e9SAndroid Build Coastguard Worker y = torch.randn([]) 1428*da0073e9SAndroid Build Coastguard Worker 1429*da0073e9SAndroid Build Coastguard Worker with capture_logs_with_logging_tensor_mode() as logs1: 1430*da0073e9SAndroid Build Coastguard Worker with capture_logs_with_logging_tensor_mode() as logs2: 1431*da0073e9SAndroid Build Coastguard Worker torch.empty([]) 1432*da0073e9SAndroid Build Coastguard Worker x + y 1433*da0073e9SAndroid Build Coastguard Worker 1434*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1435*da0073e9SAndroid Build Coastguard Worker "\n".join(logs2), 1436*da0073e9SAndroid Build Coastguard Worker """\ 1437*da0073e9SAndroid Build Coastguard Worker$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False) 1438*da0073e9SAndroid Build Coastguard Worker$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False) 1439*da0073e9SAndroid Build Coastguard Worker$3: f32[] = torch._ops.aten.add.Tensor($1, $2) 1440*da0073e9SAndroid Build Coastguard Worker$3: f32[] = torch._ops.aten.add.Tensor($1, $2)""", 1441*da0073e9SAndroid Build Coastguard Worker ) 1442*da0073e9SAndroid Build Coastguard Worker 1443*da0073e9SAndroid Build Coastguard Worker self.assertEqual(logs1, logs2) 1444*da0073e9SAndroid Build Coastguard Worker 1445*da0073e9SAndroid Build Coastguard Worker def test_torch_dispatch_mode_subclass_priority(self) -> None: 1446*da0073e9SAndroid Build Coastguard Worker class ErrorA(RuntimeError): 1447*da0073e9SAndroid Build Coastguard Worker pass 1448*da0073e9SAndroid Build Coastguard Worker 1449*da0073e9SAndroid Build Coastguard Worker class ErrorB(RuntimeError): 1450*da0073e9SAndroid Build Coastguard Worker pass 1451*da0073e9SAndroid Build Coastguard Worker 1452*da0073e9SAndroid Build Coastguard Worker class A(torch.Tensor): 1453*da0073e9SAndroid Build Coastguard Worker @staticmethod 1454*da0073e9SAndroid Build Coastguard Worker def __new__(cls, elem): 1455*da0073e9SAndroid Build Coastguard Worker return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) 1456*da0073e9SAndroid Build Coastguard Worker 1457*da0073e9SAndroid Build Coastguard Worker @classmethod 1458*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1459*da0073e9SAndroid Build Coastguard Worker with AMode(): 1460*da0073e9SAndroid Build Coastguard Worker raise ErrorA 1461*da0073e9SAndroid Build Coastguard Worker 1462*da0073e9SAndroid Build Coastguard Worker class B(A): 1463*da0073e9SAndroid Build Coastguard Worker @staticmethod 1464*da0073e9SAndroid Build Coastguard Worker def __new__(cls, elem): 1465*da0073e9SAndroid Build Coastguard Worker return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) 1466*da0073e9SAndroid Build Coastguard Worker 1467*da0073e9SAndroid Build Coastguard Worker @classmethod 1468*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1469*da0073e9SAndroid Build Coastguard Worker with BMode(): 1470*da0073e9SAndroid Build Coastguard Worker func(*args, **kwargs) 1471*da0073e9SAndroid Build Coastguard Worker 1472*da0073e9SAndroid Build Coastguard Worker class AMode(TorchDispatchMode): 1473*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1474*da0073e9SAndroid Build Coastguard Worker raise ErrorA 1475*da0073e9SAndroid Build Coastguard Worker 1476*da0073e9SAndroid Build Coastguard Worker class BMode(TorchDispatchMode): 1477*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1478*da0073e9SAndroid Build Coastguard Worker raise ErrorB 1479*da0073e9SAndroid Build Coastguard Worker 1480*da0073e9SAndroid Build Coastguard Worker a = A(torch.empty(1)) 1481*da0073e9SAndroid Build Coastguard Worker b = B(torch.empty(1)) 1482*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ErrorA): 1483*da0073e9SAndroid Build Coastguard Worker a + a 1484*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ErrorB): 1485*da0073e9SAndroid Build Coastguard Worker a + b 1486*da0073e9SAndroid Build Coastguard Worker 1487*da0073e9SAndroid Build Coastguard Worker # B has precedence over A due to the subclass relationship yet 1488*da0073e9SAndroid Build Coastguard Worker # modes take precedence over arguments 1489*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ErrorA): 1490*da0073e9SAndroid Build Coastguard Worker with AMode(): 1491*da0073e9SAndroid Build Coastguard Worker b + b 1492*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ErrorB): 1493*da0073e9SAndroid Build Coastguard Worker with BMode(): 1494*da0073e9SAndroid Build Coastguard Worker a + a 1495*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ErrorB): 1496*da0073e9SAndroid Build Coastguard Worker with BMode(): 1497*da0073e9SAndroid Build Coastguard Worker a + b 1498*da0073e9SAndroid Build Coastguard Worker 1499*da0073e9SAndroid Build Coastguard Worker def test_mode_with_make_subclass(self): 1500*da0073e9SAndroid Build Coastguard Worker class SubTensor(torch.Tensor): 1501*da0073e9SAndroid Build Coastguard Worker @staticmethod 1502*da0073e9SAndroid Build Coastguard Worker def __new__(cls, elem): 1503*da0073e9SAndroid Build Coastguard Worker return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) 1504*da0073e9SAndroid Build Coastguard Worker 1505*da0073e9SAndroid Build Coastguard Worker class BasicMode(TorchDispatchMode): 1506*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1507*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 1508*da0073e9SAndroid Build Coastguard Worker 1509*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3) 1510*da0073e9SAndroid Build Coastguard Worker with BasicMode(): 1511*da0073e9SAndroid Build Coastguard Worker y = SubTensor(x) 1512*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(y, SubTensor) 1513*da0073e9SAndroid Build Coastguard Worker 1514*da0073e9SAndroid Build Coastguard Worker def test_torch_dispatch_mode_respects_no_dispatch(self) -> None: 1515*da0073e9SAndroid Build Coastguard Worker with capture_logs(is_mode=True) as logs1: 1516*da0073e9SAndroid Build Coastguard Worker with LoggingTensorMode(): 1517*da0073e9SAndroid Build Coastguard Worker torch.ones([2, 3]) 1518*da0073e9SAndroid Build Coastguard Worker with no_dispatch(): 1519*da0073e9SAndroid Build Coastguard Worker torch.ones([2, 3]) 1520*da0073e9SAndroid Build Coastguard Worker with capture_logs(is_mode=True) as logs2: 1521*da0073e9SAndroid Build Coastguard Worker with LoggingTensorMode(): 1522*da0073e9SAndroid Build Coastguard Worker torch.ones([2, 3]) 1523*da0073e9SAndroid Build Coastguard Worker self.assertEqual(logs1, logs2) 1524*da0073e9SAndroid Build Coastguard Worker 1525*da0073e9SAndroid Build Coastguard Worker def test_shallow_copy_and_detach(self) -> None: 1526*da0073e9SAndroid Build Coastguard Worker seen = set() 1527*da0073e9SAndroid Build Coastguard Worker test_case = self 1528*da0073e9SAndroid Build Coastguard Worker 1529*da0073e9SAndroid Build Coastguard Worker class TestMode(TorchDispatchMode): 1530*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1531*da0073e9SAndroid Build Coastguard Worker tree_map_only( 1532*da0073e9SAndroid Build Coastguard Worker torch.Tensor, lambda t: test_case.assertIn(t, seen), (args, kwargs) 1533*da0073e9SAndroid Build Coastguard Worker ) 1534*da0073e9SAndroid Build Coastguard Worker if kwargs is None: 1535*da0073e9SAndroid Build Coastguard Worker kwargs = {} 1536*da0073e9SAndroid Build Coastguard Worker r = func(*args, **kwargs) 1537*da0073e9SAndroid Build Coastguard Worker tree_map_only(torch.Tensor, lambda t: seen.add(t), r) 1538*da0073e9SAndroid Build Coastguard Worker return r 1539*da0073e9SAndroid Build Coastguard Worker 1540*da0073e9SAndroid Build Coastguard Worker with TestMode(): 1541*da0073e9SAndroid Build Coastguard Worker x = torch.randn(3, requires_grad=True) 1542*da0073e9SAndroid Build Coastguard Worker loss = (x * x).sum() 1543*da0073e9SAndroid Build Coastguard Worker loss.backward() 1544*da0073e9SAndroid Build Coastguard Worker 1545*da0073e9SAndroid Build Coastguard Worker def test_exception_handling(self): 1546*da0073e9SAndroid Build Coastguard Worker class A(torch.Tensor): 1547*da0073e9SAndroid Build Coastguard Worker @staticmethod 1548*da0073e9SAndroid Build Coastguard Worker def __new__(cls, elem): 1549*da0073e9SAndroid Build Coastguard Worker return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) 1550*da0073e9SAndroid Build Coastguard Worker 1551*da0073e9SAndroid Build Coastguard Worker class AMode(TorchDispatchMode): 1552*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1553*da0073e9SAndroid Build Coastguard Worker if func.__name__ == "randn.default": 1554*da0073e9SAndroid Build Coastguard Worker raise RuntimeError 1555*da0073e9SAndroid Build Coastguard Worker return A(torch.zeros(())) 1556*da0073e9SAndroid Build Coastguard Worker 1557*da0073e9SAndroid Build Coastguard Worker with AMode(): 1558*da0073e9SAndroid Build Coastguard Worker try: 1559*da0073e9SAndroid Build Coastguard Worker torch.randn(()) 1560*da0073e9SAndroid Build Coastguard Worker except RuntimeError: 1561*da0073e9SAndroid Build Coastguard Worker pass 1562*da0073e9SAndroid Build Coastguard Worker self.assertTrue(isinstance(torch.zeros(()), A)) 1563*da0073e9SAndroid Build Coastguard Worker 1564*da0073e9SAndroid Build Coastguard Worker def test_with_mode_created_separately(self): 1565*da0073e9SAndroid Build Coastguard Worker class ErrorA(RuntimeError): 1566*da0073e9SAndroid Build Coastguard Worker pass 1567*da0073e9SAndroid Build Coastguard Worker 1568*da0073e9SAndroid Build Coastguard Worker class A(TorchDispatchMode): 1569*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1570*da0073e9SAndroid Build Coastguard Worker raise ErrorA 1571*da0073e9SAndroid Build Coastguard Worker 1572*da0073e9SAndroid Build Coastguard Worker x = A() 1573*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(ErrorA): 1574*da0073e9SAndroid Build Coastguard Worker with x: 1575*da0073e9SAndroid Build Coastguard Worker torch.empty([]) 1576*da0073e9SAndroid Build Coastguard Worker 1577*da0073e9SAndroid Build Coastguard Worker def test_with_nested_modes(self): 1578*da0073e9SAndroid Build Coastguard Worker class ErrorA(RuntimeError): 1579*da0073e9SAndroid Build Coastguard Worker def __init__(self, msg): 1580*da0073e9SAndroid Build Coastguard Worker super().__init__(msg) 1581*da0073e9SAndroid Build Coastguard Worker 1582*da0073e9SAndroid Build Coastguard Worker class A(TorchDispatchMode): 1583*da0073e9SAndroid Build Coastguard Worker def __init__(self, msg): 1584*da0073e9SAndroid Build Coastguard Worker self.msg = msg 1585*da0073e9SAndroid Build Coastguard Worker 1586*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1587*da0073e9SAndroid Build Coastguard Worker raise ErrorA(self.msg) 1588*da0073e9SAndroid Build Coastguard Worker 1589*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(ErrorA, "layer2"): 1590*da0073e9SAndroid Build Coastguard Worker with A("layer1"): 1591*da0073e9SAndroid Build Coastguard Worker with A("layer2"): 1592*da0073e9SAndroid Build Coastguard Worker torch.empty([]) 1593*da0073e9SAndroid Build Coastguard Worker 1594*da0073e9SAndroid Build Coastguard Worker def test_make_subclass_with_modes(self): 1595*da0073e9SAndroid Build Coastguard Worker class ModeTensor(torch.Tensor): 1596*da0073e9SAndroid Build Coastguard Worker def __new__(cls, elem, mode): 1597*da0073e9SAndroid Build Coastguard Worker r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad) 1598*da0073e9SAndroid Build Coastguard Worker r.elem = elem 1599*da0073e9SAndroid Build Coastguard Worker r.mode = mode 1600*da0073e9SAndroid Build Coastguard Worker return r 1601*da0073e9SAndroid Build Coastguard Worker 1602*da0073e9SAndroid Build Coastguard Worker @classmethod 1603*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1604*da0073e9SAndroid Build Coastguard Worker raise NotImplementedError("Shouldn't be here") 1605*da0073e9SAndroid Build Coastguard Worker 1606*da0073e9SAndroid Build Coastguard Worker class Mode(TorchDispatchMode): 1607*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1608*da0073e9SAndroid Build Coastguard Worker def unwrap(e): 1609*da0073e9SAndroid Build Coastguard Worker if isinstance(e, ModeTensor): 1610*da0073e9SAndroid Build Coastguard Worker return e.elem 1611*da0073e9SAndroid Build Coastguard Worker else: 1612*da0073e9SAndroid Build Coastguard Worker return e 1613*da0073e9SAndroid Build Coastguard Worker 1614*da0073e9SAndroid Build Coastguard Worker def wrap(t): 1615*da0073e9SAndroid Build Coastguard Worker if isinstance(t, torch.Tensor): 1616*da0073e9SAndroid Build Coastguard Worker return ModeTensor(t, self) 1617*da0073e9SAndroid Build Coastguard Worker else: 1618*da0073e9SAndroid Build Coastguard Worker return t 1619*da0073e9SAndroid Build Coastguard Worker 1620*da0073e9SAndroid Build Coastguard Worker return wrap(func(*tuple(unwrap(a) for a in args), **kwargs)) 1621*da0073e9SAndroid Build Coastguard Worker 1622*da0073e9SAndroid Build Coastguard Worker class BasicMode(TorchDispatchMode): 1623*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1624*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 1625*da0073e9SAndroid Build Coastguard Worker 1626*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(4.0) 1627*da0073e9SAndroid Build Coastguard Worker with Mode(): 1628*da0073e9SAndroid Build Coastguard Worker y = x + x 1629*da0073e9SAndroid Build Coastguard Worker z = y + y 1630*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(y, ModeTensor) 1631*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(z, ModeTensor) 1632*da0073e9SAndroid Build Coastguard Worker 1633*da0073e9SAndroid Build Coastguard Worker with Mode(): 1634*da0073e9SAndroid Build Coastguard Worker with BasicMode(): # we can't nest two modes that call make_subclass because it only accepts vanilla tensors 1635*da0073e9SAndroid Build Coastguard Worker y = x + x 1636*da0073e9SAndroid Build Coastguard Worker z = y + y 1637*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(y, ModeTensor) 1638*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(z, ModeTensor) 1639*da0073e9SAndroid Build Coastguard Worker 1640*da0073e9SAndroid Build Coastguard Worker assert self.assertRaisesRegex( 1641*da0073e9SAndroid Build Coastguard Worker RuntimeError, 1642*da0073e9SAndroid Build Coastguard Worker "subclass Mode but.* associated to a python object of type Mode", 1643*da0073e9SAndroid Build Coastguard Worker ) 1644*da0073e9SAndroid Build Coastguard Worker 1645*da0073e9SAndroid Build Coastguard Worker def test_notimplemented_mode(self): 1646*da0073e9SAndroid Build Coastguard Worker sub_count = 0 1647*da0073e9SAndroid Build Coastguard Worker 1648*da0073e9SAndroid Build Coastguard Worker class PoliteMode(TorchDispatchMode): 1649*da0073e9SAndroid Build Coastguard Worker def __init__(self) -> None: 1650*da0073e9SAndroid Build Coastguard Worker self.pre_count = 0 1651*da0073e9SAndroid Build Coastguard Worker self.post_count = 0 1652*da0073e9SAndroid Build Coastguard Worker 1653*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1654*da0073e9SAndroid Build Coastguard Worker self.pre_count += 1 1655*da0073e9SAndroid Build Coastguard Worker if any(t is not torch.Tensor for t in types): 1656*da0073e9SAndroid Build Coastguard Worker return NotImplemented 1657*da0073e9SAndroid Build Coastguard Worker self.post_count += 1 1658*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 1659*da0073e9SAndroid Build Coastguard Worker 1660*da0073e9SAndroid Build Coastguard Worker class SubTensor(torch.Tensor): 1661*da0073e9SAndroid Build Coastguard Worker def __new__(cls, elem): 1662*da0073e9SAndroid Build Coastguard Worker r = torch.Tensor._make_wrapper_subclass(cls, elem.shape) 1663*da0073e9SAndroid Build Coastguard Worker r.elem = elem 1664*da0073e9SAndroid Build Coastguard Worker return r 1665*da0073e9SAndroid Build Coastguard Worker 1666*da0073e9SAndroid Build Coastguard Worker @classmethod 1667*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1668*da0073e9SAndroid Build Coastguard Worker nonlocal sub_count 1669*da0073e9SAndroid Build Coastguard Worker sub_count += 1 1670*da0073e9SAndroid Build Coastguard Worker 1671*da0073e9SAndroid Build Coastguard Worker def unwrap(t): 1672*da0073e9SAndroid Build Coastguard Worker if isinstance(t, SubTensor): 1673*da0073e9SAndroid Build Coastguard Worker return t.elem 1674*da0073e9SAndroid Build Coastguard Worker else: 1675*da0073e9SAndroid Build Coastguard Worker return t 1676*da0073e9SAndroid Build Coastguard Worker 1677*da0073e9SAndroid Build Coastguard Worker return func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) 1678*da0073e9SAndroid Build Coastguard Worker 1679*da0073e9SAndroid Build Coastguard Worker a = SubTensor(torch.randn(2)) 1680*da0073e9SAndroid Build Coastguard Worker with PoliteMode() as mode: 1681*da0073e9SAndroid Build Coastguard Worker a.abs() 1682*da0073e9SAndroid Build Coastguard Worker 1683*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mode.pre_count, 2) 1684*da0073e9SAndroid Build Coastguard Worker self.assertEqual(mode.post_count, 1) 1685*da0073e9SAndroid Build Coastguard Worker self.assertEqual(sub_count, 1) 1686*da0073e9SAndroid Build Coastguard Worker 1687*da0073e9SAndroid Build Coastguard Worker # make sure this doesn't error 1688*da0073e9SAndroid Build Coastguard Worker with PoliteMode(): 1689*da0073e9SAndroid Build Coastguard Worker with PoliteMode(): 1690*da0073e9SAndroid Build Coastguard Worker a.abs() 1691*da0073e9SAndroid Build Coastguard Worker 1692*da0073e9SAndroid Build Coastguard Worker def test_nesting_same_mode(self): 1693*da0073e9SAndroid Build Coastguard Worker # If the pushed mode is the same instance as the current mode, we allow pushing an already active mode. 1694*da0073e9SAndroid Build Coastguard Worker 1695*da0073e9SAndroid Build Coastguard Worker with capture_logs(is_mode=True) as logs: 1696*da0073e9SAndroid Build Coastguard Worker with LoggingTensorMode() as reenabled: 1697*da0073e9SAndroid Build Coastguard Worker with reenabled: 1698*da0073e9SAndroid Build Coastguard Worker torch.empty([]) 1699*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 1700*da0073e9SAndroid Build Coastguard Worker "\n".join(logs), 1701*da0073e9SAndroid Build Coastguard Worker """\ 1702*da0073e9SAndroid Build Coastguard Worker$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False) 1703*da0073e9SAndroid Build Coastguard Worker$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""", 1704*da0073e9SAndroid Build Coastguard Worker ) 1705*da0073e9SAndroid Build Coastguard Worker 1706*da0073e9SAndroid Build Coastguard Worker def test_error_using_class_method_on_mode(self): 1707*da0073e9SAndroid Build Coastguard Worker class A(TorchDispatchMode): 1708*da0073e9SAndroid Build Coastguard Worker @classmethod 1709*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1710*da0073e9SAndroid Build Coastguard Worker return func(args, kwargs) 1711*da0073e9SAndroid Build Coastguard Worker 1712*da0073e9SAndroid Build Coastguard Worker x = torch.tensor(5.0) 1713*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1714*da0073e9SAndroid Build Coastguard Worker RuntimeError, "classmethod is not supported, please make it a plain method" 1715*da0073e9SAndroid Build Coastguard Worker ): 1716*da0073e9SAndroid Build Coastguard Worker with A(): 1717*da0073e9SAndroid Build Coastguard Worker x + x 1718*da0073e9SAndroid Build Coastguard Worker 1719*da0073e9SAndroid Build Coastguard Worker def test_get_cur_mode(self): 1720*da0073e9SAndroid Build Coastguard Worker class A(TorchDispatchMode): 1721*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1722*da0073e9SAndroid Build Coastguard Worker pass 1723*da0073e9SAndroid Build Coastguard Worker 1724*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_get_current_dispatch_mode(), None) 1725*da0073e9SAndroid Build Coastguard Worker 1726*da0073e9SAndroid Build Coastguard Worker with A() as mode1: 1727*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_get_current_dispatch_mode(), mode1) 1728*da0073e9SAndroid Build Coastguard Worker 1729*da0073e9SAndroid Build Coastguard Worker with mode1: 1730*da0073e9SAndroid Build Coastguard Worker with A() as mode2: 1731*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_get_current_dispatch_mode(), mode2) 1732*da0073e9SAndroid Build Coastguard Worker 1733*da0073e9SAndroid Build Coastguard Worker def test_get_mode_stack(self): 1734*da0073e9SAndroid Build Coastguard Worker class A(TorchDispatchMode): 1735*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1736*da0073e9SAndroid Build Coastguard Worker pass 1737*da0073e9SAndroid Build Coastguard Worker 1738*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_get_current_dispatch_mode_stack(), []) 1739*da0073e9SAndroid Build Coastguard Worker 1740*da0073e9SAndroid Build Coastguard Worker with A() as mode1: 1741*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_get_current_dispatch_mode_stack(), [mode1]) 1742*da0073e9SAndroid Build Coastguard Worker 1743*da0073e9SAndroid Build Coastguard Worker with mode1: 1744*da0073e9SAndroid Build Coastguard Worker with A() as mode2: 1745*da0073e9SAndroid Build Coastguard Worker self.assertEqual(_get_current_dispatch_mode_stack(), [mode1, mode2]) 1746*da0073e9SAndroid Build Coastguard Worker 1747*da0073e9SAndroid Build Coastguard Worker def test_all_same_mode(self): 1748*da0073e9SAndroid Build Coastguard Worker x = LoggingTensorMode() 1749*da0073e9SAndroid Build Coastguard Worker y = LoggingTensorMode() 1750*da0073e9SAndroid Build Coastguard Worker self.assertTrue(all_same_mode([x, x, x])) 1751*da0073e9SAndroid Build Coastguard Worker self.assertFalse(all_same_mode([x, None])) 1752*da0073e9SAndroid Build Coastguard Worker self.assertFalse(all_same_mode([x, y])) 1753*da0073e9SAndroid Build Coastguard Worker 1754*da0073e9SAndroid Build Coastguard Worker def test_mode_detection(self): 1755*da0073e9SAndroid Build Coastguard Worker class InfraMode(TorchDispatchMode): 1756*da0073e9SAndroid Build Coastguard Worker @classmethod 1757*da0073e9SAndroid Build Coastguard Worker def is_infra_mode(cls): 1758*da0073e9SAndroid Build Coastguard Worker return True 1759*da0073e9SAndroid Build Coastguard Worker 1760*da0073e9SAndroid Build Coastguard Worker class NonInfraMode(TorchDispatchMode): 1761*da0073e9SAndroid Build Coastguard Worker pass 1762*da0073e9SAndroid Build Coastguard Worker 1763*da0073e9SAndroid Build Coastguard Worker with InfraMode(): 1764*da0073e9SAndroid Build Coastguard Worker self.assertTrue(is_in_torch_dispatch_mode()) 1765*da0073e9SAndroid Build Coastguard Worker self.assertFalse(is_in_torch_dispatch_mode(include_infra_modes=False)) 1766*da0073e9SAndroid Build Coastguard Worker with NonInfraMode(): 1767*da0073e9SAndroid Build Coastguard Worker self.assertTrue(is_in_torch_dispatch_mode()) 1768*da0073e9SAndroid Build Coastguard Worker self.assertTrue(is_in_torch_dispatch_mode(include_infra_modes=False)) 1769*da0073e9SAndroid Build Coastguard Worker with InfraMode(): 1770*da0073e9SAndroid Build Coastguard Worker self.assertTrue(is_in_torch_dispatch_mode()) 1771*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 1772*da0073e9SAndroid Build Coastguard Worker is_in_torch_dispatch_mode(include_infra_modes=False) 1773*da0073e9SAndroid Build Coastguard Worker ) 1774*da0073e9SAndroid Build Coastguard Worker 1775*da0073e9SAndroid Build Coastguard Worker self.assertTrue(is_in_torch_dispatch_mode()) 1776*da0073e9SAndroid Build Coastguard Worker self.assertTrue(is_in_torch_dispatch_mode(include_infra_modes=False)) 1777*da0073e9SAndroid Build Coastguard Worker self.assertTrue(is_in_torch_dispatch_mode()) 1778*da0073e9SAndroid Build Coastguard Worker self.assertFalse(is_in_torch_dispatch_mode(include_infra_modes=False)) 1779*da0073e9SAndroid Build Coastguard Worker 1780*da0073e9SAndroid Build Coastguard Worker self.assertFalse(is_in_torch_dispatch_mode()) 1781*da0073e9SAndroid Build Coastguard Worker self.assertFalse(is_in_torch_dispatch_mode(include_infra_modes=False)) 1782*da0073e9SAndroid Build Coastguard Worker 1783*da0073e9SAndroid Build Coastguard Worker def test_tolist_numpy_with_torch_dispatch_mode(self) -> None: 1784*da0073e9SAndroid Build Coastguard Worker x = LoggingTensor(torch.tensor([2.0, 3.0])) 1785*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1786*da0073e9SAndroid Build Coastguard Worker RuntimeError, "is not supported for tensor subclasses." 1787*da0073e9SAndroid Build Coastguard Worker ): 1788*da0073e9SAndroid Build Coastguard Worker x.tolist() 1789*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex( 1790*da0073e9SAndroid Build Coastguard Worker RuntimeError, "is not supported for tensor subclasses." 1791*da0073e9SAndroid Build Coastguard Worker ): 1792*da0073e9SAndroid Build Coastguard Worker x.numpy() 1793*da0073e9SAndroid Build Coastguard Worker with self.assertRaises(AssertionError): 1794*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x, None) 1795*da0073e9SAndroid Build Coastguard Worker 1796*da0073e9SAndroid Build Coastguard Worker def test_record_stream(self) -> None: 1797*da0073e9SAndroid Build Coastguard Worker class TestMode(TorchDispatchMode): 1798*da0073e9SAndroid Build Coastguard Worker def __init__(self, testcase): 1799*da0073e9SAndroid Build Coastguard Worker self.testcase = testcase 1800*da0073e9SAndroid Build Coastguard Worker 1801*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1802*da0073e9SAndroid Build Coastguard Worker self.testcase.assertEqual(func.name(), "aten::record_stream") 1803*da0073e9SAndroid Build Coastguard Worker self.testcase.assertIsInstance(args[0], torch.Tensor) 1804*da0073e9SAndroid Build Coastguard Worker self.testcase.assertIsInstance(args[1], torch.Stream) 1805*da0073e9SAndroid Build Coastguard Worker self.testcase.assertEqual(args[1].stream_id, 1) 1806*da0073e9SAndroid Build Coastguard Worker self.testcase.assertEqual(args[1].device_index, 2) 1807*da0073e9SAndroid Build Coastguard Worker self.testcase.assertEqual(args[1].device_type, 3) 1808*da0073e9SAndroid Build Coastguard Worker 1809*da0073e9SAndroid Build Coastguard Worker t = torch.tensor(5.0) 1810*da0073e9SAndroid Build Coastguard Worker s = torch.Stream(stream_id=1, device_index=2, device_type=3) 1811*da0073e9SAndroid Build Coastguard Worker with TestMode(self): 1812*da0073e9SAndroid Build Coastguard Worker t.record_stream(s) 1813*da0073e9SAndroid Build Coastguard Worker 1814*da0073e9SAndroid Build Coastguard Worker def test_return_stream(self) -> None: 1815*da0073e9SAndroid Build Coastguard Worker with _scoped_library("test_return_stream", "DEF") as l_def: 1816*da0073e9SAndroid Build Coastguard Worker l_def.define("return_stream(Tensor self) -> Stream") 1817*da0073e9SAndroid Build Coastguard Worker with _scoped_library("test_return_stream", "IMPL", "CPU") as l_impl: 1818*da0073e9SAndroid Build Coastguard Worker l_impl.impl( 1819*da0073e9SAndroid Build Coastguard Worker "return_stream", 1820*da0073e9SAndroid Build Coastguard Worker lambda _: torch.Stream(stream_id=0, device_index=1, device_type=2), 1821*da0073e9SAndroid Build Coastguard Worker ) 1822*da0073e9SAndroid Build Coastguard Worker 1823*da0073e9SAndroid Build Coastguard Worker class TestMode(TorchDispatchMode): 1824*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(self, func, types, args=(), kwargs=None): 1825*da0073e9SAndroid Build Coastguard Worker return torch.Stream(stream_id=1, device_index=2, device_type=3) 1826*da0073e9SAndroid Build Coastguard Worker 1827*da0073e9SAndroid Build Coastguard Worker t = torch.tensor(5.0) 1828*da0073e9SAndroid Build Coastguard Worker s = torch.ops.test_return_stream.return_stream(t) 1829*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(s, torch.Stream) 1830*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s.stream_id, 0) 1831*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s.device_index, 1) 1832*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s.device_type, 2) 1833*da0073e9SAndroid Build Coastguard Worker 1834*da0073e9SAndroid Build Coastguard Worker with TestMode(): 1835*da0073e9SAndroid Build Coastguard Worker s = torch.ops.test_return_stream.return_stream(t) 1836*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(s, torch.Stream) 1837*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s.stream_id, 1) 1838*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s.device_index, 2) 1839*da0073e9SAndroid Build Coastguard Worker self.assertEqual(s.device_type, 3) 1840*da0073e9SAndroid Build Coastguard Worker 1841*da0073e9SAndroid Build Coastguard Worker def test_subclass_autograd_device_check(self) -> None: 1842*da0073e9SAndroid Build Coastguard Worker class NonWrapperSubclass(torch.Tensor): 1843*da0073e9SAndroid Build Coastguard Worker elem: torch.Tensor 1844*da0073e9SAndroid Build Coastguard Worker 1845*da0073e9SAndroid Build Coastguard Worker __slots__ = ["elem"] 1846*da0073e9SAndroid Build Coastguard Worker 1847*da0073e9SAndroid Build Coastguard Worker @staticmethod 1848*da0073e9SAndroid Build Coastguard Worker def __new__(cls, elem, *args, **kwargs): 1849*da0073e9SAndroid Build Coastguard Worker # Wrong device here! 1850*da0073e9SAndroid Build Coastguard Worker r = torch.Tensor._make_subclass( 1851*da0073e9SAndroid Build Coastguard Worker cls, elem.to("meta"), elem.requires_grad 1852*da0073e9SAndroid Build Coastguard Worker ) 1853*da0073e9SAndroid Build Coastguard Worker # ...the real tensor is held as an element on the tensor. 1854*da0073e9SAndroid Build Coastguard Worker r.elem = elem 1855*da0073e9SAndroid Build Coastguard Worker return r 1856*da0073e9SAndroid Build Coastguard Worker 1857*da0073e9SAndroid Build Coastguard Worker @classmethod 1858*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1859*da0073e9SAndroid Build Coastguard Worker def unwrap(e): 1860*da0073e9SAndroid Build Coastguard Worker return e.elem if isinstance(e, NonWrapperSubclass) else e 1861*da0073e9SAndroid Build Coastguard Worker 1862*da0073e9SAndroid Build Coastguard Worker def wrap(e): 1863*da0073e9SAndroid Build Coastguard Worker return NonWrapperSubclass(e) if isinstance(e, torch.Tensor) else e 1864*da0073e9SAndroid Build Coastguard Worker 1865*da0073e9SAndroid Build Coastguard Worker rs = tree_map( 1866*da0073e9SAndroid Build Coastguard Worker wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) 1867*da0073e9SAndroid Build Coastguard Worker ) 1868*da0073e9SAndroid Build Coastguard Worker logging.getLogger("NonWrapperSubclass").info( 1869*da0073e9SAndroid Build Coastguard Worker f"{func.__module__}.{func.__name__}", # noqa: G004 1870*da0073e9SAndroid Build Coastguard Worker args, 1871*da0073e9SAndroid Build Coastguard Worker kwargs, 1872*da0073e9SAndroid Build Coastguard Worker rs, 1873*da0073e9SAndroid Build Coastguard Worker ) 1874*da0073e9SAndroid Build Coastguard Worker return rs 1875*da0073e9SAndroid Build Coastguard Worker 1876*da0073e9SAndroid Build Coastguard Worker x = NonWrapperSubclass(torch.tensor([3.0, 4.0], requires_grad=True)) 1877*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2, requires_grad=True) 1878*da0073e9SAndroid Build Coastguard Worker z = x * y 1879*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(z, NonWrapperSubclass) 1880*da0073e9SAndroid Build Coastguard Worker z.sum().backward(torch.tensor(1)) 1881*da0073e9SAndroid Build Coastguard Worker self.assertEqual(x.grad, y) 1882*da0073e9SAndroid Build Coastguard Worker self.assertEqual(y.grad, x) 1883*da0073e9SAndroid Build Coastguard Worker 1884*da0073e9SAndroid Build Coastguard Worker def test_none_wrapping(self): 1885*da0073e9SAndroid Build Coastguard Worker # A Tensor subclass that returns None when doing add 1886*da0073e9SAndroid Build Coastguard Worker # See LoggingTensor above for more details on the subclass 1887*da0073e9SAndroid Build Coastguard Worker class SubclassWithNone(torch.Tensor): 1888*da0073e9SAndroid Build Coastguard Worker @staticmethod 1889*da0073e9SAndroid Build Coastguard Worker def __new__(cls, elem, *args, **kwargs): 1890*da0073e9SAndroid Build Coastguard Worker r = torch.Tensor._make_wrapper_subclass( 1891*da0073e9SAndroid Build Coastguard Worker cls, 1892*da0073e9SAndroid Build Coastguard Worker elem.size(), 1893*da0073e9SAndroid Build Coastguard Worker dtype=elem.dtype, 1894*da0073e9SAndroid Build Coastguard Worker layout=elem.layout, 1895*da0073e9SAndroid Build Coastguard Worker device=elem.device, 1896*da0073e9SAndroid Build Coastguard Worker requires_grad=elem.requires_grad, 1897*da0073e9SAndroid Build Coastguard Worker ) 1898*da0073e9SAndroid Build Coastguard Worker r.elem = elem 1899*da0073e9SAndroid Build Coastguard Worker return r 1900*da0073e9SAndroid Build Coastguard Worker 1901*da0073e9SAndroid Build Coastguard Worker @classmethod 1902*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1903*da0073e9SAndroid Build Coastguard Worker def unwrap(e): 1904*da0073e9SAndroid Build Coastguard Worker return e.elem if isinstance(e, SubclassWithNone) else e 1905*da0073e9SAndroid Build Coastguard Worker 1906*da0073e9SAndroid Build Coastguard Worker def wrap(e): 1907*da0073e9SAndroid Build Coastguard Worker return SubclassWithNone(e) if isinstance(e, torch.Tensor) else e 1908*da0073e9SAndroid Build Coastguard Worker 1909*da0073e9SAndroid Build Coastguard Worker rs = tree_map( 1910*da0073e9SAndroid Build Coastguard Worker wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs)) 1911*da0073e9SAndroid Build Coastguard Worker ) 1912*da0073e9SAndroid Build Coastguard Worker if func.overloadpacket.__name__ == "add": 1913*da0073e9SAndroid Build Coastguard Worker return None 1914*da0073e9SAndroid Build Coastguard Worker else: 1915*da0073e9SAndroid Build Coastguard Worker return rs 1916*da0073e9SAndroid Build Coastguard Worker 1917*da0073e9SAndroid Build Coastguard Worker x = SubclassWithNone(torch.rand(2)) 1918*da0073e9SAndroid Build Coastguard Worker # Make sure both run without error 1919*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(x * 2, SubclassWithNone) 1920*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(x + 2) 1921*da0073e9SAndroid Build Coastguard Worker 1922*da0073e9SAndroid Build Coastguard Worker x.requires_grad_() 1923*da0073e9SAndroid Build Coastguard Worker out = x.acos().sum() 1924*da0073e9SAndroid Build Coastguard Worker 1925*da0073e9SAndroid Build Coastguard Worker # The backward of acos does add then rsqrt so here we make sure that the 1926*da0073e9SAndroid Build Coastguard Worker # undefined Tensor generated by the user code is nicely handled. 1927*da0073e9SAndroid Build Coastguard Worker # If acos formula changes in the future, this can be replaced by any other 1928*da0073e9SAndroid Build Coastguard Worker # function that does add then something in the backward in a composite way 1929*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "but got None"): 1930*da0073e9SAndroid Build Coastguard Worker out.backward() 1931*da0073e9SAndroid Build Coastguard Worker 1932*da0073e9SAndroid Build Coastguard Worker def test_storage_can_be_converted_to_python_object(self): 1933*da0073e9SAndroid Build Coastguard Worker s = torch.Storage() 1934*da0073e9SAndroid Build Coastguard Worker z = LoggingTensor(torch.empty([])) 1935*da0073e9SAndroid Build Coastguard Worker z.set_(s) 1936*da0073e9SAndroid Build Coastguard Worker 1937*da0073e9SAndroid Build Coastguard Worker def test_autograd_in_attr(self): 1938*da0073e9SAndroid Build Coastguard Worker # We want the wrapped Tensor to require gradients! 1939*da0073e9SAndroid Build Coastguard Worker true_t = torch.rand(2, requires_grad=True) 1940*da0073e9SAndroid Build Coastguard Worker t = LoggingTensorReentrant(true_t) 1941*da0073e9SAndroid Build Coastguard Worker 1942*da0073e9SAndroid Build Coastguard Worker out = t + 2 1943*da0073e9SAndroid Build Coastguard Worker 1944*da0073e9SAndroid Build Coastguard Worker self.assertFalse(out.requires_grad) 1945*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(out.grad_fn) 1946*da0073e9SAndroid Build Coastguard Worker 1947*da0073e9SAndroid Build Coastguard Worker self.assertTrue(out.elem.requires_grad) 1948*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(out.elem.grad_fn) 1949*da0073e9SAndroid Build Coastguard Worker 1950*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(RuntimeError, "does not require grad"): 1951*da0073e9SAndroid Build Coastguard Worker out.sum().backward() 1952*da0073e9SAndroid Build Coastguard Worker 1953*da0073e9SAndroid Build Coastguard Worker out.elem.sum().backward() 1954*da0073e9SAndroid Build Coastguard Worker 1955*da0073e9SAndroid Build Coastguard Worker self.assertIsNone(t.grad) 1956*da0073e9SAndroid Build Coastguard Worker self.assertIsNotNone(t.elem.grad) 1957*da0073e9SAndroid Build Coastguard Worker 1958*da0073e9SAndroid Build Coastguard Worker def test_dispatch_super_call(self): 1959*da0073e9SAndroid Build Coastguard Worker called = [] 1960*da0073e9SAndroid Build Coastguard Worker 1961*da0073e9SAndroid Build Coastguard Worker class SubTensor(torch.Tensor): 1962*da0073e9SAndroid Build Coastguard Worker @staticmethod 1963*da0073e9SAndroid Build Coastguard Worker def __new__(cls, elem): 1964*da0073e9SAndroid Build Coastguard Worker return torch.Tensor._make_subclass(cls, elem) 1965*da0073e9SAndroid Build Coastguard Worker 1966*da0073e9SAndroid Build Coastguard Worker @classmethod 1967*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1968*da0073e9SAndroid Build Coastguard Worker called.append(func) 1969*da0073e9SAndroid Build Coastguard Worker return super().__torch_dispatch__(func, types, args, kwargs) 1970*da0073e9SAndroid Build Coastguard Worker 1971*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2) 1972*da0073e9SAndroid Build Coastguard Worker y = torch.randn(2) 1973*da0073e9SAndroid Build Coastguard Worker self.assertEqual(SubTensor(x) + SubTensor(y), x + y) 1974*da0073e9SAndroid Build Coastguard Worker self.assertEqual(called, [torch.ops.aten.add.Tensor]) 1975*da0073e9SAndroid Build Coastguard Worker 1976*da0073e9SAndroid Build Coastguard Worker def test_dispatch_super_call_list_arg(self): 1977*da0073e9SAndroid Build Coastguard Worker called = [] 1978*da0073e9SAndroid Build Coastguard Worker 1979*da0073e9SAndroid Build Coastguard Worker class SubTensorWithListArg(torch.Tensor): 1980*da0073e9SAndroid Build Coastguard Worker @staticmethod 1981*da0073e9SAndroid Build Coastguard Worker def __new__(cls, elem): 1982*da0073e9SAndroid Build Coastguard Worker return torch.Tensor._make_subclass(cls, elem) 1983*da0073e9SAndroid Build Coastguard Worker 1984*da0073e9SAndroid Build Coastguard Worker @classmethod 1985*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 1986*da0073e9SAndroid Build Coastguard Worker called.append(func) 1987*da0073e9SAndroid Build Coastguard Worker return super().__torch_dispatch__(func, types, list(args), kwargs) 1988*da0073e9SAndroid Build Coastguard Worker 1989*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2) 1990*da0073e9SAndroid Build Coastguard Worker self.assertEqual(SubTensorWithListArg(x).neg(), x.neg()) 1991*da0073e9SAndroid Build Coastguard Worker self.assertEqual(called, [torch.ops.aten.neg.default]) 1992*da0073e9SAndroid Build Coastguard Worker 1993*da0073e9SAndroid Build Coastguard Worker def test_dispatch_super_dont_autograd(self): 1994*da0073e9SAndroid Build Coastguard Worker called = [] 1995*da0073e9SAndroid Build Coastguard Worker 1996*da0073e9SAndroid Build Coastguard Worker class SubTensor(torch.Tensor): 1997*da0073e9SAndroid Build Coastguard Worker @staticmethod 1998*da0073e9SAndroid Build Coastguard Worker def __new__(cls, elem): 1999*da0073e9SAndroid Build Coastguard Worker return torch.Tensor._make_subclass(cls, elem, elem.requires_grad) 2000*da0073e9SAndroid Build Coastguard Worker 2001*da0073e9SAndroid Build Coastguard Worker @classmethod 2002*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 2003*da0073e9SAndroid Build Coastguard Worker called.append(func) 2004*da0073e9SAndroid Build Coastguard Worker # This argument still requires grad because it was passed 2005*da0073e9SAndroid Build Coastguard Worker # through directly... 2006*da0073e9SAndroid Build Coastguard Worker self.assertTrue(args[0].requires_grad) 2007*da0073e9SAndroid Build Coastguard Worker r = super().__torch_dispatch__(func, types, args, kwargs) 2008*da0073e9SAndroid Build Coastguard Worker # But the output better not require grad, because that means 2009*da0073e9SAndroid Build Coastguard Worker # you did autograd again in torch dispatch (oops) 2010*da0073e9SAndroid Build Coastguard Worker self.assertFalse(r.requires_grad) 2011*da0073e9SAndroid Build Coastguard Worker return r 2012*da0073e9SAndroid Build Coastguard Worker 2013*da0073e9SAndroid Build Coastguard Worker x = SubTensor(torch.randn(2, requires_grad=True)) 2014*da0073e9SAndroid Build Coastguard Worker x.neg() 2015*da0073e9SAndroid Build Coastguard Worker self.assertEqual(called, [torch.ops.aten.neg.default]) 2016*da0073e9SAndroid Build Coastguard Worker 2017*da0073e9SAndroid Build Coastguard Worker def test_set_data(self): 2018*da0073e9SAndroid Build Coastguard Worker called = 0 2019*da0073e9SAndroid Build Coastguard Worker 2020*da0073e9SAndroid Build Coastguard Worker class SubTensor(torch.Tensor): 2021*da0073e9SAndroid Build Coastguard Worker @classmethod 2022*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 2023*da0073e9SAndroid Build Coastguard Worker nonlocal called 2024*da0073e9SAndroid Build Coastguard Worker called += 1 2025*da0073e9SAndroid Build Coastguard Worker return super().__torch_dispatch__(func, types, args, kwargs) 2026*da0073e9SAndroid Build Coastguard Worker 2027*da0073e9SAndroid Build Coastguard Worker x = SubTensor(torch.empty(2)) 2028*da0073e9SAndroid Build Coastguard Worker x.data 2029*da0073e9SAndroid Build Coastguard Worker self.assertEqual(called, 1) 2030*da0073e9SAndroid Build Coastguard Worker x.data = torch.empty(2) 2031*da0073e9SAndroid Build Coastguard Worker self.assertEqual(called, 1) 2032*da0073e9SAndroid Build Coastguard Worker x.data 2033*da0073e9SAndroid Build Coastguard Worker self.assertEqual(called, 2) 2034*da0073e9SAndroid Build Coastguard Worker self.assertIs(type(x), SubTensor) 2035*da0073e9SAndroid Build Coastguard Worker x.set_(torch.empty(2)) 2036*da0073e9SAndroid Build Coastguard Worker self.assertEqual(called, 3) 2037*da0073e9SAndroid Build Coastguard Worker x.data 2038*da0073e9SAndroid Build Coastguard Worker self.assertEqual(called, 4) 2039*da0073e9SAndroid Build Coastguard Worker self.assertIs(type(x), SubTensor) 2040*da0073e9SAndroid Build Coastguard Worker 2041*da0073e9SAndroid Build Coastguard Worker def test_construct_int_tensor(self): 2042*da0073e9SAndroid Build Coastguard Worker class SubTensor(torch.Tensor): 2043*da0073e9SAndroid Build Coastguard Worker pass 2044*da0073e9SAndroid Build Coastguard Worker 2045*da0073e9SAndroid Build Coastguard Worker # should not fail 2046*da0073e9SAndroid Build Coastguard Worker SubTensor(torch.zeros(2, dtype=torch.int)) 2047*da0073e9SAndroid Build Coastguard Worker 2048*da0073e9SAndroid Build Coastguard Worker def test_multiple_ops_subclass(self): 2049*da0073e9SAndroid Build Coastguard Worker # This is a Direct Subclass, don't do that! 2050*da0073e9SAndroid Build Coastguard Worker class MySubclass(torch.Tensor): 2051*da0073e9SAndroid Build Coastguard Worker @staticmethod 2052*da0073e9SAndroid Build Coastguard Worker def __new__(cls, elem): 2053*da0073e9SAndroid Build Coastguard Worker r = torch.Tensor._make_subclass(cls, elem) 2054*da0073e9SAndroid Build Coastguard Worker return r 2055*da0073e9SAndroid Build Coastguard Worker 2056*da0073e9SAndroid Build Coastguard Worker @classmethod 2057*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 2058*da0073e9SAndroid Build Coastguard Worker with no_dispatch(): 2059*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 2060*da0073e9SAndroid Build Coastguard Worker 2061*da0073e9SAndroid Build Coastguard Worker x = MySubclass(torch.rand(2, 2, dtype=torch.complex64)) 2062*da0073e9SAndroid Build Coastguard Worker y = x.conj() 2063*da0073e9SAndroid Build Coastguard Worker # Details of the bug that this tests for: 2064*da0073e9SAndroid Build Coastguard Worker # Here, y dispatch keys are: {PythonTLSSnapshot, AutogradCPU, Conjugate, Python, CPU} 2065*da0073e9SAndroid Build Coastguard Worker # There are a few calls to the dispatcher that are going to happen here: 2066*da0073e9SAndroid Build Coastguard Worker # - call_exp: User calling exp on y 2067*da0073e9SAndroid Build Coastguard Worker # - PythonTLSSnapshot: records the TLS on entry and redispatch 2068*da0073e9SAndroid Build Coastguard Worker # - AutogradCPU: no input requires grad, so does nothing and redispatch 2069*da0073e9SAndroid Build Coastguard Worker # - Conjugate: no special implementation for exp: use the fallback that 2070*da0073e9SAndroid Build Coastguard Worker # first clone the Tensor (to materialize the conj) then redispatch 2071*da0073e9SAndroid Build Coastguard Worker # - call_clone: conjugate fallback calling clone on y 2072*da0073e9SAndroid Build Coastguard Worker # - PythonTLSSnapshot: records the TLS on entry and redispatch 2073*da0073e9SAndroid Build Coastguard Worker # - (AutogradCPU: skipped as autograd added itself to the exclude set above) 2074*da0073e9SAndroid Build Coastguard Worker # - Conjugate: special implementation for clone: just skip this key 2075*da0073e9SAndroid Build Coastguard Worker # - Python: Reset the TLS based on the snapshot above and call the user implementation (this 2076*da0073e9SAndroid Build Coastguard Worker # actually calls into the dispatcher again but since we disable both our keys 2077*da0073e9SAndroid Build Coastguard Worker # before, not detailed here) 2078*da0073e9SAndroid Build Coastguard Worker # - exit Python: restore the TLS and exit 2079*da0073e9SAndroid Build Coastguard Worker # - exit Conjugate: nothing was inplace so just exit 2080*da0073e9SAndroid Build Coastguard Worker # - exit PythonTLSSnapshot: done with this call, reset the saved TLS to empty 2081*da0073e9SAndroid Build Coastguard Worker # - Python: Reset the TLS again based on the snapshot. <- this used to fail 2082*da0073e9SAndroid Build Coastguard Worker # - More steps.... 2083*da0073e9SAndroid Build Coastguard Worker y.exp() 2084*da0073e9SAndroid Build Coastguard Worker 2085*da0073e9SAndroid Build Coastguard Worker @staticmethod 2086*da0073e9SAndroid Build Coastguard Worker def subclass_helper(cls, data, use_wrapper_subclass, **kwargs): 2087*da0073e9SAndroid Build Coastguard Worker if use_wrapper_subclass: 2088*da0073e9SAndroid Build Coastguard Worker kwargs["device"] = data.device 2089*da0073e9SAndroid Build Coastguard Worker kwargs["dtype"] = data.dtype 2090*da0073e9SAndroid Build Coastguard Worker kwargs["layout"] = data.layout 2091*da0073e9SAndroid Build Coastguard Worker kwargs["requires_grad"] = True 2092*da0073e9SAndroid Build Coastguard Worker return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs) # type: ignore[attr-defined] 2093*da0073e9SAndroid Build Coastguard Worker else: 2094*da0073e9SAndroid Build Coastguard Worker return torch.Tensor._make_subclass(cls, data, True, **kwargs) 2095*da0073e9SAndroid Build Coastguard Worker 2096*da0073e9SAndroid Build Coastguard Worker def test_is_contiguous_slow_path(self): 2097*da0073e9SAndroid Build Coastguard Worker data = torch.randn(3, 3) 2098*da0073e9SAndroid Build Coastguard Worker contiguous_data = data.clone() 2099*da0073e9SAndroid Build Coastguard Worker not_contiguous_data = torch.as_strided(data.clone(), (2, 2), (1, 2)) 2100*da0073e9SAndroid Build Coastguard Worker 2101*da0073e9SAndroid Build Coastguard Worker for use_wrapper_subclass in [True, False]: 2102*da0073e9SAndroid Build Coastguard Worker 2103*da0073e9SAndroid Build Coastguard Worker class ExampleTensor1(torch.Tensor): 2104*da0073e9SAndroid Build Coastguard Worker @staticmethod 2105*da0073e9SAndroid Build Coastguard Worker def __new__(cls, data, wrapper): 2106*da0073e9SAndroid Build Coastguard Worker return TestPythonDispatch.subclass_helper( 2107*da0073e9SAndroid Build Coastguard Worker cls, data, wrapper, dispatch_sizes_strides_policy="strides" 2108*da0073e9SAndroid Build Coastguard Worker ) 2109*da0073e9SAndroid Build Coastguard Worker 2110*da0073e9SAndroid Build Coastguard Worker @classmethod 2111*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args, kwargs): 2112*da0073e9SAndroid Build Coastguard Worker return NotImplemented 2113*da0073e9SAndroid Build Coastguard Worker 2114*da0073e9SAndroid Build Coastguard Worker class ExampleTensor2(torch.Tensor): 2115*da0073e9SAndroid Build Coastguard Worker @staticmethod 2116*da0073e9SAndroid Build Coastguard Worker def __new__(cls, data, wrapper): 2117*da0073e9SAndroid Build Coastguard Worker return TestPythonDispatch.subclass_helper( 2118*da0073e9SAndroid Build Coastguard Worker cls, data, wrapper, dispatch_sizes_strides_policy="strides" 2119*da0073e9SAndroid Build Coastguard Worker ) 2120*da0073e9SAndroid Build Coastguard Worker 2121*da0073e9SAndroid Build Coastguard Worker @classmethod 2122*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args, kwargs): 2123*da0073e9SAndroid Build Coastguard Worker if func.overloadpacket == torch.ops.aten.is_contiguous: 2124*da0073e9SAndroid Build Coastguard Worker return contiguous_data.is_contiguous() 2125*da0073e9SAndroid Build Coastguard Worker return NotImplemented 2126*da0073e9SAndroid Build Coastguard Worker 2127*da0073e9SAndroid Build Coastguard Worker class ExampleTensor3(torch.Tensor): 2128*da0073e9SAndroid Build Coastguard Worker @staticmethod 2129*da0073e9SAndroid Build Coastguard Worker def __new__(cls, data, wrapper): 2130*da0073e9SAndroid Build Coastguard Worker return TestPythonDispatch.subclass_helper( 2131*da0073e9SAndroid Build Coastguard Worker cls, data, wrapper, dispatch_sizes_strides_policy="strides" 2132*da0073e9SAndroid Build Coastguard Worker ) 2133*da0073e9SAndroid Build Coastguard Worker 2134*da0073e9SAndroid Build Coastguard Worker @classmethod 2135*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args, kwargs): 2136*da0073e9SAndroid Build Coastguard Worker if func.overloadpacket == torch.ops.aten.is_contiguous: 2137*da0073e9SAndroid Build Coastguard Worker return not_contiguous_data.is_contiguous() 2138*da0073e9SAndroid Build Coastguard Worker return NotImplemented 2139*da0073e9SAndroid Build Coastguard Worker 2140*da0073e9SAndroid Build Coastguard Worker err_msg = "Multiple dispatch failed for 'torch.ops.aten.is_contiguous'" 2141*da0073e9SAndroid Build Coastguard Worker e = ExampleTensor1(torch.randn(3, 3), use_wrapper_subclass) 2142*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, err_msg): 2143*da0073e9SAndroid Build Coastguard Worker e.is_contiguous() 2144*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, err_msg): 2145*da0073e9SAndroid Build Coastguard Worker e.contiguous() 2146*da0073e9SAndroid Build Coastguard Worker 2147*da0073e9SAndroid Build Coastguard Worker e = ExampleTensor2(torch.randn(3, 3), use_wrapper_subclass) 2148*da0073e9SAndroid Build Coastguard Worker self.assertEqual(e.is_contiguous(), True) 2149*da0073e9SAndroid Build Coastguard Worker e.contiguous() # this will just return the original TensorImpl since is_contiguous = True 2150*da0073e9SAndroid Build Coastguard Worker 2151*da0073e9SAndroid Build Coastguard Worker err_msg = "Multiple dispatch failed for" 2152*da0073e9SAndroid Build Coastguard Worker e = ExampleTensor3(torch.randn(3, 3), use_wrapper_subclass) 2153*da0073e9SAndroid Build Coastguard Worker self.assertEqual(e.is_contiguous(), False) 2154*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, err_msg): 2155*da0073e9SAndroid Build Coastguard Worker e.contiguous() 2156*da0073e9SAndroid Build Coastguard Worker 2157*da0073e9SAndroid Build Coastguard Worker def test_fancy_strides(self): 2158*da0073e9SAndroid Build Coastguard Worker calls = [] 2159*da0073e9SAndroid Build Coastguard Worker 2160*da0073e9SAndroid Build Coastguard Worker class ExampleTensor(torch.Tensor): 2161*da0073e9SAndroid Build Coastguard Worker @staticmethod 2162*da0073e9SAndroid Build Coastguard Worker def __new__(cls, data): 2163*da0073e9SAndroid Build Coastguard Worker return TestPythonDispatch.subclass_helper( 2164*da0073e9SAndroid Build Coastguard Worker cls, data, False, dispatch_sizes_strides_policy="strides" 2165*da0073e9SAndroid Build Coastguard Worker ) 2166*da0073e9SAndroid Build Coastguard Worker 2167*da0073e9SAndroid Build Coastguard Worker @classmethod 2168*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args, kwargs): 2169*da0073e9SAndroid Build Coastguard Worker if func in [ 2170*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.is_contiguous.default, 2171*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.is_contiguous.memory_format, 2172*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.is_strides_like_format.default, 2173*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.is_non_overlapping_and_dense.default, 2174*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.stride.default, 2175*da0073e9SAndroid Build Coastguard Worker ]: 2176*da0073e9SAndroid Build Coastguard Worker calls.append((func, list(args)[1:])) 2177*da0073e9SAndroid Build Coastguard Worker return None 2178*da0073e9SAndroid Build Coastguard Worker with no_dispatch(): 2179*da0073e9SAndroid Build Coastguard Worker return func(*args, **kwargs) 2180*da0073e9SAndroid Build Coastguard Worker 2181*da0073e9SAndroid Build Coastguard Worker e = ExampleTensor(torch.randn(2, 2)) 2182*da0073e9SAndroid Build Coastguard Worker self.assertFalse(e.is_contiguous(memory_format=torch.channels_last)) 2183*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2184*da0073e9SAndroid Build Coastguard Worker calls, [(torch.ops.aten.is_contiguous.memory_format, [torch.channels_last])] 2185*da0073e9SAndroid Build Coastguard Worker ) 2186*da0073e9SAndroid Build Coastguard Worker calls.clear() 2187*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 2188*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.is_strides_like_format.default(e, torch.channels_last) 2189*da0073e9SAndroid Build Coastguard Worker ) 2190*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2191*da0073e9SAndroid Build Coastguard Worker calls, 2192*da0073e9SAndroid Build Coastguard Worker [(torch.ops.aten.is_strides_like_format.default, [torch.channels_last])], 2193*da0073e9SAndroid Build Coastguard Worker ) 2194*da0073e9SAndroid Build Coastguard Worker calls.clear() 2195*da0073e9SAndroid Build Coastguard Worker self.assertTrue(torch.ops.aten.is_non_overlapping_and_dense.default(e)) 2196*da0073e9SAndroid Build Coastguard Worker self.assertEqual( 2197*da0073e9SAndroid Build Coastguard Worker calls, [(torch.ops.aten.is_non_overlapping_and_dense.default, [])] 2198*da0073e9SAndroid Build Coastguard Worker ) 2199*da0073e9SAndroid Build Coastguard Worker 2200*da0073e9SAndroid Build Coastguard Worker def test_device_slowpath(self): 2201*da0073e9SAndroid Build Coastguard Worker for use_wrapper_subclass in [True]: 2202*da0073e9SAndroid Build Coastguard Worker 2203*da0073e9SAndroid Build Coastguard Worker class ExampleTensor1(torch.Tensor): 2204*da0073e9SAndroid Build Coastguard Worker @staticmethod 2205*da0073e9SAndroid Build Coastguard Worker def __new__(cls, data, wrapper): 2206*da0073e9SAndroid Build Coastguard Worker return TestPythonDispatch.subclass_helper( 2207*da0073e9SAndroid Build Coastguard Worker cls, data, wrapper, dispatch_device=True 2208*da0073e9SAndroid Build Coastguard Worker ) 2209*da0073e9SAndroid Build Coastguard Worker 2210*da0073e9SAndroid Build Coastguard Worker @classmethod 2211*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args, kwargs): 2212*da0073e9SAndroid Build Coastguard Worker return NotImplemented 2213*da0073e9SAndroid Build Coastguard Worker 2214*da0073e9SAndroid Build Coastguard Worker class ExampleTensor2(torch.Tensor): 2215*da0073e9SAndroid Build Coastguard Worker @staticmethod 2216*da0073e9SAndroid Build Coastguard Worker def __new__(cls, data, wrapper): 2217*da0073e9SAndroid Build Coastguard Worker return TestPythonDispatch.subclass_helper( 2218*da0073e9SAndroid Build Coastguard Worker cls, data, wrapper, dispatch_device=True 2219*da0073e9SAndroid Build Coastguard Worker ) 2220*da0073e9SAndroid Build Coastguard Worker 2221*da0073e9SAndroid Build Coastguard Worker @classmethod 2222*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args, kwargs): 2223*da0073e9SAndroid Build Coastguard Worker if func.overloadpacket == torch.ops.prim.device: 2224*da0073e9SAndroid Build Coastguard Worker return torch.device("meta") 2225*da0073e9SAndroid Build Coastguard Worker return NotImplemented 2226*da0073e9SAndroid Build Coastguard Worker 2227*da0073e9SAndroid Build Coastguard Worker class ExampleTensor3(torch.Tensor): 2228*da0073e9SAndroid Build Coastguard Worker @staticmethod 2229*da0073e9SAndroid Build Coastguard Worker def __new__(cls, data, wrapper): 2230*da0073e9SAndroid Build Coastguard Worker return TestPythonDispatch.subclass_helper( 2231*da0073e9SAndroid Build Coastguard Worker cls, data, wrapper, dispatch_device=True 2232*da0073e9SAndroid Build Coastguard Worker ) 2233*da0073e9SAndroid Build Coastguard Worker 2234*da0073e9SAndroid Build Coastguard Worker @classmethod 2235*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args, kwargs): 2236*da0073e9SAndroid Build Coastguard Worker if func.overloadpacket == torch.ops.prim.device: 2237*da0073e9SAndroid Build Coastguard Worker return torch.device("meta") 2238*da0073e9SAndroid Build Coastguard Worker return NotImplemented 2239*da0073e9SAndroid Build Coastguard Worker 2240*da0073e9SAndroid Build Coastguard Worker err_msg = "Multiple dispatch failed for 'torch.ops.prim.device'" 2241*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, err_msg): 2242*da0073e9SAndroid Build Coastguard Worker e = ExampleTensor1(torch.randn(3, 3), use_wrapper_subclass) 2243*da0073e9SAndroid Build Coastguard Worker e.device() 2244*da0073e9SAndroid Build Coastguard Worker 2245*da0073e9SAndroid Build Coastguard Worker ten = torch.rand([1]) 2246*da0073e9SAndroid Build Coastguard Worker e = ExampleTensor2(torch.randn(3, 3, device="cpu"), use_wrapper_subclass) 2247*da0073e9SAndroid Build Coastguard Worker self.assertEqual(e.device.type, "meta") 2248*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ten.type_as(e).device.type, "meta") 2249*da0073e9SAndroid Build Coastguard Worker 2250*da0073e9SAndroid Build Coastguard Worker e = ExampleTensor3(torch.randn(3, 3, device="cpu"), use_wrapper_subclass) 2251*da0073e9SAndroid Build Coastguard Worker self.assertEqual(e.device.type, "meta") 2252*da0073e9SAndroid Build Coastguard Worker self.assertEqual(ten.type_as(e).device.type, "meta") 2253*da0073e9SAndroid Build Coastguard Worker 2254*da0073e9SAndroid Build Coastguard Worker def test_dim_slowpath(self): 2255*da0073e9SAndroid Build Coastguard Worker data = torch.randn(3, 3) 2256*da0073e9SAndroid Build Coastguard Worker 2257*da0073e9SAndroid Build Coastguard Worker for use_wrapper_subclass in [True, False]: 2258*da0073e9SAndroid Build Coastguard Worker 2259*da0073e9SAndroid Build Coastguard Worker class DimNotImplementedTensor(torch.Tensor): 2260*da0073e9SAndroid Build Coastguard Worker @staticmethod 2261*da0073e9SAndroid Build Coastguard Worker def __new__(cls, data, wrapper): 2262*da0073e9SAndroid Build Coastguard Worker return TestPythonDispatch.subclass_helper( 2263*da0073e9SAndroid Build Coastguard Worker cls, data, wrapper, dispatch_sizes_strides_policy="sizes" 2264*da0073e9SAndroid Build Coastguard Worker ) 2265*da0073e9SAndroid Build Coastguard Worker 2266*da0073e9SAndroid Build Coastguard Worker @classmethod 2267*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args, kwargs): 2268*da0073e9SAndroid Build Coastguard Worker return NotImplemented 2269*da0073e9SAndroid Build Coastguard Worker 2270*da0073e9SAndroid Build Coastguard Worker class DimImplementedTensor(torch.Tensor): 2271*da0073e9SAndroid Build Coastguard Worker @staticmethod 2272*da0073e9SAndroid Build Coastguard Worker def __new__(cls, data, wrapper): 2273*da0073e9SAndroid Build Coastguard Worker return TestPythonDispatch.subclass_helper( 2274*da0073e9SAndroid Build Coastguard Worker cls, data, wrapper, dispatch_sizes_strides_policy="sizes" 2275*da0073e9SAndroid Build Coastguard Worker ) 2276*da0073e9SAndroid Build Coastguard Worker 2277*da0073e9SAndroid Build Coastguard Worker @classmethod 2278*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args, kwargs): 2279*da0073e9SAndroid Build Coastguard Worker if func.overloadpacket == torch.ops.aten.dim: 2280*da0073e9SAndroid Build Coastguard Worker return data.dim() 2281*da0073e9SAndroid Build Coastguard Worker return NotImplemented 2282*da0073e9SAndroid Build Coastguard Worker 2283*da0073e9SAndroid Build Coastguard Worker err_msg = "Multiple dispatch failed for 'torch.ops.aten.dim'" 2284*da0073e9SAndroid Build Coastguard Worker e = DimNotImplementedTensor(torch.randn(3, 3), use_wrapper_subclass) 2285*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, err_msg): 2286*da0073e9SAndroid Build Coastguard Worker e.dim() 2287*da0073e9SAndroid Build Coastguard Worker 2288*da0073e9SAndroid Build Coastguard Worker t = DimImplementedTensor(torch.randn(3, 3), use_wrapper_subclass) 2289*da0073e9SAndroid Build Coastguard Worker self.assertEqual(t.dim(), 2) 2290*da0073e9SAndroid Build Coastguard Worker 2291*da0073e9SAndroid Build Coastguard Worker def test_maybe_tuple_bug(self): 2292*da0073e9SAndroid Build Coastguard Worker class T(torch.Tensor): 2293*da0073e9SAndroid Build Coastguard Worker @classmethod 2294*da0073e9SAndroid Build Coastguard Worker def __torch_function__(cls, *args, **kwargs): 2295*da0073e9SAndroid Build Coastguard Worker pass 2296*da0073e9SAndroid Build Coastguard Worker 2297*da0073e9SAndroid Build Coastguard Worker a = torch.rand(3) 2298*da0073e9SAndroid Build Coastguard Worker 2299*da0073e9SAndroid Build Coastguard Worker a[[T(), T()]] 2300*da0073e9SAndroid Build Coastguard Worker 2301*da0073e9SAndroid Build Coastguard Worker def test_standard_is_not_subclass(self): 2302*da0073e9SAndroid Build Coastguard Worker # https://github.com/pytorch/pytorch/issues/79079 2303*da0073e9SAndroid Build Coastguard Worker self.assertFalse(torch._C._dispatch_isTensorSubclassLike(torch.empty(0))) 2304*da0073e9SAndroid Build Coastguard Worker 2305*da0073e9SAndroid Build Coastguard Worker def test_sym_sizes_strides_slow_path(self): 2306*da0073e9SAndroid Build Coastguard Worker class TestTensor(torch.Tensor): 2307*da0073e9SAndroid Build Coastguard Worker @staticmethod 2308*da0073e9SAndroid Build Coastguard Worker def __new__(cls, *args, **kwargs): 2309*da0073e9SAndroid Build Coastguard Worker r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] 2310*da0073e9SAndroid Build Coastguard Worker cls, (0,), dispatch_sizes_strides_policy="sizes" 2311*da0073e9SAndroid Build Coastguard Worker ) 2312*da0073e9SAndroid Build Coastguard Worker return r 2313*da0073e9SAndroid Build Coastguard Worker 2314*da0073e9SAndroid Build Coastguard Worker @classmethod 2315*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 2316*da0073e9SAndroid Build Coastguard Worker if func in ( 2317*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.sym_size.default, 2318*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.sym_stride.default, 2319*da0073e9SAndroid Build Coastguard Worker ): 2320*da0073e9SAndroid Build Coastguard Worker from torch._dynamo.source import ConstantSource 2321*da0073e9SAndroid Build Coastguard Worker from torch.fx.experimental.symbolic_shapes import ( 2322*da0073e9SAndroid Build Coastguard Worker DimDynamic, 2323*da0073e9SAndroid Build Coastguard Worker ShapeEnv, 2324*da0073e9SAndroid Build Coastguard Worker ) 2325*da0073e9SAndroid Build Coastguard Worker 2326*da0073e9SAndroid Build Coastguard Worker shape_env = ShapeEnv() 2327*da0073e9SAndroid Build Coastguard Worker si = shape_env.create_symintnode( 2328*da0073e9SAndroid Build Coastguard Worker shape_env.create_symbol( 2329*da0073e9SAndroid Build Coastguard Worker 123, 2330*da0073e9SAndroid Build Coastguard Worker source=ConstantSource("abc"), 2331*da0073e9SAndroid Build Coastguard Worker dynamic_dim=DimDynamic.DUCK, 2332*da0073e9SAndroid Build Coastguard Worker constraint_dim=None, 2333*da0073e9SAndroid Build Coastguard Worker ), 2334*da0073e9SAndroid Build Coastguard Worker hint=123, 2335*da0073e9SAndroid Build Coastguard Worker ) 2336*da0073e9SAndroid Build Coastguard Worker return (si,) 2337*da0073e9SAndroid Build Coastguard Worker 2338*da0073e9SAndroid Build Coastguard Worker t = TestTensor() 2339*da0073e9SAndroid Build Coastguard Worker si = t.size()[0] 2340*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(si, torch.SymInt) 2341*da0073e9SAndroid Build Coastguard Worker si = t.stride()[0] 2342*da0073e9SAndroid Build Coastguard Worker self.assertIsInstance(si, torch.SymInt) 2343*da0073e9SAndroid Build Coastguard Worker 2344*da0073e9SAndroid Build Coastguard Worker def test_strides_slow_path(self): 2345*da0073e9SAndroid Build Coastguard Worker for use_wrapper_subclass in [True, False]: 2346*da0073e9SAndroid Build Coastguard Worker 2347*da0073e9SAndroid Build Coastguard Worker class StridesNotImplemented(torch.Tensor): 2348*da0073e9SAndroid Build Coastguard Worker @staticmethod 2349*da0073e9SAndroid Build Coastguard Worker def __new__(cls, data, wrapper): 2350*da0073e9SAndroid Build Coastguard Worker return TestPythonDispatch.subclass_helper( 2351*da0073e9SAndroid Build Coastguard Worker cls, data, wrapper, dispatch_sizes_strides_policy="strides" 2352*da0073e9SAndroid Build Coastguard Worker ) 2353*da0073e9SAndroid Build Coastguard Worker 2354*da0073e9SAndroid Build Coastguard Worker @classmethod 2355*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args, kwargs): 2356*da0073e9SAndroid Build Coastguard Worker return NotImplemented 2357*da0073e9SAndroid Build Coastguard Worker 2358*da0073e9SAndroid Build Coastguard Worker class StridesCustomReturn(torch.Tensor): 2359*da0073e9SAndroid Build Coastguard Worker @staticmethod 2360*da0073e9SAndroid Build Coastguard Worker def __new__(cls, data, wrapper): 2361*da0073e9SAndroid Build Coastguard Worker return TestPythonDispatch.subclass_helper( 2362*da0073e9SAndroid Build Coastguard Worker cls, data, wrapper, dispatch_sizes_strides_policy="strides" 2363*da0073e9SAndroid Build Coastguard Worker ) 2364*da0073e9SAndroid Build Coastguard Worker 2365*da0073e9SAndroid Build Coastguard Worker @classmethod 2366*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args, kwargs): 2367*da0073e9SAndroid Build Coastguard Worker if func == torch.ops.aten.sym_stride.default: 2368*da0073e9SAndroid Build Coastguard Worker return (4, 2) 2369*da0073e9SAndroid Build Coastguard Worker return NotImplemented 2370*da0073e9SAndroid Build Coastguard Worker 2371*da0073e9SAndroid Build Coastguard Worker class StridesDefaultReturn(torch.Tensor): 2372*da0073e9SAndroid Build Coastguard Worker @staticmethod 2373*da0073e9SAndroid Build Coastguard Worker def __new__(cls, data, wrapper): 2374*da0073e9SAndroid Build Coastguard Worker return TestPythonDispatch.subclass_helper( 2375*da0073e9SAndroid Build Coastguard Worker cls, data, wrapper, dispatch_sizes_strides_policy="strides" 2376*da0073e9SAndroid Build Coastguard Worker ) 2377*da0073e9SAndroid Build Coastguard Worker 2378*da0073e9SAndroid Build Coastguard Worker @classmethod 2379*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args, kwargs): 2380*da0073e9SAndroid Build Coastguard Worker if func == torch.ops.aten.sym_stride.default: 2381*da0073e9SAndroid Build Coastguard Worker return None 2382*da0073e9SAndroid Build Coastguard Worker return NotImplemented 2383*da0073e9SAndroid Build Coastguard Worker 2384*da0073e9SAndroid Build Coastguard Worker err_msg = "Multiple dispatch failed for 'torch.ops.aten.sym_stride'" 2385*da0073e9SAndroid Build Coastguard Worker e = StridesNotImplemented(torch.randn(3, 3), use_wrapper_subclass) 2386*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, err_msg): 2387*da0073e9SAndroid Build Coastguard Worker e.stride() 2388*da0073e9SAndroid Build Coastguard Worker 2389*da0073e9SAndroid Build Coastguard Worker e = StridesCustomReturn(torch.randn(3, 3), use_wrapper_subclass) 2390*da0073e9SAndroid Build Coastguard Worker self.assertEqual(e.stride(), (4, 2)) 2391*da0073e9SAndroid Build Coastguard Worker 2392*da0073e9SAndroid Build Coastguard Worker e = StridesDefaultReturn(torch.randn(6, 2), use_wrapper_subclass) 2393*da0073e9SAndroid Build Coastguard Worker self.assertEqual(e.stride(), (2, 1)) 2394*da0073e9SAndroid Build Coastguard Worker 2395*da0073e9SAndroid Build Coastguard Worker def test_sizes_slow_path(self): 2396*da0073e9SAndroid Build Coastguard Worker for use_wrapper_subclass in [True, False]: 2397*da0073e9SAndroid Build Coastguard Worker data = torch.randn(6, 2) 2398*da0073e9SAndroid Build Coastguard Worker 2399*da0073e9SAndroid Build Coastguard Worker class SizesNotImplemented(torch.Tensor): 2400*da0073e9SAndroid Build Coastguard Worker @staticmethod 2401*da0073e9SAndroid Build Coastguard Worker def __new__(cls, data, wrapper): 2402*da0073e9SAndroid Build Coastguard Worker return TestPythonDispatch.subclass_helper( 2403*da0073e9SAndroid Build Coastguard Worker cls, data, wrapper, dispatch_sizes_strides_policy="sizes" 2404*da0073e9SAndroid Build Coastguard Worker ) 2405*da0073e9SAndroid Build Coastguard Worker 2406*da0073e9SAndroid Build Coastguard Worker @classmethod 2407*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args, kwargs): 2408*da0073e9SAndroid Build Coastguard Worker if func.overloadpacket == torch.ops.aten.dim: 2409*da0073e9SAndroid Build Coastguard Worker return data.dim() 2410*da0073e9SAndroid Build Coastguard Worker return NotImplemented 2411*da0073e9SAndroid Build Coastguard Worker 2412*da0073e9SAndroid Build Coastguard Worker class SizesCustomReturn(torch.Tensor): 2413*da0073e9SAndroid Build Coastguard Worker @staticmethod 2414*da0073e9SAndroid Build Coastguard Worker def __new__(cls, data, wrapper): 2415*da0073e9SAndroid Build Coastguard Worker return TestPythonDispatch.subclass_helper( 2416*da0073e9SAndroid Build Coastguard Worker cls, data, wrapper, dispatch_sizes_strides_policy="sizes" 2417*da0073e9SAndroid Build Coastguard Worker ) 2418*da0073e9SAndroid Build Coastguard Worker 2419*da0073e9SAndroid Build Coastguard Worker @classmethod 2420*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args, kwargs): 2421*da0073e9SAndroid Build Coastguard Worker if func.overloadpacket == torch.ops.aten.dim: 2422*da0073e9SAndroid Build Coastguard Worker return data.dim() 2423*da0073e9SAndroid Build Coastguard Worker if func.overloadpacket == torch.ops.aten.sym_size: 2424*da0073e9SAndroid Build Coastguard Worker return (5, 3) 2425*da0073e9SAndroid Build Coastguard Worker return NotImplemented 2426*da0073e9SAndroid Build Coastguard Worker 2427*da0073e9SAndroid Build Coastguard Worker class SizesDefaultReturn(torch.Tensor): 2428*da0073e9SAndroid Build Coastguard Worker @staticmethod 2429*da0073e9SAndroid Build Coastguard Worker def __new__(cls, data, wrapper): 2430*da0073e9SAndroid Build Coastguard Worker return TestPythonDispatch.subclass_helper( 2431*da0073e9SAndroid Build Coastguard Worker cls, data, wrapper, dispatch_sizes_strides_policy="sizes" 2432*da0073e9SAndroid Build Coastguard Worker ) 2433*da0073e9SAndroid Build Coastguard Worker 2434*da0073e9SAndroid Build Coastguard Worker @classmethod 2435*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args, kwargs): 2436*da0073e9SAndroid Build Coastguard Worker if func.overloadpacket == torch.ops.aten.dim: 2437*da0073e9SAndroid Build Coastguard Worker return data.dim() 2438*da0073e9SAndroid Build Coastguard Worker if func.overloadpacket == torch.ops.aten.sym_size: 2439*da0073e9SAndroid Build Coastguard Worker return None 2440*da0073e9SAndroid Build Coastguard Worker return NotImplemented 2441*da0073e9SAndroid Build Coastguard Worker 2442*da0073e9SAndroid Build Coastguard Worker err_msg = "Multiple dispatch failed for 'torch.ops.aten.sym_size'" 2443*da0073e9SAndroid Build Coastguard Worker e = SizesNotImplemented(torch.randn(3, 3), use_wrapper_subclass) 2444*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, err_msg): 2445*da0073e9SAndroid Build Coastguard Worker e.size() 2446*da0073e9SAndroid Build Coastguard Worker 2447*da0073e9SAndroid Build Coastguard Worker e = SizesCustomReturn(torch.randn(3, 3), use_wrapper_subclass) 2448*da0073e9SAndroid Build Coastguard Worker self.assertEqual(e.size(), (5, 3)) 2449*da0073e9SAndroid Build Coastguard Worker 2450*da0073e9SAndroid Build Coastguard Worker e = SizesDefaultReturn(torch.randn(4, 2), use_wrapper_subclass) 2451*da0073e9SAndroid Build Coastguard Worker self.assertEqual(e.size(), (4, 2)) 2452*da0073e9SAndroid Build Coastguard Worker 2453*da0073e9SAndroid Build Coastguard Worker def test_custom_size_policy_dynamic_shapes(self): 2454*da0073e9SAndroid Build Coastguard Worker data = torch.randn(6, 2) 2455*da0073e9SAndroid Build Coastguard Worker 2456*da0073e9SAndroid Build Coastguard Worker class CustomSizeDynamicShapesTensor(torch.Tensor): 2457*da0073e9SAndroid Build Coastguard Worker @staticmethod 2458*da0073e9SAndroid Build Coastguard Worker def __new__(cls, inner): 2459*da0073e9SAndroid Build Coastguard Worker return torch.Tensor._make_wrapper_subclass( 2460*da0073e9SAndroid Build Coastguard Worker # TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great. 2461*da0073e9SAndroid Build Coastguard Worker # Calling the overload that has kwargs causes us to go down the first overload path, 2462*da0073e9SAndroid Build Coastguard Worker # which will **always** specialize sizes. 2463*da0073e9SAndroid Build Coastguard Worker # We should probably eventually fix this so that the first overload can just handle dynamic shapes. 2464*da0073e9SAndroid Build Coastguard Worker cls, 2465*da0073e9SAndroid Build Coastguard Worker inner.size(), 2466*da0073e9SAndroid Build Coastguard Worker inner.stride(), 2467*da0073e9SAndroid Build Coastguard Worker None, 2468*da0073e9SAndroid Build Coastguard Worker None, 2469*da0073e9SAndroid Build Coastguard Worker inner.dtype, 2470*da0073e9SAndroid Build Coastguard Worker inner.layout, 2471*da0073e9SAndroid Build Coastguard Worker inner.device, 2472*da0073e9SAndroid Build Coastguard Worker False, 2473*da0073e9SAndroid Build Coastguard Worker inner.requires_grad, 2474*da0073e9SAndroid Build Coastguard Worker "sizes", 2475*da0073e9SAndroid Build Coastguard Worker ) 2476*da0073e9SAndroid Build Coastguard Worker 2477*da0073e9SAndroid Build Coastguard Worker def __init__(self, inner): 2478*da0073e9SAndroid Build Coastguard Worker self.inner = inner 2479*da0073e9SAndroid Build Coastguard Worker 2480*da0073e9SAndroid Build Coastguard Worker @classmethod 2481*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args, kwargs): 2482*da0073e9SAndroid Build Coastguard Worker if func == torch.ops.aten.sym_size.default: 2483*da0073e9SAndroid Build Coastguard Worker return args[0].inner.shape 2484*da0073e9SAndroid Build Coastguard Worker if func == torch.ops.aten.sym_stride.default: 2485*da0073e9SAndroid Build Coastguard Worker return args[0].inner.shape 2486*da0073e9SAndroid Build Coastguard Worker return NotImplemented 2487*da0073e9SAndroid Build Coastguard Worker 2488*da0073e9SAndroid Build Coastguard Worker x = torch.ones(2, 2) 2489*da0073e9SAndroid Build Coastguard Worker 2490*da0073e9SAndroid Build Coastguard Worker def trace_fn(x): 2491*da0073e9SAndroid Build Coastguard Worker x_wrapper = CustomSizeDynamicShapesTensor(x) 2492*da0073e9SAndroid Build Coastguard Worker return x_wrapper.size(), x_wrapper.stride() 2493*da0073e9SAndroid Build Coastguard Worker 2494*da0073e9SAndroid Build Coastguard Worker fx_g = make_fx(trace_fn, tracing_mode="symbolic")(x) 2495*da0073e9SAndroid Build Coastguard Worker self.assertExpectedInline( 2496*da0073e9SAndroid Build Coastguard Worker fx_g.code.strip(), 2497*da0073e9SAndroid Build Coastguard Worker """\ 2498*da0073e9SAndroid Build Coastguard Workerdef forward(self, x_1): 2499*da0073e9SAndroid Build Coastguard Worker sym_size_int = torch.ops.aten.sym_size.int(x_1, 0) 2500*da0073e9SAndroid Build Coastguard Worker sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1); x_1 = None 2501*da0073e9SAndroid Build Coastguard Worker return ((sym_size_int, sym_size_int_1), (sym_size_int, sym_size_int_1))""", 2502*da0073e9SAndroid Build Coastguard Worker ) 2503*da0073e9SAndroid Build Coastguard Worker 2504*da0073e9SAndroid Build Coastguard Worker def test_data_ptr_respects_numel_slow_path(self): 2505*da0073e9SAndroid Build Coastguard Worker data = torch.randn(6, 2) 2506*da0073e9SAndroid Build Coastguard Worker 2507*da0073e9SAndroid Build Coastguard Worker class NumelDefaultReturn(torch.Tensor): 2508*da0073e9SAndroid Build Coastguard Worker @staticmethod 2509*da0073e9SAndroid Build Coastguard Worker def __new__(cls, data, wrapper): 2510*da0073e9SAndroid Build Coastguard Worker return TestPythonDispatch.subclass_helper( 2511*da0073e9SAndroid Build Coastguard Worker cls, data, wrapper, dispatch_sizes_strides_policy="sizes" 2512*da0073e9SAndroid Build Coastguard Worker ) 2513*da0073e9SAndroid Build Coastguard Worker 2514*da0073e9SAndroid Build Coastguard Worker @classmethod 2515*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args, kwargs): 2516*da0073e9SAndroid Build Coastguard Worker if func.overloadpacket == torch.ops.aten.dim: 2517*da0073e9SAndroid Build Coastguard Worker return data.dim() 2518*da0073e9SAndroid Build Coastguard Worker if func.overloadpacket == torch.ops.aten.numel: 2519*da0073e9SAndroid Build Coastguard Worker numel_called[0] = True 2520*da0073e9SAndroid Build Coastguard Worker return None 2521*da0073e9SAndroid Build Coastguard Worker return NotImplemented 2522*da0073e9SAndroid Build Coastguard Worker 2523*da0073e9SAndroid Build Coastguard Worker for use_wrapper_subclass in (False, True): 2524*da0073e9SAndroid Build Coastguard Worker numel_called = [False] 2525*da0073e9SAndroid Build Coastguard Worker e = NumelDefaultReturn(torch.randn(2, 2), use_wrapper_subclass) 2526*da0073e9SAndroid Build Coastguard Worker e.data_ptr() 2527*da0073e9SAndroid Build Coastguard Worker self.assertTrue(numel_called[0]) 2528*da0073e9SAndroid Build Coastguard Worker 2529*da0073e9SAndroid Build Coastguard Worker def test_layout_slow_path(self): 2530*da0073e9SAndroid Build Coastguard Worker for use_wrapper_subclass in [True, False]: 2531*da0073e9SAndroid Build Coastguard Worker data = torch.randn(6, 2) 2532*da0073e9SAndroid Build Coastguard Worker 2533*da0073e9SAndroid Build Coastguard Worker class LayoutNotImplemented(torch.Tensor): 2534*da0073e9SAndroid Build Coastguard Worker @staticmethod 2535*da0073e9SAndroid Build Coastguard Worker def __new__(cls, data, wrapper): 2536*da0073e9SAndroid Build Coastguard Worker return TestPythonDispatch.subclass_helper( 2537*da0073e9SAndroid Build Coastguard Worker cls, data, wrapper, dispatch_layout=True 2538*da0073e9SAndroid Build Coastguard Worker ) 2539*da0073e9SAndroid Build Coastguard Worker 2540*da0073e9SAndroid Build Coastguard Worker @classmethod 2541*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args, kwargs): 2542*da0073e9SAndroid Build Coastguard Worker return NotImplemented 2543*da0073e9SAndroid Build Coastguard Worker 2544*da0073e9SAndroid Build Coastguard Worker class LayoutCustomReturn(torch.Tensor): 2545*da0073e9SAndroid Build Coastguard Worker @staticmethod 2546*da0073e9SAndroid Build Coastguard Worker def __new__(cls, data, wrapper): 2547*da0073e9SAndroid Build Coastguard Worker return TestPythonDispatch.subclass_helper( 2548*da0073e9SAndroid Build Coastguard Worker cls, data, wrapper, dispatch_layout=True 2549*da0073e9SAndroid Build Coastguard Worker ) 2550*da0073e9SAndroid Build Coastguard Worker 2551*da0073e9SAndroid Build Coastguard Worker @classmethod 2552*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args, kwargs): 2553*da0073e9SAndroid Build Coastguard Worker if func.overloadpacket == torch.ops.prim.layout: 2554*da0073e9SAndroid Build Coastguard Worker return torch.sparse_csr 2555*da0073e9SAndroid Build Coastguard Worker return NotImplemented 2556*da0073e9SAndroid Build Coastguard Worker 2557*da0073e9SAndroid Build Coastguard Worker class LayoutDefaultReturn(torch.Tensor): 2558*da0073e9SAndroid Build Coastguard Worker @staticmethod 2559*da0073e9SAndroid Build Coastguard Worker def __new__(cls, data, wrapper): 2560*da0073e9SAndroid Build Coastguard Worker return TestPythonDispatch.subclass_helper( 2561*da0073e9SAndroid Build Coastguard Worker cls, data, wrapper, dispatch_layout=True 2562*da0073e9SAndroid Build Coastguard Worker ) 2563*da0073e9SAndroid Build Coastguard Worker 2564*da0073e9SAndroid Build Coastguard Worker @classmethod 2565*da0073e9SAndroid Build Coastguard Worker def __torch_dispatch__(cls, func, types, args, kwargs): 2566*da0073e9SAndroid Build Coastguard Worker if func.overloadpacket == torch.ops.prim.layout: 2567*da0073e9SAndroid Build Coastguard Worker return data.layout 2568*da0073e9SAndroid Build Coastguard Worker return NotImplemented 2569*da0073e9SAndroid Build Coastguard Worker 2570*da0073e9SAndroid Build Coastguard Worker err_msg = "Multiple dispatch failed for 'torch.ops.prim.layout'" 2571*da0073e9SAndroid Build Coastguard Worker e = LayoutNotImplemented(torch.randn(3, 3), use_wrapper_subclass) 2572*da0073e9SAndroid Build Coastguard Worker with self.assertRaisesRegex(TypeError, err_msg): 2573*da0073e9SAndroid Build Coastguard Worker e.layout 2574*da0073e9SAndroid Build Coastguard Worker 2575*da0073e9SAndroid Build Coastguard Worker e = LayoutCustomReturn(torch.randn(3, 3), use_wrapper_subclass) 2576*da0073e9SAndroid Build Coastguard Worker self.assertEqual(e.layout, torch.sparse_csr) 2577*da0073e9SAndroid Build Coastguard Worker 2578*da0073e9SAndroid Build Coastguard Worker e = LayoutDefaultReturn(torch.randn(4, 2), use_wrapper_subclass) 2579*da0073e9SAndroid Build Coastguard Worker self.assertEqual(e.layout, torch.strided) 2580*da0073e9SAndroid Build Coastguard Worker 2581*da0073e9SAndroid Build Coastguard Worker 2582*da0073e9SAndroid Build Coastguard Workerclass TestPythonDispatcher(TestCase): 2583*da0073e9SAndroid Build Coastguard Worker def test_basic(self): 2584*da0073e9SAndroid Build Coastguard Worker x = torch.randn(2, requires_grad=True) 2585*da0073e9SAndroid Build Coastguard Worker r = torch._C._EnablePythonDispatcher() 2586*da0073e9SAndroid Build Coastguard Worker torch.add(x, x) 2587*da0073e9SAndroid Build Coastguard Worker 2588*da0073e9SAndroid Build Coastguard Worker def test_lstsq(self): 2589*da0073e9SAndroid Build Coastguard Worker a = torch.randn(4, 3) 2590*da0073e9SAndroid Build Coastguard Worker b = torch.rand(4, 3) 2591*da0073e9SAndroid Build Coastguard Worker expected_shape = torch.linalg.lstsq(a, b).solution.shape 2592*da0073e9SAndroid Build Coastguard Worker r = torch._C._EnablePythonDispatcher() 2593*da0073e9SAndroid Build Coastguard Worker python_disp_shape = torch.linalg.lstsq(a, b).solution.shape 2594*da0073e9SAndroid Build Coastguard Worker self.assertEqual(expected_shape, python_disp_shape) 2595*da0073e9SAndroid Build Coastguard Worker 2596*da0073e9SAndroid Build Coastguard Worker 2597*da0073e9SAndroid Build Coastguard Workerclass TestWrapperSubclassAliasing(TestCase): 2598*da0073e9SAndroid Build Coastguard Worker def _test_wrapper_subclass_aliasing(self, op, args, kwargs): 2599*da0073e9SAndroid Build Coastguard Worker def to_subclass(t: torch.Tensor): 2600*da0073e9SAndroid Build Coastguard Worker return TwoTensor(t, t.clone()) 2601*da0073e9SAndroid Build Coastguard Worker 2602*da0073e9SAndroid Build Coastguard Worker result_ref = op(*args, **kwargs) 2603*da0073e9SAndroid Build Coastguard Worker 2604*da0073e9SAndroid Build Coastguard Worker args_subclass = pytree.tree_map_only(torch.Tensor, to_subclass, args) 2605*da0073e9SAndroid Build Coastguard Worker kwargs_subclass = pytree.tree_map_only(torch.Tensor, to_subclass, kwargs) 2606*da0073e9SAndroid Build Coastguard Worker 2607*da0073e9SAndroid Build Coastguard Worker result_test = op(*args_subclass, **kwargs_subclass) 2608*da0073e9SAndroid Build Coastguard Worker 2609*da0073e9SAndroid Build Coastguard Worker args_ref_flat = pytree.arg_tree_leaves(*args, **kwargs) 2610*da0073e9SAndroid Build Coastguard Worker args_ref_flat_tensors = [ 2611*da0073e9SAndroid Build Coastguard Worker x for x in args_ref_flat if isinstance(x, torch.Tensor) 2612*da0073e9SAndroid Build Coastguard Worker ] 2613*da0073e9SAndroid Build Coastguard Worker 2614*da0073e9SAndroid Build Coastguard Worker args_test_flat = pytree.tree_leaves((args_subclass, kwargs_subclass)) 2615*da0073e9SAndroid Build Coastguard Worker args_test_flat_tensors = [ 2616*da0073e9SAndroid Build Coastguard Worker x for x in args_test_flat if isinstance(x, torch.Tensor) 2617*da0073e9SAndroid Build Coastguard Worker ] 2618*da0073e9SAndroid Build Coastguard Worker 2619*da0073e9SAndroid Build Coastguard Worker result_ref_flat = pytree.tree_leaves(result_ref) 2620*da0073e9SAndroid Build Coastguard Worker result_ref_flat_tensors = [ 2621*da0073e9SAndroid Build Coastguard Worker x for x in result_ref_flat if isinstance(x, torch.Tensor) 2622*da0073e9SAndroid Build Coastguard Worker ] 2623*da0073e9SAndroid Build Coastguard Worker 2624*da0073e9SAndroid Build Coastguard Worker result_test_flat = pytree.tree_leaves(result_test) 2625*da0073e9SAndroid Build Coastguard Worker result_test_flat_tensors = [ 2626*da0073e9SAndroid Build Coastguard Worker x for x in result_test_flat if isinstance(x, torch.Tensor) 2627*da0073e9SAndroid Build Coastguard Worker ] 2628*da0073e9SAndroid Build Coastguard Worker 2629*da0073e9SAndroid Build Coastguard Worker for o_ref, o_test in zip(result_ref_flat_tensors, result_test_flat_tensors): 2630*da0073e9SAndroid Build Coastguard Worker for a_ref, a_test in zip(args_ref_flat_tensors, args_test_flat_tensors): 2631*da0073e9SAndroid Build Coastguard Worker out_is_inpt = o_ref is a_ref 2632*da0073e9SAndroid Build Coastguard Worker if out_is_inpt: 2633*da0073e9SAndroid Build Coastguard Worker self.assertTrue(o_test is a_test) 2634*da0073e9SAndroid Build Coastguard Worker 2635*da0073e9SAndroid Build Coastguard Worker out_aliases_inpt = StorageWeakRef( 2636*da0073e9SAndroid Build Coastguard Worker o_ref.untyped_storage() 2637*da0073e9SAndroid Build Coastguard Worker ) == StorageWeakRef(a_ref.untyped_storage()) 2638*da0073e9SAndroid Build Coastguard Worker if out_aliases_inpt: 2639*da0073e9SAndroid Build Coastguard Worker self.assertTrue( 2640*da0073e9SAndroid Build Coastguard Worker StorageWeakRef(o_test.untyped_storage()) 2641*da0073e9SAndroid Build Coastguard Worker == StorageWeakRef(a_test.untyped_storage()) 2642*da0073e9SAndroid Build Coastguard Worker ) 2643*da0073e9SAndroid Build Coastguard Worker else: 2644*da0073e9SAndroid Build Coastguard Worker self.assertFalse( 2645*da0073e9SAndroid Build Coastguard Worker StorageWeakRef(o_test.untyped_storage()) 2646*da0073e9SAndroid Build Coastguard Worker == StorageWeakRef(a_test.untyped_storage()) 2647*da0073e9SAndroid Build Coastguard Worker ) 2648*da0073e9SAndroid Build Coastguard Worker 2649*da0073e9SAndroid Build Coastguard Worker # This tests the correctness of `torch.utils._python_dispatch.return_and_correct_aliasing`, 2650*da0073e9SAndroid Build Coastguard Worker # a util for wrapper subclasses to promise correct aliasing behavior. 2651*da0073e9SAndroid Build Coastguard Worker # It's probably overkill to test every OpInfo, 2652*da0073e9SAndroid Build Coastguard Worker # so I picked a sampling of ops with representative schemas. 2653*da0073e9SAndroid Build Coastguard Worker @ops( 2654*da0073e9SAndroid Build Coastguard Worker [ 2655*da0073e9SAndroid Build Coastguard Worker op 2656*da0073e9SAndroid Build Coastguard Worker for op in op_db 2657*da0073e9SAndroid Build Coastguard Worker if op.name 2658*da0073e9SAndroid Build Coastguard Worker in [ 2659*da0073e9SAndroid Build Coastguard Worker "mul", # out-of-place 2660*da0073e9SAndroid Build Coastguard Worker "cat", # out-of-place (TensorList input) 2661*da0073e9SAndroid Build Coastguard Worker "index", # out-of-place (Optional TensorList input) 2662*da0073e9SAndroid Build Coastguard Worker "mul_", # inplace 2663*da0073e9SAndroid Build Coastguard Worker "view", # view 2664*da0073e9SAndroid Build Coastguard Worker "t_", # inplace-view 2665*da0073e9SAndroid Build Coastguard Worker "split", # view (multi-return) 2666*da0073e9SAndroid Build Coastguard Worker "native_batch_norm", # mutable op (returns outputs and mutates some inputs) 2667*da0073e9SAndroid Build Coastguard Worker ] 2668*da0073e9SAndroid Build Coastguard Worker ], 2669*da0073e9SAndroid Build Coastguard Worker allowed_dtypes=(torch.float,), 2670*da0073e9SAndroid Build Coastguard Worker ) 2671*da0073e9SAndroid Build Coastguard Worker def test_wrapper_subclass_aliasing(self, device, dtype, op): 2672*da0073e9SAndroid Build Coastguard Worker samples = op.sample_inputs(device, dtype) 2673*da0073e9SAndroid Build Coastguard Worker sample = first_sample(self, samples) 2674*da0073e9SAndroid Build Coastguard Worker args = (sample.input, *sample.args) 2675*da0073e9SAndroid Build Coastguard Worker kwargs = sample.kwargs 2676*da0073e9SAndroid Build Coastguard Worker self._test_wrapper_subclass_aliasing(op, args, kwargs) 2677*da0073e9SAndroid Build Coastguard Worker 2678*da0073e9SAndroid Build Coastguard Worker @ops(custom_op_db, allowed_dtypes=(torch.float,)) 2679*da0073e9SAndroid Build Coastguard Worker def test_wrapper_subclass_aliasing_custom(self, device, dtype, op): 2680*da0073e9SAndroid Build Coastguard Worker samples = op.sample_inputs(device, dtype) 2681*da0073e9SAndroid Build Coastguard Worker sample = first_sample(self, samples) 2682*da0073e9SAndroid Build Coastguard Worker args = (sample.input, *sample.args) 2683*da0073e9SAndroid Build Coastguard Worker kwargs = sample.kwargs 2684*da0073e9SAndroid Build Coastguard Worker self._test_wrapper_subclass_aliasing(op, args, kwargs) 2685*da0073e9SAndroid Build Coastguard Worker 2686*da0073e9SAndroid Build Coastguard Worker def test_wrapper_subclass_aliasing_conv2d(self, device): 2687*da0073e9SAndroid Build Coastguard Worker args = (torch.randn(4, 4, 4, 4), torch.randn(4, 4, 4, 4)) 2688*da0073e9SAndroid Build Coastguard Worker kwargs = {} 2689*da0073e9SAndroid Build Coastguard Worker # conv2d has a default arg 'int[2] strides=0', 2690*da0073e9SAndroid Build Coastguard Worker # which torchscript expands into 'int[2] strides=[0, 0]' 2691*da0073e9SAndroid Build Coastguard Worker # Make sure that _return_and_correct_aliasing can handle this case 2692*da0073e9SAndroid Build Coastguard Worker # (I'm using inference_mode to make sure conv2d doesn't decompose and goes to torch_dispatch) 2693*da0073e9SAndroid Build Coastguard Worker with torch.inference_mode(): 2694*da0073e9SAndroid Build Coastguard Worker self._test_wrapper_subclass_aliasing( 2695*da0073e9SAndroid Build Coastguard Worker torch.ops.aten.conv2d.default, args, kwargs 2696*da0073e9SAndroid Build Coastguard Worker ) 2697*da0073e9SAndroid Build Coastguard Worker 2698*da0073e9SAndroid Build Coastguard Worker def test_wrapper_subclass_aliasing_out_op(self, device): 2699*da0073e9SAndroid Build Coastguard Worker # Make sure that _return_and_correct_aliasing can handle kwargs w mutable tensors 2700*da0073e9SAndroid Build Coastguard Worker args = (torch.ones(4), torch.ones(4)) 2701*da0073e9SAndroid Build Coastguard Worker kwargs = {"out": torch.empty(4)} 2702*da0073e9SAndroid Build Coastguard Worker self._test_wrapper_subclass_aliasing(torch.ops.aten.add.out, args, kwargs) 2703*da0073e9SAndroid Build Coastguard Worker 2704*da0073e9SAndroid Build Coastguard Worker 2705*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestWrapperSubclassAliasing, globals()) 2706*da0073e9SAndroid Build Coastguard Worker 2707*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__": 2708*da0073e9SAndroid Build Coastguard Worker run_tests() 2709