xref: /aosp_15_r20/external/pytorch/test/test_python_dispatch.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# Owner(s): ["module: __torch_dispatch__"]
2*da0073e9SAndroid Build Coastguard Worker
3*da0073e9SAndroid Build Coastguard Workerimport logging
4*da0073e9SAndroid Build Coastguard Workerimport sys
5*da0073e9SAndroid Build Coastguard Workerimport tempfile
6*da0073e9SAndroid Build Coastguard Workerimport unittest
7*da0073e9SAndroid Build Coastguard Workerfrom copy import deepcopy
8*da0073e9SAndroid Build Coastguard Worker
9*da0073e9SAndroid Build Coastguard Workerimport torch
10*da0073e9SAndroid Build Coastguard Workerimport torch._dynamo
11*da0073e9SAndroid Build Coastguard Workerfrom torch import SymInt
12*da0073e9SAndroid Build Coastguard Workerfrom torch._C import DispatchKey, DispatchKeySet
13*da0073e9SAndroid Build Coastguard Workerfrom torch._custom_op.functional import register_functional_op
14*da0073e9SAndroid Build Coastguard Workerfrom torch._subclasses.fake_tensor import FakeTensorMode
15*da0073e9SAndroid Build Coastguard Workerfrom torch.cuda.jiterator import _create_jit_fn
16*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental.proxy_tensor import make_fx
17*da0073e9SAndroid Build Coastguard Workerfrom torch.fx.experimental.symbolic_shapes import ShapeEnv
18*da0073e9SAndroid Build Coastguard Workerfrom torch.library import _scoped_library, fallthrough_kernel, impl, Library
19*da0073e9SAndroid Build Coastguard Workerfrom torch.multiprocessing.reductions import StorageWeakRef
20*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_device_type import (
21*da0073e9SAndroid Build Coastguard Worker    instantiate_device_type_tests,
22*da0073e9SAndroid Build Coastguard Worker    ops,
23*da0073e9SAndroid Build Coastguard Worker)
24*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_methods_invocations import op_db
25*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.common_utils import (
26*da0073e9SAndroid Build Coastguard Worker    first_sample,
27*da0073e9SAndroid Build Coastguard Worker    IS_WINDOWS,
28*da0073e9SAndroid Build Coastguard Worker    run_tests,
29*da0073e9SAndroid Build Coastguard Worker    TEST_WITH_ROCM,
30*da0073e9SAndroid Build Coastguard Worker    TestCase,
31*da0073e9SAndroid Build Coastguard Worker)
32*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.custom_op_db import custom_op_db
33*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.logging_tensor import (
34*da0073e9SAndroid Build Coastguard Worker    capture_logs,
35*da0073e9SAndroid Build Coastguard Worker    capture_logs_with_logging_tensor_mode,
36*da0073e9SAndroid Build Coastguard Worker    log_input,
37*da0073e9SAndroid Build Coastguard Worker    LoggingTensor,
38*da0073e9SAndroid Build Coastguard Worker    LoggingTensorMode,
39*da0073e9SAndroid Build Coastguard Worker    LoggingTensorReentrant,
40*da0073e9SAndroid Build Coastguard Worker)
41*da0073e9SAndroid Build Coastguard Workerfrom torch.testing._internal.two_tensor import TwoTensor
42*da0073e9SAndroid Build Coastguard Workerfrom torch.utils import _pytree as pytree
43*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._mode_utils import all_same_mode, no_dispatch
44*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._python_dispatch import (
45*da0073e9SAndroid Build Coastguard Worker    _get_current_dispatch_mode,
46*da0073e9SAndroid Build Coastguard Worker    _get_current_dispatch_mode_stack,
47*da0073e9SAndroid Build Coastguard Worker    is_in_torch_dispatch_mode,
48*da0073e9SAndroid Build Coastguard Worker    TorchDispatchMode,
49*da0073e9SAndroid Build Coastguard Worker)
50*da0073e9SAndroid Build Coastguard Workerfrom torch.utils._pytree import tree_map, tree_map_only
51*da0073e9SAndroid Build Coastguard Worker
52*da0073e9SAndroid Build Coastguard Worker
53*da0073e9SAndroid Build Coastguard Worker# used as DataLoader collate_fn below; named here to avoid trying to pickle a lambda
54*da0073e9SAndroid Build Coastguard Workerdef _identity(x):
55*da0073e9SAndroid Build Coastguard Worker    return x
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Workerclass TestDispatcherPythonBindings(TestCase):
59*da0073e9SAndroid Build Coastguard Worker    def test_call_boxed(self) -> None:
60*da0073e9SAndroid Build Coastguard Worker        sin = torch._C._dispatch_find_schema_or_throw("aten::sin", "")
61*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
62*da0073e9SAndroid Build Coastguard Worker        y = torch._C._dispatch_call_boxed(sin, x)
63*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y, x.sin())
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Workerclass TestPythonRegistration(TestCase):
67*da0073e9SAndroid Build Coastguard Worker    test_ns = "_test_python_registration"
68*da0073e9SAndroid Build Coastguard Worker
69*da0073e9SAndroid Build Coastguard Worker    def tearDown(self):
70*da0073e9SAndroid Build Coastguard Worker        if hasattr(torch.ops, self.test_ns):
71*da0073e9SAndroid Build Coastguard Worker            del torch.ops._test_python_registration
72*da0073e9SAndroid Build Coastguard Worker
73*da0073e9SAndroid Build Coastguard Worker    def test_fallback(self) -> None:
74*da0073e9SAndroid Build Coastguard Worker        test_key = "TESTING_ONLY_GenericMode"
75*da0073e9SAndroid Build Coastguard Worker        test_keyset = torch._C.DispatchKeySet(test_key)
76*da0073e9SAndroid Build Coastguard Worker        include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset
77*da0073e9SAndroid Build Coastguard Worker        exclude_to_set = torch._C._dispatch_tls_local_exclude_set()
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker        with _scoped_library("_", "IMPL") as my_lib:
80*da0073e9SAndroid Build Coastguard Worker            expected_op = None
81*da0073e9SAndroid Build Coastguard Worker            expected_args = None
82*da0073e9SAndroid Build Coastguard Worker            expected_kwargs = None
83*da0073e9SAndroid Build Coastguard Worker            # Use this out shape to make sure the result from our fallback
84*da0073e9SAndroid Build Coastguard Worker            # is what is returned to the user
85*da0073e9SAndroid Build Coastguard Worker            out_shape = None
86*da0073e9SAndroid Build Coastguard Worker
87*da0073e9SAndroid Build Coastguard Worker            def my_fallback(op, *args, **kwargs):
88*da0073e9SAndroid Build Coastguard Worker                # Disable our handler during checks and generating the output
89*da0073e9SAndroid Build Coastguard Worker                with torch._C._ForceDispatchKeyGuard(
90*da0073e9SAndroid Build Coastguard Worker                    include_to_set, exclude_to_set | test_keyset
91*da0073e9SAndroid Build Coastguard Worker                ):
92*da0073e9SAndroid Build Coastguard Worker                    self.assertIs(op, expected_op)
93*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(args, expected_args)
94*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(kwargs, expected_kwargs)
95*da0073e9SAndroid Build Coastguard Worker                    # Return something specific
96*da0073e9SAndroid Build Coastguard Worker                    return torch.empty(out_shape)
97*da0073e9SAndroid Build Coastguard Worker
98*da0073e9SAndroid Build Coastguard Worker            my_lib.fallback(my_fallback, test_key)
99*da0073e9SAndroid Build Coastguard Worker
100*da0073e9SAndroid Build Coastguard Worker            a, b = torch.rand(2), torch.rand(2)
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker            with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
103*da0073e9SAndroid Build Coastguard Worker                # Check a factory function
104*da0073e9SAndroid Build Coastguard Worker                expected_op = torch.ops.aten.empty.memory_format
105*da0073e9SAndroid Build Coastguard Worker                expected_args = ((2, 2),)
106*da0073e9SAndroid Build Coastguard Worker                # Extra kwargs to bypass issues with default args in factory functions
107*da0073e9SAndroid Build Coastguard Worker                expected_kwargs = {
108*da0073e9SAndroid Build Coastguard Worker                    "dtype": torch.float64,
109*da0073e9SAndroid Build Coastguard Worker                    "pin_memory": False,
110*da0073e9SAndroid Build Coastguard Worker                    "device": torch.device("cpu"),
111*da0073e9SAndroid Build Coastguard Worker                }
112*da0073e9SAndroid Build Coastguard Worker                out_shape = (3,)
113*da0073e9SAndroid Build Coastguard Worker                out = torch.empty(*expected_args, **expected_kwargs)
114*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out.size(), out_shape)
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker                # Check a regular function
117*da0073e9SAndroid Build Coastguard Worker                expected_op = torch.ops.aten.add.Tensor
118*da0073e9SAndroid Build Coastguard Worker                expected_args = (a, b)
119*da0073e9SAndroid Build Coastguard Worker                expected_kwargs = {}
120*da0073e9SAndroid Build Coastguard Worker                out_shape = (4,)
121*da0073e9SAndroid Build Coastguard Worker                out = a + b
122*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(out.size(), out_shape)
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker    def test_fallback_keyset(self) -> None:
125*da0073e9SAndroid Build Coastguard Worker        test_key_first = "TESTING_ONLY_GenericMode"
126*da0073e9SAndroid Build Coastguard Worker        test_key_second = "TESTING_ONLY_GenericWrapper"
127*da0073e9SAndroid Build Coastguard Worker        test_keyset = torch._C.DispatchKeySet(test_key_first) | torch._C.DispatchKeySet(
128*da0073e9SAndroid Build Coastguard Worker            test_key_second
129*da0073e9SAndroid Build Coastguard Worker        )
130*da0073e9SAndroid Build Coastguard Worker        include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset
131*da0073e9SAndroid Build Coastguard Worker        exclude_to_set = torch._C._dispatch_tls_local_exclude_set()
132*da0073e9SAndroid Build Coastguard Worker
133*da0073e9SAndroid Build Coastguard Worker        with _scoped_library("_", "IMPL") as my_lib:
134*da0073e9SAndroid Build Coastguard Worker            first_called = False
135*da0073e9SAndroid Build Coastguard Worker            second_called = False
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker            def first_fallback(keyset, op, *args, **kwargs):
138*da0073e9SAndroid Build Coastguard Worker                nonlocal first_called
139*da0073e9SAndroid Build Coastguard Worker                if second_called:
140*da0073e9SAndroid Build Coastguard Worker                    # Recursive call
141*da0073e9SAndroid Build Coastguard Worker                    first_called = True
142*da0073e9SAndroid Build Coastguard Worker                    with torch._C._ForceDispatchKeyGuard(
143*da0073e9SAndroid Build Coastguard Worker                        include_to_set, exclude_to_set | test_keyset
144*da0073e9SAndroid Build Coastguard Worker                    ):
145*da0073e9SAndroid Build Coastguard Worker                        return op(*args, **kwargs)
146*da0073e9SAndroid Build Coastguard Worker                else:
147*da0073e9SAndroid Build Coastguard Worker                    # Redispatch down
148*da0073e9SAndroid Build Coastguard Worker                    keyset = keyset.remove(test_key_first)
149*da0073e9SAndroid Build Coastguard Worker                    return op.redispatch(keyset, *args, **kwargs)
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker            def second_fallback(op, *args, **kwargs):
152*da0073e9SAndroid Build Coastguard Worker                nonlocal second_called
153*da0073e9SAndroid Build Coastguard Worker                # Set to avoid infinite recursion
154*da0073e9SAndroid Build Coastguard Worker                second_called = True
155*da0073e9SAndroid Build Coastguard Worker                # New dispatcher call should hit the first callback again
156*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(first_called)
157*da0073e9SAndroid Build Coastguard Worker                a, b = args
158*da0073e9SAndroid Build Coastguard Worker                # Make a substraction here instead of add !
159*da0073e9SAndroid Build Coastguard Worker                c = a - b
160*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(first_called)
161*da0073e9SAndroid Build Coastguard Worker                return c
162*da0073e9SAndroid Build Coastguard Worker
163*da0073e9SAndroid Build Coastguard Worker            my_lib.fallback(first_fallback, test_key_first, with_keyset=True)
164*da0073e9SAndroid Build Coastguard Worker            my_lib.fallback(second_fallback, test_key_second)
165*da0073e9SAndroid Build Coastguard Worker
166*da0073e9SAndroid Build Coastguard Worker            a, b = torch.rand(2), torch.rand(2)
167*da0073e9SAndroid Build Coastguard Worker            with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
168*da0073e9SAndroid Build Coastguard Worker                c = a + b
169*da0073e9SAndroid Build Coastguard Worker
170*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(c, a - b)
171*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(first_called)
172*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(second_called)
173*da0073e9SAndroid Build Coastguard Worker
174*da0073e9SAndroid Build Coastguard Worker    def test_fallback_fallthrough(self) -> None:
175*da0073e9SAndroid Build Coastguard Worker        test_key_first = "TESTING_ONLY_GenericMode"
176*da0073e9SAndroid Build Coastguard Worker        test_key_second = "TESTING_ONLY_GenericWrapper"
177*da0073e9SAndroid Build Coastguard Worker        test_keyset = torch._C.DispatchKeySet(test_key_first) | torch._C.DispatchKeySet(
178*da0073e9SAndroid Build Coastguard Worker            test_key_second
179*da0073e9SAndroid Build Coastguard Worker        )
180*da0073e9SAndroid Build Coastguard Worker        include_to_set = torch._C._dispatch_tls_local_include_set() | test_keyset
181*da0073e9SAndroid Build Coastguard Worker        exclude_to_set = torch._C._dispatch_tls_local_exclude_set()
182*da0073e9SAndroid Build Coastguard Worker
183*da0073e9SAndroid Build Coastguard Worker        with _scoped_library("_", "IMPL") as my_lib:
184*da0073e9SAndroid Build Coastguard Worker            is_called = False
185*da0073e9SAndroid Build Coastguard Worker
186*da0073e9SAndroid Build Coastguard Worker            def my_fallback(op, *args, **kwargs):
187*da0073e9SAndroid Build Coastguard Worker                nonlocal is_called
188*da0073e9SAndroid Build Coastguard Worker                is_called = True
189*da0073e9SAndroid Build Coastguard Worker                with torch._C._ForceDispatchKeyGuard(
190*da0073e9SAndroid Build Coastguard Worker                    include_to_set, exclude_to_set | test_keyset
191*da0073e9SAndroid Build Coastguard Worker                ):
192*da0073e9SAndroid Build Coastguard Worker                    return op(*args, **kwargs)
193*da0073e9SAndroid Build Coastguard Worker
194*da0073e9SAndroid Build Coastguard Worker            my_lib.fallback(torch.library.fallthrough_kernel, test_key_first)
195*da0073e9SAndroid Build Coastguard Worker            my_lib.fallback(my_fallback, test_key_second)
196*da0073e9SAndroid Build Coastguard Worker
197*da0073e9SAndroid Build Coastguard Worker            a, b = torch.rand(2), torch.rand(2)
198*da0073e9SAndroid Build Coastguard Worker            with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
199*da0073e9SAndroid Build Coastguard Worker                c = a + b
200*da0073e9SAndroid Build Coastguard Worker
201*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(c, a + b)
202*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(is_called)
203*da0073e9SAndroid Build Coastguard Worker
204*da0073e9SAndroid Build Coastguard Worker    def test_override_aten_ops_with_multiple_libraries(self) -> None:
205*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor([1, 2])
206*da0073e9SAndroid Build Coastguard Worker        with _scoped_library("aten", "IMPL") as my_lib2:
207*da0073e9SAndroid Build Coastguard Worker            with _scoped_library("aten", "IMPL") as my_lib1:
208*da0073e9SAndroid Build Coastguard Worker                # Example 1
209*da0073e9SAndroid Build Coastguard Worker                def my_neg(*args, **kwargs):
210*da0073e9SAndroid Build Coastguard Worker                    return args[0]._neg_view()
211*da0073e9SAndroid Build Coastguard Worker
212*da0073e9SAndroid Build Coastguard Worker                # Now we are secretly making the operator a view op so autograd needs to know how
213*da0073e9SAndroid Build Coastguard Worker                # to handle it
214*da0073e9SAndroid Build Coastguard Worker                my_lib1.impl("neg", my_neg, "AutogradCPU")
215*da0073e9SAndroid Build Coastguard Worker
216*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(torch.neg(x).is_neg())
217*da0073e9SAndroid Build Coastguard Worker
218*da0073e9SAndroid Build Coastguard Worker                # RuntimeError: impl("aten::neg", ...):
219*da0073e9SAndroid Build Coastguard Worker                # Explicitly provided namespace (aten) in operator name does not match ...
220*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
221*da0073e9SAndroid Build Coastguard Worker                    RuntimeError, "operator name does not match namespace"
222*da0073e9SAndroid Build Coastguard Worker                ):
223*da0073e9SAndroid Build Coastguard Worker                    with _scoped_library("foo", "DEF") as my_lib3:
224*da0073e9SAndroid Build Coastguard Worker                        my_lib3.define("neg(Tensor self) -> Tensor")
225*da0073e9SAndroid Build Coastguard Worker                        my_lib3.impl(torch.ops.aten.neg.default, my_neg, "AutogradCPU")
226*da0073e9SAndroid Build Coastguard Worker
227*da0073e9SAndroid Build Coastguard Worker                # Example 2
228*da0073e9SAndroid Build Coastguard Worker                def my_mul(*args, **kwargs):
229*da0073e9SAndroid Build Coastguard Worker                    return torch.zeros_like(args[0])
230*da0073e9SAndroid Build Coastguard Worker
231*da0073e9SAndroid Build Coastguard Worker                # torch.ops.aten.mul.Tensor
232*da0073e9SAndroid Build Coastguard Worker                my_lib2.impl("aten::mul.Tensor", my_mul, "ZeroTensor")
233*da0073e9SAndroid Build Coastguard Worker
234*da0073e9SAndroid Build Coastguard Worker                y = torch._efficientzerotensor(2)
235*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(torch.mul(x, y)._is_zerotensor())
236*da0073e9SAndroid Build Coastguard Worker
237*da0073e9SAndroid Build Coastguard Worker                # Assert that a user can't override the behavior of a (ns, op, dispatch_key)
238*da0073e9SAndroid Build Coastguard Worker                # combination if someone overridden the behavior for the same before them
239*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(
240*da0073e9SAndroid Build Coastguard Worker                    RuntimeError, "already a kernel registered from python"
241*da0073e9SAndroid Build Coastguard Worker                ):
242*da0073e9SAndroid Build Coastguard Worker                    my_lib2.impl(torch.ops.aten.mul.Tensor, my_mul, "ZeroTensor")
243*da0073e9SAndroid Build Coastguard Worker
244*da0073e9SAndroid Build Coastguard Worker            # Validate that lib2 is not affected by removing lib1
245*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(torch.mul(x, y)._is_zerotensor())
246*da0073e9SAndroid Build Coastguard Worker
247*da0073e9SAndroid Build Coastguard Worker        # Validate that the old behavior is restored for neg and mul
248*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch.neg(x).is_neg())
249*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.mul(x, y)._is_zerotensor())
250*da0073e9SAndroid Build Coastguard Worker
251*da0073e9SAndroid Build Coastguard Worker    def test_error_if_fn_not_callable(self):
252*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
253*da0073e9SAndroid Build Coastguard Worker            TypeError, "Input function is required to be a callable"
254*da0073e9SAndroid Build Coastguard Worker        ):
255*da0073e9SAndroid Build Coastguard Worker            with _scoped_library("aten", "IMPL") as my_lib:
256*da0073e9SAndroid Build Coastguard Worker                my_lib.impl(torch.ops.aten.neg.default, [], "AutogradCPU")
257*da0073e9SAndroid Build Coastguard Worker
258*da0073e9SAndroid Build Coastguard Worker    def test_finalizer(self):
259*da0073e9SAndroid Build Coastguard Worker        impls_refcnt = sys.getrefcount(torch.library._impls)
260*da0073e9SAndroid Build Coastguard Worker        lib = Library(self.test_ns, "FRAGMENT")  # noqa: TOR901
261*da0073e9SAndroid Build Coastguard Worker        lib.define("foo123(Tensor x) -> Tensor")
262*da0073e9SAndroid Build Coastguard Worker
263*da0073e9SAndroid Build Coastguard Worker        # 1 for `lib`, 1 for sys.getrefcount
264*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(sys.getrefcount(lib), 2)
265*da0073e9SAndroid Build Coastguard Worker        # We gained an additional reference that gets cleared when the finalizer runs
266*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(sys.getrefcount(torch.library._impls), impls_refcnt + 1)
267*da0073e9SAndroid Build Coastguard Worker        # 1 for `lib`
268*da0073e9SAndroid Build Coastguard Worker        # 1 for the finalizer
269*da0073e9SAndroid Build Coastguard Worker        # 1 for sys.getrefcount
270*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(sys.getrefcount(lib._op_impls), 3)
271*da0073e9SAndroid Build Coastguard Worker
272*da0073e9SAndroid Build Coastguard Worker        def foo123(x):
273*da0073e9SAndroid Build Coastguard Worker            pass
274*da0073e9SAndroid Build Coastguard Worker
275*da0073e9SAndroid Build Coastguard Worker        lib.impl(f"{self.test_ns}::foo123", foo123, "CPU")
276*da0073e9SAndroid Build Coastguard Worker        key = f"{self.test_ns}/foo123/CPU"
277*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(key in torch.library._impls)
278*da0073e9SAndroid Build Coastguard Worker
279*da0073e9SAndroid Build Coastguard Worker        saved_op_impls = lib._op_impls
280*da0073e9SAndroid Build Coastguard Worker
281*da0073e9SAndroid Build Coastguard Worker        # del will definitely work if the following passes
282*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(sys.getrefcount(lib), 2)
283*da0073e9SAndroid Build Coastguard Worker        del lib
284*da0073e9SAndroid Build Coastguard Worker
285*da0073e9SAndroid Build Coastguard Worker        # 1 for saved_op_impls
286*da0073e9SAndroid Build Coastguard Worker        # 1 for sys.getrefcount
287*da0073e9SAndroid Build Coastguard Worker        # This function should be the last user of lib._op_impls:
288*da0073e9SAndroid Build Coastguard Worker        # - lib should not have a reference anymore (it was del'ed)
289*da0073e9SAndroid Build Coastguard Worker        # - lib's finalizer should not have a reference anymore
290*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(sys.getrefcount(saved_op_impls), 2)
291*da0073e9SAndroid Build Coastguard Worker
292*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(key not in torch.library._impls)
293*da0073e9SAndroid Build Coastguard Worker
294*da0073e9SAndroid Build Coastguard Worker        # lib's finalizer should not have a reference anymore
295*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(sys.getrefcount(torch.library._impls), impls_refcnt)
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker    def test_override_cpu_sum(self) -> None:
298*da0073e9SAndroid Build Coastguard Worker        # Example 1
299*da0073e9SAndroid Build Coastguard Worker        run = [False]
300*da0073e9SAndroid Build Coastguard Worker
301*da0073e9SAndroid Build Coastguard Worker        def my_sum(*args, **kwargs):
302*da0073e9SAndroid Build Coastguard Worker            run[0] = True
303*da0073e9SAndroid Build Coastguard Worker            return args[0].clone()
304*da0073e9SAndroid Build Coastguard Worker
305*da0073e9SAndroid Build Coastguard Worker        with _scoped_library("aten", "IMPL") as my_lib1:
306*da0073e9SAndroid Build Coastguard Worker            my_lib1.impl("aten::sum", my_sum, "CPU")
307*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor([1, 2])
308*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.sum(x), x)
309*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(run[0])
310*da0073e9SAndroid Build Coastguard Worker        # Validate that the old behavior is restored for sum
311*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(torch.sum(x), torch.tensor(3))
312*da0073e9SAndroid Build Coastguard Worker
313*da0073e9SAndroid Build Coastguard Worker    def test_override_cuda_with_jiterator(self) -> None:
314*da0073e9SAndroid Build Coastguard Worker        def override_where_cuda() -> None:
315*da0073e9SAndroid Build Coastguard Worker            # Example 1: Invert the behavior of where's condition input
316*da0073e9SAndroid Build Coastguard Worker            not_where_code_string = """
317*da0073e9SAndroid Build Coastguard Worker            template <typename T> T inverted_where(bool cond, T a, T b){
318*da0073e9SAndroid Build Coastguard Worker                return !cond ? a : b;
319*da0073e9SAndroid Build Coastguard Worker            }
320*da0073e9SAndroid Build Coastguard Worker            """
321*da0073e9SAndroid Build Coastguard Worker            jitted_where = _create_jit_fn(not_where_code_string)
322*da0073e9SAndroid Build Coastguard Worker
323*da0073e9SAndroid Build Coastguard Worker            CALLED = [False]
324*da0073e9SAndroid Build Coastguard Worker
325*da0073e9SAndroid Build Coastguard Worker            def inverted_where(*args, **kwargs):
326*da0073e9SAndroid Build Coastguard Worker                CALLED[0] = True
327*da0073e9SAndroid Build Coastguard Worker                return jitted_where(*args, **kwargs)
328*da0073e9SAndroid Build Coastguard Worker
329*da0073e9SAndroid Build Coastguard Worker            # overriding where's cuda kernel with Jiterator generated kernel
330*da0073e9SAndroid Build Coastguard Worker            with _scoped_library("aten", "IMPL") as my_lib:
331*da0073e9SAndroid Build Coastguard Worker                my_lib.impl("aten::where.self", inverted_where, "CUDA")
332*da0073e9SAndroid Build Coastguard Worker
333*da0073e9SAndroid Build Coastguard Worker                device = "cuda"
334*da0073e9SAndroid Build Coastguard Worker                cond = torch.tensor(
335*da0073e9SAndroid Build Coastguard Worker                    [True, True, False], device=device, dtype=torch.bool
336*da0073e9SAndroid Build Coastguard Worker                )
337*da0073e9SAndroid Build Coastguard Worker                x = torch.tensor([1, 2, 3], device=device)
338*da0073e9SAndroid Build Coastguard Worker                y = torch.tensor([-1, -2, -3], device=device)
339*da0073e9SAndroid Build Coastguard Worker
340*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(torch.where(cond, x, y), torch.tensor([-1, -2, 3]))
341*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(CALLED[0])
342*da0073e9SAndroid Build Coastguard Worker
343*da0073e9SAndroid Build Coastguard Worker            # behavior restored after deregistration
344*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.where(cond, x, y), torch.tensor([1, 2, -3]))
345*da0073e9SAndroid Build Coastguard Worker
346*da0073e9SAndroid Build Coastguard Worker        def override_gelu_cuda() -> None:
347*da0073e9SAndroid Build Coastguard Worker            # Example 2: Use relu to approximate gelu for faster compute
348*da0073e9SAndroid Build Coastguard Worker            fastest_gelu_code_string = """
349*da0073e9SAndroid Build Coastguard Worker            template <typename T> T fast_gelu(T a){
350*da0073e9SAndroid Build Coastguard Worker                return a > 0 ? a : 0;
351*da0073e9SAndroid Build Coastguard Worker            }
352*da0073e9SAndroid Build Coastguard Worker            """
353*da0073e9SAndroid Build Coastguard Worker            jitted_gelu = _create_jit_fn(fastest_gelu_code_string)
354*da0073e9SAndroid Build Coastguard Worker
355*da0073e9SAndroid Build Coastguard Worker            CALLED = [False]
356*da0073e9SAndroid Build Coastguard Worker
357*da0073e9SAndroid Build Coastguard Worker            def fast_gelu(*args, **kwargs):
358*da0073e9SAndroid Build Coastguard Worker                CALLED[0] = True
359*da0073e9SAndroid Build Coastguard Worker                return jitted_gelu(*args, **kwargs)
360*da0073e9SAndroid Build Coastguard Worker
361*da0073e9SAndroid Build Coastguard Worker            # overriding gelu's cuda kernel with Jiterator generated relu kernel
362*da0073e9SAndroid Build Coastguard Worker            with _scoped_library("aten", "IMPL") as my_lib:
363*da0073e9SAndroid Build Coastguard Worker                my_lib.impl("aten::gelu", fast_gelu, "CUDA")
364*da0073e9SAndroid Build Coastguard Worker
365*da0073e9SAndroid Build Coastguard Worker                x = torch.rand([3, 3], device="cuda", dtype=torch.float)
366*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
367*da0073e9SAndroid Build Coastguard Worker                    torch.nn.functional.gelu(x), torch.nn.functional.relu(x)
368*da0073e9SAndroid Build Coastguard Worker                )
369*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(CALLED[0])
370*da0073e9SAndroid Build Coastguard Worker
371*da0073e9SAndroid Build Coastguard Worker            # behavior restored after deregistration
372*da0073e9SAndroid Build Coastguard Worker            self.assertNotEqual(
373*da0073e9SAndroid Build Coastguard Worker                torch.nn.functional.gelu(x), torch.nn.functional.relu(x)
374*da0073e9SAndroid Build Coastguard Worker            )
375*da0073e9SAndroid Build Coastguard Worker
376*da0073e9SAndroid Build Coastguard Worker        def override_exp_cuda() -> None:
377*da0073e9SAndroid Build Coastguard Worker            # Example 3: Preventing exp from exploding for float16
378*da0073e9SAndroid Build Coastguard Worker            clipped_exp_code_string = """
379*da0073e9SAndroid Build Coastguard Worker            template <typename T> T clipped_exp(T a){
380*da0073e9SAndroid Build Coastguard Worker                return a > T(10.0) ? T(22026.4657948) : exp(a);
381*da0073e9SAndroid Build Coastguard Worker            }
382*da0073e9SAndroid Build Coastguard Worker            """
383*da0073e9SAndroid Build Coastguard Worker            jitted_exp = _create_jit_fn(clipped_exp_code_string)
384*da0073e9SAndroid Build Coastguard Worker
385*da0073e9SAndroid Build Coastguard Worker            CALLED = [False]
386*da0073e9SAndroid Build Coastguard Worker
387*da0073e9SAndroid Build Coastguard Worker            def clipped_exp(*args, **kwargs):
388*da0073e9SAndroid Build Coastguard Worker                CALLED[0] = True
389*da0073e9SAndroid Build Coastguard Worker                return jitted_exp(*args, **kwargs)
390*da0073e9SAndroid Build Coastguard Worker
391*da0073e9SAndroid Build Coastguard Worker            # overriding exp's cuda kernel with clipped_exp kernel
392*da0073e9SAndroid Build Coastguard Worker            with _scoped_library("aten", "IMPL") as my_lib:
393*da0073e9SAndroid Build Coastguard Worker                my_lib.impl("aten::exp", clipped_exp, "CUDA")
394*da0073e9SAndroid Build Coastguard Worker
395*da0073e9SAndroid Build Coastguard Worker                x = torch.tensor([0.0, 100.0], device="cuda", dtype=torch.float16)
396*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(
397*da0073e9SAndroid Build Coastguard Worker                    torch.exp(x),
398*da0073e9SAndroid Build Coastguard Worker                    torch.tensor([1.0, 22026.4657948], dtype=torch.float16),
399*da0073e9SAndroid Build Coastguard Worker                )
400*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(CALLED[0])
401*da0073e9SAndroid Build Coastguard Worker
402*da0073e9SAndroid Build Coastguard Worker            # behavior restored after deregistration
403*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
404*da0073e9SAndroid Build Coastguard Worker                torch.exp(x), torch.tensor([1.0, torch.inf], dtype=torch.float16)
405*da0073e9SAndroid Build Coastguard Worker            )
406*da0073e9SAndroid Build Coastguard Worker
407*da0073e9SAndroid Build Coastguard Worker        def override_add_cuda() -> None:
408*da0073e9SAndroid Build Coastguard Worker            # Example 4: simulate a hardware bug, where the adder is always off by 1
409*da0073e9SAndroid Build Coastguard Worker            buggy_add_code_string = """
410*da0073e9SAndroid Build Coastguard Worker            template <typename T> T buggy_add(T a, T b){
411*da0073e9SAndroid Build Coastguard Worker                return a + b + T(1);
412*da0073e9SAndroid Build Coastguard Worker            }
413*da0073e9SAndroid Build Coastguard Worker            """
414*da0073e9SAndroid Build Coastguard Worker            jitted_add = _create_jit_fn(buggy_add_code_string)
415*da0073e9SAndroid Build Coastguard Worker
416*da0073e9SAndroid Build Coastguard Worker            CALLED = [False]
417*da0073e9SAndroid Build Coastguard Worker
418*da0073e9SAndroid Build Coastguard Worker            def buggy_add(*args, **kwargs):
419*da0073e9SAndroid Build Coastguard Worker                CALLED[0] = True
420*da0073e9SAndroid Build Coastguard Worker                return jitted_add(*args, **kwargs)
421*da0073e9SAndroid Build Coastguard Worker
422*da0073e9SAndroid Build Coastguard Worker            with _scoped_library("aten", "IMPL") as my_lib:
423*da0073e9SAndroid Build Coastguard Worker                my_lib.impl("aten::add.Tensor", buggy_add, "CUDA")
424*da0073e9SAndroid Build Coastguard Worker
425*da0073e9SAndroid Build Coastguard Worker                x_cpu = torch.rand([3, 3], device="cpu")
426*da0073e9SAndroid Build Coastguard Worker                y_cpu = torch.rand([3], device="cpu")
427*da0073e9SAndroid Build Coastguard Worker
428*da0073e9SAndroid Build Coastguard Worker                x_cuda = x_cpu.cuda()
429*da0073e9SAndroid Build Coastguard Worker                y_cuda = y_cpu.cuda()
430*da0073e9SAndroid Build Coastguard Worker
431*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(x_cuda + y_cuda, x_cpu + y_cpu + 1)
432*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(CALLED[0])
433*da0073e9SAndroid Build Coastguard Worker
434*da0073e9SAndroid Build Coastguard Worker            # behavior restored after deregistration
435*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x_cuda + y_cuda, x_cpu + y_cpu)
436*da0073e9SAndroid Build Coastguard Worker
437*da0073e9SAndroid Build Coastguard Worker        if torch.cuda.is_available() and not TEST_WITH_ROCM:
438*da0073e9SAndroid Build Coastguard Worker            override_where_cuda()
439*da0073e9SAndroid Build Coastguard Worker            override_gelu_cuda()
440*da0073e9SAndroid Build Coastguard Worker            override_exp_cuda()
441*da0073e9SAndroid Build Coastguard Worker            override_add_cuda()
442*da0073e9SAndroid Build Coastguard Worker
443*da0073e9SAndroid Build Coastguard Worker    def test_extend_library_with_dispatch_key_arg(self):
444*da0073e9SAndroid Build Coastguard Worker        def my_sum(*args, **kwargs):
445*da0073e9SAndroid Build Coastguard Worker            return args[0].clone()
446*da0073e9SAndroid Build Coastguard Worker
447*da0073e9SAndroid Build Coastguard Worker        with _scoped_library("aten", "IMPL", dispatch_key="CPU") as my_lib1:
448*da0073e9SAndroid Build Coastguard Worker            # RuntimeError: Explicitly provided dispatch key (Conjugate) is
449*da0073e9SAndroid Build Coastguard Worker            # inconsistent with the dispatch key of the enclosing TORCH_LIBRARY_IMPL block
450*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
451*da0073e9SAndroid Build Coastguard Worker                RuntimeError, "inconsistent with the dispatch key"
452*da0073e9SAndroid Build Coastguard Worker            ):
453*da0073e9SAndroid Build Coastguard Worker                my_lib1.impl("sum", my_sum, "Conjugate")
454*da0073e9SAndroid Build Coastguard Worker            my_lib1.impl("aten::sum", my_sum)
455*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor([1, 2])
456*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.sum(x), x)
457*da0073e9SAndroid Build Coastguard Worker
458*da0073e9SAndroid Build Coastguard Worker    def test_create_new_library(self) -> None:
459*da0073e9SAndroid Build Coastguard Worker        with _scoped_library(self.test_ns, "DEF") as my_lib1:
460*da0073e9SAndroid Build Coastguard Worker            my_lib1.define("sum(Tensor self) -> Tensor")
461*da0073e9SAndroid Build Coastguard Worker
462*da0073e9SAndroid Build Coastguard Worker            # Example 1
463*da0073e9SAndroid Build Coastguard Worker            @torch.library.impl(my_lib1, "sum", "CPU")
464*da0073e9SAndroid Build Coastguard Worker            def my_sum(*args, **kwargs):
465*da0073e9SAndroid Build Coastguard Worker                return args[0].clone()
466*da0073e9SAndroid Build Coastguard Worker
467*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor([1, 2])
468*da0073e9SAndroid Build Coastguard Worker            op = getattr(torch.ops, self.test_ns).sum
469*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(op(x), x)
470*da0073e9SAndroid Build Coastguard Worker
471*da0073e9SAndroid Build Coastguard Worker            with _scoped_library(self.test_ns, "IMPL") as my_lib2:
472*da0073e9SAndroid Build Coastguard Worker                # Example 2
473*da0073e9SAndroid Build Coastguard Worker                @torch.library.impl(my_lib2, op.default, "ZeroTensor")
474*da0073e9SAndroid Build Coastguard Worker                def my_sum_zt(*args, **kwargs):
475*da0073e9SAndroid Build Coastguard Worker                    if args[0]._is_zerotensor():
476*da0073e9SAndroid Build Coastguard Worker                        return torch._efficientzerotensor(args[0].shape)
477*da0073e9SAndroid Build Coastguard Worker                    else:
478*da0073e9SAndroid Build Coastguard Worker                        return args[0].clone()
479*da0073e9SAndroid Build Coastguard Worker
480*da0073e9SAndroid Build Coastguard Worker                y = torch._efficientzerotensor(3)
481*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(op(y)._is_zerotensor())
482*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(op(x), x)
483*da0073e9SAndroid Build Coastguard Worker
484*da0073e9SAndroid Build Coastguard Worker    def test_create_new_library_fragment_no_existing(self):
485*da0073e9SAndroid Build Coastguard Worker        with _scoped_library(self.test_ns, "FRAGMENT") as my_lib:
486*da0073e9SAndroid Build Coastguard Worker            my_lib.define("sum2(Tensor self) -> Tensor")
487*da0073e9SAndroid Build Coastguard Worker
488*da0073e9SAndroid Build Coastguard Worker            @torch.library.impl(my_lib, "sum2", "CPU")
489*da0073e9SAndroid Build Coastguard Worker            def my_sum(*args, **kwargs):
490*da0073e9SAndroid Build Coastguard Worker                return args[0]
491*da0073e9SAndroid Build Coastguard Worker
492*da0073e9SAndroid Build Coastguard Worker            x = torch.tensor([1, 2])
493*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(getattr(torch.ops, self.test_ns).sum2(x), x)
494*da0073e9SAndroid Build Coastguard Worker
495*da0073e9SAndroid Build Coastguard Worker    def test_create_new_library_fragment_with_existing(self):
496*da0073e9SAndroid Build Coastguard Worker        with _scoped_library(self.test_ns, "DEF") as my_lib1:
497*da0073e9SAndroid Build Coastguard Worker            # Create a fragment
498*da0073e9SAndroid Build Coastguard Worker            with _scoped_library(self.test_ns, "FRAGMENT") as my_lib2:
499*da0073e9SAndroid Build Coastguard Worker                my_lib2.define("sum4(Tensor self) -> Tensor")
500*da0073e9SAndroid Build Coastguard Worker
501*da0073e9SAndroid Build Coastguard Worker                @torch.library.impl(my_lib2, "sum4", "CPU")
502*da0073e9SAndroid Build Coastguard Worker                def my_sum4(*args, **kwargs):
503*da0073e9SAndroid Build Coastguard Worker                    return args[0]
504*da0073e9SAndroid Build Coastguard Worker
505*da0073e9SAndroid Build Coastguard Worker                x = torch.tensor([1, 2])
506*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(getattr(torch.ops, self.test_ns).sum4(x), x)
507*da0073e9SAndroid Build Coastguard Worker
508*da0073e9SAndroid Build Coastguard Worker                # Create another fragment
509*da0073e9SAndroid Build Coastguard Worker                with _scoped_library(self.test_ns, "FRAGMENT") as my_lib3:
510*da0073e9SAndroid Build Coastguard Worker                    my_lib3.define("sum3(Tensor self) -> Tensor")
511*da0073e9SAndroid Build Coastguard Worker
512*da0073e9SAndroid Build Coastguard Worker                    @torch.library.impl(my_lib3, "sum3", "CPU")
513*da0073e9SAndroid Build Coastguard Worker                    def my_sum3(*args, **kwargs):
514*da0073e9SAndroid Build Coastguard Worker                        return args[0]
515*da0073e9SAndroid Build Coastguard Worker
516*da0073e9SAndroid Build Coastguard Worker                    x = torch.tensor([1, 2])
517*da0073e9SAndroid Build Coastguard Worker                    self.assertEqual(getattr(torch.ops, self.test_ns).sum3(x), x)
518*da0073e9SAndroid Build Coastguard Worker
519*da0073e9SAndroid Build Coastguard Worker    @unittest.skipIf(IS_WINDOWS, "Skipped under Windows")
520*da0073e9SAndroid Build Coastguard Worker    def test_alias_analysis(self):
521*da0073e9SAndroid Build Coastguard Worker        def test_helper(alias_analysis=""):
522*da0073e9SAndroid Build Coastguard Worker            my_lib1 = Library(self.test_ns, "DEF")  # noqa: TOR901
523*da0073e9SAndroid Build Coastguard Worker
524*da0073e9SAndroid Build Coastguard Worker            called = [0]
525*da0073e9SAndroid Build Coastguard Worker
526*da0073e9SAndroid Build Coastguard Worker            @torch.library.define(
527*da0073e9SAndroid Build Coastguard Worker                my_lib1, "_op() -> None", alias_analysis=alias_analysis
528*da0073e9SAndroid Build Coastguard Worker            )
529*da0073e9SAndroid Build Coastguard Worker            def _op(*args, **kwargs):
530*da0073e9SAndroid Build Coastguard Worker                called[0] += 1
531*da0073e9SAndroid Build Coastguard Worker
532*da0073e9SAndroid Build Coastguard Worker            @torch.jit.script
533*da0073e9SAndroid Build Coastguard Worker            def _test():
534*da0073e9SAndroid Build Coastguard Worker                torch.ops._test_python_registration._op()
535*da0073e9SAndroid Build Coastguard Worker
536*da0073e9SAndroid Build Coastguard Worker            assert "_test_python_registration::_op" in str(_test.graph)
537*da0073e9SAndroid Build Coastguard Worker
538*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(AssertionError):
539*da0073e9SAndroid Build Coastguard Worker            test_helper("")  # alias_analysis="FROM_SCHEMA"
540*da0073e9SAndroid Build Coastguard Worker
541*da0073e9SAndroid Build Coastguard Worker        test_helper("CONSERVATIVE")
542*da0073e9SAndroid Build Coastguard Worker
543*da0073e9SAndroid Build Coastguard Worker    def test_error_for_unsupported_ns_or_kind(self) -> None:
544*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ValueError, "Unsupported kind"):
545*da0073e9SAndroid Build Coastguard Worker            my_lib1 = Library("myns", "BLA")  # noqa: TOR901
546*da0073e9SAndroid Build Coastguard Worker
547*da0073e9SAndroid Build Coastguard Worker        for kind in ("DEF", "FRAGMENT"):
548*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(ValueError, "reserved namespace"):
549*da0073e9SAndroid Build Coastguard Worker                my_lib1 = Library("prim", kind)  # noqa: TOR901
550*da0073e9SAndroid Build Coastguard Worker
551*da0073e9SAndroid Build Coastguard Worker    def test_returning_symint(self) -> None:
552*da0073e9SAndroid Build Coastguard Worker        shape_env = ShapeEnv()
553*da0073e9SAndroid Build Coastguard Worker        fake_tensor_mode = FakeTensorMode(shape_env=shape_env)
554*da0073e9SAndroid Build Coastguard Worker
555*da0073e9SAndroid Build Coastguard Worker        ft = fake_tensor_mode.from_tensor(torch.rand(2, 3))
556*da0073e9SAndroid Build Coastguard Worker
557*da0073e9SAndroid Build Coastguard Worker        s0, s1 = ft.shape
558*da0073e9SAndroid Build Coastguard Worker
559*da0073e9SAndroid Build Coastguard Worker        with _scoped_library(self.test_ns, "DEF") as tlib:
560*da0073e9SAndroid Build Coastguard Worker            tlib.define("sqsum(SymInt a, SymInt b) -> SymInt")
561*da0073e9SAndroid Build Coastguard Worker
562*da0073e9SAndroid Build Coastguard Worker            @impl(tlib, "sqsum", "CompositeExplicitAutograd")
563*da0073e9SAndroid Build Coastguard Worker            def sqsum(a: SymInt, b: SymInt):
564*da0073e9SAndroid Build Coastguard Worker                return a * a + b * b
565*da0073e9SAndroid Build Coastguard Worker
566*da0073e9SAndroid Build Coastguard Worker            out = getattr(torch.ops, self.test_ns).sqsum.default(s0, s1)
567*da0073e9SAndroid Build Coastguard Worker            out_val = shape_env.evaluate_expr(out.node.expr)
568*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(out_val, 13)
569*da0073e9SAndroid Build Coastguard Worker
570*da0073e9SAndroid Build Coastguard Worker    def test_register_functional_op_error_cases(self):
571*da0073e9SAndroid Build Coastguard Worker        with _scoped_library(self.test_ns, "FRAGMENT") as lib:
572*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(TypeError, "instance of OpOverload"):
573*da0073e9SAndroid Build Coastguard Worker                register_functional_op(lib, "abs", torch.ops.aten.abs_)
574*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Expected op to be mutable"):
575*da0073e9SAndroid Build Coastguard Worker                register_functional_op(lib, "abs", torch.ops.aten.abs_.default)
576*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(RuntimeError, "Expected op to be mutable"):
577*da0073e9SAndroid Build Coastguard Worker                register_functional_op(lib, "abs", torch.ops.aten.abs.out)
578*da0073e9SAndroid Build Coastguard Worker
579*da0073e9SAndroid Build Coastguard Worker            schemas = [
580*da0073e9SAndroid Build Coastguard Worker                "foo(Tensor x, Tensor(a!)[] y) -> ()",
581*da0073e9SAndroid Build Coastguard Worker                "foo(Tensor x, Tensor(a!) y, Tensor(b) z) -> Tensor(b)",
582*da0073e9SAndroid Build Coastguard Worker                "foo(Tensor x, Tensor(a!) y) -> (Tensor, Tensor(a))",
583*da0073e9SAndroid Build Coastguard Worker            ]
584*da0073e9SAndroid Build Coastguard Worker
585*da0073e9SAndroid Build Coastguard Worker        for schema in schemas:
586*da0073e9SAndroid Build Coastguard Worker            with _scoped_library(self.test_ns, "FRAGMENT") as lib:
587*da0073e9SAndroid Build Coastguard Worker                lib.define(schema)
588*da0073e9SAndroid Build Coastguard Worker                with self.assertRaisesRegex(RuntimeError, "NYI"):
589*da0073e9SAndroid Build Coastguard Worker                    register_functional_op(
590*da0073e9SAndroid Build Coastguard Worker                        lib,
591*da0073e9SAndroid Build Coastguard Worker                        "foo_functional",
592*da0073e9SAndroid Build Coastguard Worker                        getattr(torch.ops, self.test_ns).foo.default,
593*da0073e9SAndroid Build Coastguard Worker                    )
594*da0073e9SAndroid Build Coastguard Worker
595*da0073e9SAndroid Build Coastguard Worker    def _check_is_functional_variant(self, mutable_op, functional_op, args):
596*da0073e9SAndroid Build Coastguard Worker        # functional op should not mutate
597*da0073e9SAndroid Build Coastguard Worker        cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args)
598*da0073e9SAndroid Build Coastguard Worker        functional_result = functional_op(*cloned_args)
599*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cloned_args, args)
600*da0073e9SAndroid Build Coastguard Worker
601*da0073e9SAndroid Build Coastguard Worker        # check functional_result includes mutable_result
602*da0073e9SAndroid Build Coastguard Worker        mutable_result = mutable_op(*cloned_args)
603*da0073e9SAndroid Build Coastguard Worker        if mutable_result is None:
604*da0073e9SAndroid Build Coastguard Worker            flat_mutable_result = []
605*da0073e9SAndroid Build Coastguard Worker        else:
606*da0073e9SAndroid Build Coastguard Worker            flat_mutable_result = pytree.tree_leaves(mutable_result)
607*da0073e9SAndroid Build Coastguard Worker        flat_functional_result = pytree.tree_leaves(functional_result)
608*da0073e9SAndroid Build Coastguard Worker        assert len(flat_functional_result) > len(flat_mutable_result)
609*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
610*da0073e9SAndroid Build Coastguard Worker            flat_functional_result[: len(flat_mutable_result)], flat_mutable_result
611*da0073e9SAndroid Build Coastguard Worker        )
612*da0073e9SAndroid Build Coastguard Worker
613*da0073e9SAndroid Build Coastguard Worker        # check rest of functional_result is the mutated args
614*da0073e9SAndroid Build Coastguard Worker        mutated_args = [
615*da0073e9SAndroid Build Coastguard Worker            maybe_mutated_arg
616*da0073e9SAndroid Build Coastguard Worker            for maybe_mutated_arg, arg in zip(cloned_args, args)
617*da0073e9SAndroid Build Coastguard Worker            if not (
618*da0073e9SAndroid Build Coastguard Worker                maybe_mutated_arg is not None
619*da0073e9SAndroid Build Coastguard Worker                and arg is not None
620*da0073e9SAndroid Build Coastguard Worker                and torch.allclose(maybe_mutated_arg, arg)
621*da0073e9SAndroid Build Coastguard Worker            )
622*da0073e9SAndroid Build Coastguard Worker        ]
623*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
624*da0073e9SAndroid Build Coastguard Worker            flat_functional_result[len(flat_mutable_result) :], mutated_args
625*da0073e9SAndroid Build Coastguard Worker        )
626*da0073e9SAndroid Build Coastguard Worker
627*da0073e9SAndroid Build Coastguard Worker        # check that functionalization kernel was indeed registered
628*da0073e9SAndroid Build Coastguard Worker        def fn(*args):
629*da0073e9SAndroid Build Coastguard Worker            cloned_args = pytree.tree_map_only(torch.Tensor, torch.clone, args)
630*da0073e9SAndroid Build Coastguard Worker            mutable_op(*cloned_args)
631*da0073e9SAndroid Build Coastguard Worker            return cloned_args
632*da0073e9SAndroid Build Coastguard Worker
633*da0073e9SAndroid Build Coastguard Worker        gm = make_fx(torch.func.functionalize(fn))(*args)
634*da0073e9SAndroid Build Coastguard Worker        has_functional_op = False
635*da0073e9SAndroid Build Coastguard Worker        for node in gm.graph.nodes:
636*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(node.target is mutable_op)
637*da0073e9SAndroid Build Coastguard Worker            if node.target is functional_op:
638*da0073e9SAndroid Build Coastguard Worker                has_functional_op = True
639*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(has_functional_op)
640*da0073e9SAndroid Build Coastguard Worker
641*da0073e9SAndroid Build Coastguard Worker    def test_register_functional_op_no_returns(self):
642*da0073e9SAndroid Build Coastguard Worker        with _scoped_library(self.test_ns, "FRAGMENT") as lib:
643*da0073e9SAndroid Build Coastguard Worker            lib.define("foo(Tensor x, Tensor(a!) y, Tensor z, Tensor(b!) w) -> ()")
644*da0073e9SAndroid Build Coastguard Worker
645*da0073e9SAndroid Build Coastguard Worker            def foo_impl(x, y, z, w):
646*da0073e9SAndroid Build Coastguard Worker                y.fill_(3.14)
647*da0073e9SAndroid Build Coastguard Worker                w.fill_(2.71)
648*da0073e9SAndroid Build Coastguard Worker
649*da0073e9SAndroid Build Coastguard Worker            lib.impl("foo", foo_impl, "CPU")
650*da0073e9SAndroid Build Coastguard Worker            register_functional_op(
651*da0073e9SAndroid Build Coastguard Worker                lib, "foo_functional", getattr(torch.ops, self.test_ns).foo.default
652*da0073e9SAndroid Build Coastguard Worker            )
653*da0073e9SAndroid Build Coastguard Worker            x = torch.randn([])
654*da0073e9SAndroid Build Coastguard Worker            y = torch.randn([])
655*da0073e9SAndroid Build Coastguard Worker            z = torch.randn([])
656*da0073e9SAndroid Build Coastguard Worker            w = torch.randn([])
657*da0073e9SAndroid Build Coastguard Worker            self._check_is_functional_variant(
658*da0073e9SAndroid Build Coastguard Worker                getattr(torch.ops, self.test_ns).foo.default,
659*da0073e9SAndroid Build Coastguard Worker                getattr(torch.ops, self.test_ns).foo_functional.default,
660*da0073e9SAndroid Build Coastguard Worker                (x, y, z, w),
661*da0073e9SAndroid Build Coastguard Worker            )
662*da0073e9SAndroid Build Coastguard Worker
663*da0073e9SAndroid Build Coastguard Worker    def test_register_functional_op_with_optional(self):
664*da0073e9SAndroid Build Coastguard Worker        with _scoped_library(self.test_ns, "FRAGMENT") as lib:
665*da0073e9SAndroid Build Coastguard Worker            lib.define(
666*da0073e9SAndroid Build Coastguard Worker                "foo(Tensor x, Tensor(a!) y, Tensor (b!) z, Tensor(c!)? w) -> ()"
667*da0073e9SAndroid Build Coastguard Worker            )
668*da0073e9SAndroid Build Coastguard Worker
669*da0073e9SAndroid Build Coastguard Worker            def foo_impl(x, y, z, w):
670*da0073e9SAndroid Build Coastguard Worker                y.fill_(3.14)
671*da0073e9SAndroid Build Coastguard Worker                z.fill_(2.71)
672*da0073e9SAndroid Build Coastguard Worker                if w is not None:
673*da0073e9SAndroid Build Coastguard Worker                    w.fill_(1.618)
674*da0073e9SAndroid Build Coastguard Worker
675*da0073e9SAndroid Build Coastguard Worker            lib.impl("foo", foo_impl, "CPU")
676*da0073e9SAndroid Build Coastguard Worker            register_functional_op(
677*da0073e9SAndroid Build Coastguard Worker                lib, "foo_functional", getattr(torch.ops, self.test_ns).foo.default
678*da0073e9SAndroid Build Coastguard Worker            )
679*da0073e9SAndroid Build Coastguard Worker            x = torch.randn([])
680*da0073e9SAndroid Build Coastguard Worker            y = torch.randn([])
681*da0073e9SAndroid Build Coastguard Worker            z = torch.randn([])
682*da0073e9SAndroid Build Coastguard Worker            w = torch.randn([])
683*da0073e9SAndroid Build Coastguard Worker            self._check_is_functional_variant(
684*da0073e9SAndroid Build Coastguard Worker                getattr(torch.ops, self.test_ns).foo.default,
685*da0073e9SAndroid Build Coastguard Worker                getattr(torch.ops, self.test_ns).foo_functional.default,
686*da0073e9SAndroid Build Coastguard Worker                (x, y, z, w),
687*da0073e9SAndroid Build Coastguard Worker            )
688*da0073e9SAndroid Build Coastguard Worker            self._check_is_functional_variant(
689*da0073e9SAndroid Build Coastguard Worker                getattr(torch.ops, self.test_ns).foo.default,
690*da0073e9SAndroid Build Coastguard Worker                getattr(torch.ops, self.test_ns).foo_functional.default,
691*da0073e9SAndroid Build Coastguard Worker                (x, y, z, None),
692*da0073e9SAndroid Build Coastguard Worker            )
693*da0073e9SAndroid Build Coastguard Worker
694*da0073e9SAndroid Build Coastguard Worker    def test_register_functional_op_one_return(self):
695*da0073e9SAndroid Build Coastguard Worker        with _scoped_library(self.test_ns, "FRAGMENT") as lib:
696*da0073e9SAndroid Build Coastguard Worker            lib.define(
697*da0073e9SAndroid Build Coastguard Worker                "foo(Tensor x, Tensor(a!) y, Tensor(c!) z, Tensor(b!) w) -> Tensor"
698*da0073e9SAndroid Build Coastguard Worker            )
699*da0073e9SAndroid Build Coastguard Worker
700*da0073e9SAndroid Build Coastguard Worker            def foo_impl(x, y, z, w):
701*da0073e9SAndroid Build Coastguard Worker                y.fill_(3.14)
702*da0073e9SAndroid Build Coastguard Worker                w.fill_(2.71)
703*da0073e9SAndroid Build Coastguard Worker                z.fill_(0.99)
704*da0073e9SAndroid Build Coastguard Worker                return x.clone()
705*da0073e9SAndroid Build Coastguard Worker
706*da0073e9SAndroid Build Coastguard Worker            lib.impl("foo", foo_impl, "CPU")
707*da0073e9SAndroid Build Coastguard Worker            register_functional_op(
708*da0073e9SAndroid Build Coastguard Worker                lib, "foo_functional", getattr(torch.ops, self.test_ns).foo.default
709*da0073e9SAndroid Build Coastguard Worker            )
710*da0073e9SAndroid Build Coastguard Worker            x = torch.randn([])
711*da0073e9SAndroid Build Coastguard Worker            y = torch.randn([])
712*da0073e9SAndroid Build Coastguard Worker            z = torch.randn([])
713*da0073e9SAndroid Build Coastguard Worker            w = torch.randn([])
714*da0073e9SAndroid Build Coastguard Worker            self._check_is_functional_variant(
715*da0073e9SAndroid Build Coastguard Worker                getattr(torch.ops, self.test_ns).foo.default,
716*da0073e9SAndroid Build Coastguard Worker                getattr(torch.ops, self.test_ns).foo_functional.default,
717*da0073e9SAndroid Build Coastguard Worker                (x, y, z, w),
718*da0073e9SAndroid Build Coastguard Worker            )
719*da0073e9SAndroid Build Coastguard Worker
720*da0073e9SAndroid Build Coastguard Worker    def test_register_functional_op_multiple_returns(self):
721*da0073e9SAndroid Build Coastguard Worker        with _scoped_library(self.test_ns, "FRAGMENT") as lib:
722*da0073e9SAndroid Build Coastguard Worker            lib.define(
723*da0073e9SAndroid Build Coastguard Worker                "foo(Tensor x, Tensor(a!) y, Tensor z, Tensor(b!) w) -> (Tensor, Tensor)"
724*da0073e9SAndroid Build Coastguard Worker            )
725*da0073e9SAndroid Build Coastguard Worker
726*da0073e9SAndroid Build Coastguard Worker            def foo_impl(x, y, z, w):
727*da0073e9SAndroid Build Coastguard Worker                y.fill_(3.14)
728*da0073e9SAndroid Build Coastguard Worker                w.fill_(2.71)
729*da0073e9SAndroid Build Coastguard Worker                return x.clone(), z.clone()
730*da0073e9SAndroid Build Coastguard Worker
731*da0073e9SAndroid Build Coastguard Worker            lib.impl("foo", foo_impl, "CPU")
732*da0073e9SAndroid Build Coastguard Worker            register_functional_op(
733*da0073e9SAndroid Build Coastguard Worker                lib, "foo_functional", getattr(torch.ops, self.test_ns).foo.default
734*da0073e9SAndroid Build Coastguard Worker            )
735*da0073e9SAndroid Build Coastguard Worker
736*da0073e9SAndroid Build Coastguard Worker            x = torch.randn([])
737*da0073e9SAndroid Build Coastguard Worker            y = torch.randn([])
738*da0073e9SAndroid Build Coastguard Worker            z = torch.randn([])
739*da0073e9SAndroid Build Coastguard Worker            w = torch.randn([])
740*da0073e9SAndroid Build Coastguard Worker            self._check_is_functional_variant(
741*da0073e9SAndroid Build Coastguard Worker                getattr(torch.ops, self.test_ns).foo.default,
742*da0073e9SAndroid Build Coastguard Worker                getattr(torch.ops, self.test_ns).foo_functional.default,
743*da0073e9SAndroid Build Coastguard Worker                (x, y, z, w),
744*da0073e9SAndroid Build Coastguard Worker            )
745*da0073e9SAndroid Build Coastguard Worker
746*da0073e9SAndroid Build Coastguard Worker    def test_register_fallthrough(self):
747*da0073e9SAndroid Build Coastguard Worker        with _scoped_library("aten", "IMPL") as my_lib:
748*da0073e9SAndroid Build Coastguard Worker            my_lib.impl("mm", fallthrough_kernel, "AutocastCPU")
749*da0073e9SAndroid Build Coastguard Worker
750*da0073e9SAndroid Build Coastguard Worker            a = torch.randn(2, 3, device="cpu", dtype=torch.float32)
751*da0073e9SAndroid Build Coastguard Worker            b = torch.randn(3, 2, device="cpu", dtype=torch.float32)
752*da0073e9SAndroid Build Coastguard Worker            with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
753*da0073e9SAndroid Build Coastguard Worker                # dtype for mm should be float32 since we registered a fallthrough
754*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(torch.mm(a, b).dtype, torch.float32)
755*da0073e9SAndroid Build Coastguard Worker                # ops that don't have a fallthrough registered should not be affected
756*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(torch.matmul(a, b).dtype, torch.bfloat16)
757*da0073e9SAndroid Build Coastguard Worker
758*da0073e9SAndroid Build Coastguard Worker        with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
759*da0073e9SAndroid Build Coastguard Worker            # default behavior should have been restored
760*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(torch.mm(a, b).dtype, torch.bfloat16)
761*da0073e9SAndroid Build Coastguard Worker
762*da0073e9SAndroid Build Coastguard Worker
763*da0073e9SAndroid Build Coastguard Workerclass TestPythonDispatch(TestCase):
764*da0073e9SAndroid Build Coastguard Worker    def test_basic(self) -> None:
765*da0073e9SAndroid Build Coastguard Worker        with capture_logs() as logs:
766*da0073e9SAndroid Build Coastguard Worker            x = LoggingTensor(torch.tensor([3.0]), requires_grad=True)
767*da0073e9SAndroid Build Coastguard Worker            log_input("x", x)
768*da0073e9SAndroid Build Coastguard Worker            y = x * x
769*da0073e9SAndroid Build Coastguard Worker            saved_x = y.grad_fn._saved_self
770*da0073e9SAndroid Build Coastguard Worker            grad_y = LoggingTensor(torch.tensor([1.0]))
771*da0073e9SAndroid Build Coastguard Worker            log_input("grad_y", grad_y)
772*da0073e9SAndroid Build Coastguard Worker            (g,) = torch.autograd.grad((y,), (x,), (grad_y,))
773*da0073e9SAndroid Build Coastguard Worker
774*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(g.elem, torch.tensor([6.0]))
775*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
776*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(saved_x, x)
777*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(saved_x._version, x._version)
778*da0073e9SAndroid Build Coastguard Worker            x.add_(2)
779*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(saved_x, x)
780*da0073e9SAndroid Build Coastguard Worker            # TODO: figure out why broken
781*da0073e9SAndroid Build Coastguard Worker            # self.assertEqual(saved_x._version, x._version)
782*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
783*da0073e9SAndroid Build Coastguard Worker            "\n".join(logs),
784*da0073e9SAndroid Build Coastguard Worker            """\
785*da0073e9SAndroid Build Coastguard Worker$0: f32[1] = input('x')
786*da0073e9SAndroid Build Coastguard Worker$1: f32[1] = torch._ops.aten.mul.Tensor($0, $0)
787*da0073e9SAndroid Build Coastguard Worker$2: f32[1] = input('grad_y')
788*da0073e9SAndroid Build Coastguard Worker$3: f32[1] = torch._ops.aten.mul.Tensor($2, $0)
789*da0073e9SAndroid Build Coastguard Worker$4: f32[1] = torch._ops.aten.mul.Tensor($2, $0)
790*da0073e9SAndroid Build Coastguard Worker$5: f32[1] = torch._ops.aten.add.Tensor($4, $3)""",
791*da0073e9SAndroid Build Coastguard Worker        )
792*da0073e9SAndroid Build Coastguard Worker
793*da0073e9SAndroid Build Coastguard Worker    def test_out(self) -> None:
794*da0073e9SAndroid Build Coastguard Worker        with capture_logs() as logs:
795*da0073e9SAndroid Build Coastguard Worker            x = LoggingTensor(torch.ones(1))
796*da0073e9SAndroid Build Coastguard Worker            y = LoggingTensor(torch.zeros(1))
797*da0073e9SAndroid Build Coastguard Worker            log_input("x", x)
798*da0073e9SAndroid Build Coastguard Worker            log_input("y", y)
799*da0073e9SAndroid Build Coastguard Worker            torch.abs(x, out=y)
800*da0073e9SAndroid Build Coastguard Worker
801*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.elem, torch.ones(1))
802*da0073e9SAndroid Build Coastguard Worker        # TODO: arguably this shouldn't pass and we should complain
803*da0073e9SAndroid Build Coastguard Worker        # that out isn't a kwarg
804*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
805*da0073e9SAndroid Build Coastguard Worker            "\n".join(logs),
806*da0073e9SAndroid Build Coastguard Worker            """\
807*da0073e9SAndroid Build Coastguard Worker$0: f32[1] = input('x')
808*da0073e9SAndroid Build Coastguard Worker$1: f32[1] = input('y')
809*da0073e9SAndroid Build Coastguard Worker$2: f32[1] = torch._ops.aten.abs.out($0, out=$1)""",
810*da0073e9SAndroid Build Coastguard Worker        )
811*da0073e9SAndroid Build Coastguard Worker
812*da0073e9SAndroid Build Coastguard Worker    def test_kwarg_only(self) -> None:
813*da0073e9SAndroid Build Coastguard Worker        with capture_logs() as logs:
814*da0073e9SAndroid Build Coastguard Worker            x = LoggingTensor(torch.ones(1))
815*da0073e9SAndroid Build Coastguard Worker            y = LoggingTensor(torch.ones(1, 1))
816*da0073e9SAndroid Build Coastguard Worker            z = LoggingTensor(torch.ones(1))
817*da0073e9SAndroid Build Coastguard Worker            log_input("x", x)
818*da0073e9SAndroid Build Coastguard Worker            log_input("y", y)
819*da0073e9SAndroid Build Coastguard Worker            log_input("z", z)
820*da0073e9SAndroid Build Coastguard Worker            torch.addmv(x, y, z)
821*da0073e9SAndroid Build Coastguard Worker            torch.addmv(x, y, z, beta=1)
822*da0073e9SAndroid Build Coastguard Worker            torch.addmv(x, y, z, beta=2)
823*da0073e9SAndroid Build Coastguard Worker            torch.addmv(x, y, z, alpha=2)
824*da0073e9SAndroid Build Coastguard Worker            torch.addmv(x, y, z, beta=2, alpha=2)
825*da0073e9SAndroid Build Coastguard Worker
826*da0073e9SAndroid Build Coastguard Worker        # The expectation is that beta/alpha don't show up when they're
827*da0073e9SAndroid Build Coastguard Worker        # defaulted.  This is even if the user explicitly specified it.
828*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
829*da0073e9SAndroid Build Coastguard Worker            "\n".join(logs),
830*da0073e9SAndroid Build Coastguard Worker            """\
831*da0073e9SAndroid Build Coastguard Worker$0: f32[1] = input('x')
832*da0073e9SAndroid Build Coastguard Worker$1: f32[1, 1] = input('y')
833*da0073e9SAndroid Build Coastguard Worker$2: f32[1] = input('z')
834*da0073e9SAndroid Build Coastguard Worker$3: f32[1] = torch._ops.aten.addmv.default($0, $1, $2)
835*da0073e9SAndroid Build Coastguard Worker$4: f32[1] = torch._ops.aten.addmv.default($0, $1, $2)
836*da0073e9SAndroid Build Coastguard Worker$5: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, beta=2)
837*da0073e9SAndroid Build Coastguard Worker$6: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, alpha=2)
838*da0073e9SAndroid Build Coastguard Worker$7: f32[1] = torch._ops.aten.addmv.default($0, $1, $2, beta=2, alpha=2)""",
839*da0073e9SAndroid Build Coastguard Worker        )
840*da0073e9SAndroid Build Coastguard Worker
841*da0073e9SAndroid Build Coastguard Worker    def test_kwarg_only_and_positional_default(self) -> None:
842*da0073e9SAndroid Build Coastguard Worker        with capture_logs() as logs:
843*da0073e9SAndroid Build Coastguard Worker            x = LoggingTensor(torch.ones(1))
844*da0073e9SAndroid Build Coastguard Worker            log_input("x", x)
845*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten._foobar(x)
846*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten._foobar(x, False)
847*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten._foobar(x, arg3=False)
848*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten._foobar(x, False, arg3=False)
849*da0073e9SAndroid Build Coastguard Worker
850*da0073e9SAndroid Build Coastguard Worker        # What we are testing here is that we omit arg2
851*da0073e9SAndroid Build Coastguard Worker        # if it is defaulted, even if a kwarg is set
852*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
853*da0073e9SAndroid Build Coastguard Worker            "\n".join(logs),
854*da0073e9SAndroid Build Coastguard Worker            """\
855*da0073e9SAndroid Build Coastguard Worker$0: f32[1] = input('x')
856*da0073e9SAndroid Build Coastguard Worker$1: f32[1] = torch._ops.aten._foobar.default($0)
857*da0073e9SAndroid Build Coastguard Worker$2: f32[1] = torch._ops.aten._foobar.default($0, False)
858*da0073e9SAndroid Build Coastguard Worker$3: f32[1] = torch._ops.aten._foobar.default($0, arg3=False)
859*da0073e9SAndroid Build Coastguard Worker$4: f32[1] = torch._ops.aten._foobar.default($0, False, arg3=False)""",
860*da0073e9SAndroid Build Coastguard Worker        )
861*da0073e9SAndroid Build Coastguard Worker
862*da0073e9SAndroid Build Coastguard Worker    def test_produce_real_type(self) -> None:
863*da0073e9SAndroid Build Coastguard Worker        with capture_logs() as logs:
864*da0073e9SAndroid Build Coastguard Worker            x = LoggingTensor(torch.ones(2, 2))
865*da0073e9SAndroid Build Coastguard Worker            log_input("x", x)
866*da0073e9SAndroid Build Coastguard Worker            x.to(dtype=torch.double)  # non-optional dtype
867*da0073e9SAndroid Build Coastguard Worker            torch.cumprod(x, 0, dtype=torch.double)  # optional dtype
868*da0073e9SAndroid Build Coastguard Worker            x[:, 1].contiguous(
869*da0073e9SAndroid Build Coastguard Worker                memory_format=torch.contiguous_format
870*da0073e9SAndroid Build Coastguard Worker            )  # optional memory format
871*da0073e9SAndroid Build Coastguard Worker            # There doesn't appear to be any layout signatures which are
872*da0073e9SAndroid Build Coastguard Worker            # triggerable using tensor subclasses (need to use a mode)
873*da0073e9SAndroid Build Coastguard Worker
874*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
875*da0073e9SAndroid Build Coastguard Worker            "\n".join(logs),
876*da0073e9SAndroid Build Coastguard Worker            """\
877*da0073e9SAndroid Build Coastguard Worker$0: f32[2, 2] = input('x')
878*da0073e9SAndroid Build Coastguard Worker$1: f64[2, 2] = torch._ops.aten._to_copy.default($0, dtype=torch.float64)
879*da0073e9SAndroid Build Coastguard Worker$2: f64[2, 2] = torch._ops.aten.cumprod.default($0, 0, dtype=torch.float64)
880*da0073e9SAndroid Build Coastguard Worker$3: f32[2, 2] = torch._ops.aten.slice.Tensor($0, 0, 0, 9223372036854775807)
881*da0073e9SAndroid Build Coastguard Worker$4: f32[2] = torch._ops.aten.select.int($3, 1, 1)
882*da0073e9SAndroid Build Coastguard Worker$5: f32[2] = torch._ops.aten.clone.default($4, memory_format=torch.contiguous_format)""",
883*da0073e9SAndroid Build Coastguard Worker        )
884*da0073e9SAndroid Build Coastguard Worker
885*da0073e9SAndroid Build Coastguard Worker    def test_optional_tensor_list(self) -> None:
886*da0073e9SAndroid Build Coastguard Worker        def weird(xs):
887*da0073e9SAndroid Build Coastguard Worker            print("woof")
888*da0073e9SAndroid Build Coastguard Worker            return torch.empty(())
889*da0073e9SAndroid Build Coastguard Worker
890*da0073e9SAndroid Build Coastguard Worker        with _scoped_library("my_lib", "DEF") as my_lib:
891*da0073e9SAndroid Build Coastguard Worker            my_lib.define("weird(Tensor?[] self) -> Tensor")
892*da0073e9SAndroid Build Coastguard Worker            my_lib.impl("weird", weird, "CPU")
893*da0073e9SAndroid Build Coastguard Worker            with capture_logs() as logs:
894*da0073e9SAndroid Build Coastguard Worker                x = LoggingTensor(torch.ones(2, 2))
895*da0073e9SAndroid Build Coastguard Worker                log_input("x", x)
896*da0073e9SAndroid Build Coastguard Worker                torch.ops.my_lib.weird.default([None, x])
897*da0073e9SAndroid Build Coastguard Worker
898*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
899*da0073e9SAndroid Build Coastguard Worker            "\n".join(logs),
900*da0073e9SAndroid Build Coastguard Worker            """\
901*da0073e9SAndroid Build Coastguard Worker$0: f32[2, 2] = input('x')
902*da0073e9SAndroid Build Coastguard Worker$1: f32[] = torch._ops.my_lib.weird.default(['None', '$0'])""",
903*da0073e9SAndroid Build Coastguard Worker        )
904*da0073e9SAndroid Build Coastguard Worker
905*da0073e9SAndroid Build Coastguard Worker    def test_list_ret(self) -> None:
906*da0073e9SAndroid Build Coastguard Worker        # test all sequence types are permissible returns
907*da0073e9SAndroid Build Coastguard Worker        for list_type in (list, tuple):
908*da0073e9SAndroid Build Coastguard Worker
909*da0073e9SAndroid Build Coastguard Worker            class A(torch.Tensor):
910*da0073e9SAndroid Build Coastguard Worker                @staticmethod
911*da0073e9SAndroid Build Coastguard Worker                def __new__(cls, elem):
912*da0073e9SAndroid Build Coastguard Worker                    return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
913*da0073e9SAndroid Build Coastguard Worker
914*da0073e9SAndroid Build Coastguard Worker                @classmethod
915*da0073e9SAndroid Build Coastguard Worker                def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
916*da0073e9SAndroid Build Coastguard Worker                    if func.overloadpacket == torch.ops.aten.split:
917*da0073e9SAndroid Build Coastguard Worker                        with no_dispatch():
918*da0073e9SAndroid Build Coastguard Worker                            return list_type(torch.split(*args))
919*da0073e9SAndroid Build Coastguard Worker                    else:
920*da0073e9SAndroid Build Coastguard Worker                        raise AssertionError(f"unrecognized func: {func}")
921*da0073e9SAndroid Build Coastguard Worker
922*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(
923*da0073e9SAndroid Build Coastguard Worker                torch.split(A(torch.tensor([0, 1])), 2),
924*da0073e9SAndroid Build Coastguard Worker                torch.split(torch.tensor([0, 1]), 2),
925*da0073e9SAndroid Build Coastguard Worker            )
926*da0073e9SAndroid Build Coastguard Worker
927*da0073e9SAndroid Build Coastguard Worker    def test_invalid_ret(self) -> None:
928*da0073e9SAndroid Build Coastguard Worker        # test invalid return gets reasonable error message
929*da0073e9SAndroid Build Coastguard Worker        class A(torch.Tensor):
930*da0073e9SAndroid Build Coastguard Worker            @staticmethod
931*da0073e9SAndroid Build Coastguard Worker            def __new__(cls, elem):
932*da0073e9SAndroid Build Coastguard Worker                return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
933*da0073e9SAndroid Build Coastguard Worker
934*da0073e9SAndroid Build Coastguard Worker            @classmethod
935*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
936*da0073e9SAndroid Build Coastguard Worker                return "arf"
937*da0073e9SAndroid Build Coastguard Worker
938*da0073e9SAndroid Build Coastguard Worker        # Wobbles depending on NDEBUG mode of pybind11
939*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
940*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
941*da0073e9SAndroid Build Coastguard Worker            "Unable to cast",
942*da0073e9SAndroid Build Coastguard Worker            lambda: A(torch.zeros(1)).neg(),
943*da0073e9SAndroid Build Coastguard Worker        )
944*da0073e9SAndroid Build Coastguard Worker        self.assertRaisesRegex(
945*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
946*da0073e9SAndroid Build Coastguard Worker            "Unable to cast",
947*da0073e9SAndroid Build Coastguard Worker            lambda: A(torch.zeros(1)).detach(),
948*da0073e9SAndroid Build Coastguard Worker        )
949*da0073e9SAndroid Build Coastguard Worker
950*da0073e9SAndroid Build Coastguard Worker    def test_detach_appears_twice_when_called_once(self) -> None:
951*da0073e9SAndroid Build Coastguard Worker        with capture_logs() as logs:
952*da0073e9SAndroid Build Coastguard Worker            x = LoggingTensor(torch.tensor([3.0]), requires_grad=True)
953*da0073e9SAndroid Build Coastguard Worker            log_input("x", x)
954*da0073e9SAndroid Build Coastguard Worker            x.detach()
955*da0073e9SAndroid Build Coastguard Worker        # FIXME: We actually want this to emit a single detach. However,
956*da0073e9SAndroid Build Coastguard Worker        # it currently emits two, for reasons unclear to us. Leaving
957*da0073e9SAndroid Build Coastguard Worker        # this test here to make sure we don't regress even further (it
958*da0073e9SAndroid Build Coastguard Worker        # would be bad if calling .detach() once emits 3+ detaches).
959*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
960*da0073e9SAndroid Build Coastguard Worker            "\n".join(logs),
961*da0073e9SAndroid Build Coastguard Worker            """\
962*da0073e9SAndroid Build Coastguard Worker$0: f32[1] = input('x')
963*da0073e9SAndroid Build Coastguard Worker$1: f32[1] = torch._ops.aten.detach.default($0)
964*da0073e9SAndroid Build Coastguard Worker$2: f32[1] = torch._ops.aten.detach.default($1)""",
965*da0073e9SAndroid Build Coastguard Worker        )
966*da0073e9SAndroid Build Coastguard Worker
967*da0073e9SAndroid Build Coastguard Worker    def test_storage(self) -> None:
968*da0073e9SAndroid Build Coastguard Worker        # For now, just make sure it doesn't crash.  Ideally, we should
969*da0073e9SAndroid Build Coastguard Worker        # return some virtual storage that is safe to work with
970*da0073e9SAndroid Build Coastguard Worker        x = LoggingTensor(torch.ones(1))
971*da0073e9SAndroid Build Coastguard Worker        storage = x.untyped_storage()
972*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(RuntimeError, lambda: storage.data_ptr())
973*da0073e9SAndroid Build Coastguard Worker
974*da0073e9SAndroid Build Coastguard Worker    def test_make_wrapper_subclass_noalloc(self) -> None:
975*da0073e9SAndroid Build Coastguard Worker        # This is ludicrously big (8TB) and this should pass because wrapper
976*da0073e9SAndroid Build Coastguard Worker        # subclasses don't allocate
977*da0073e9SAndroid Build Coastguard Worker        torch.Tensor._make_wrapper_subclass(LoggingTensor, (1000000000000,))
978*da0073e9SAndroid Build Coastguard Worker
979*da0073e9SAndroid Build Coastguard Worker    def test_version(self) -> None:
980*da0073e9SAndroid Build Coastguard Worker        x = LoggingTensor(torch.ones(1))
981*da0073e9SAndroid Build Coastguard Worker        prev_vc = x._version
982*da0073e9SAndroid Build Coastguard Worker        x.detach().add_(2)
983*da0073e9SAndroid Build Coastguard Worker        cur_vc = x._version
984*da0073e9SAndroid Build Coastguard Worker        self.assertNotEqual(prev_vc, cur_vc)
985*da0073e9SAndroid Build Coastguard Worker        x.data.add_(2)
986*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(cur_vc, x._version)
987*da0073e9SAndroid Build Coastguard Worker
988*da0073e9SAndroid Build Coastguard Worker    def test_subclass_priority(self) -> None:
989*da0073e9SAndroid Build Coastguard Worker        class ErrorA(RuntimeError):
990*da0073e9SAndroid Build Coastguard Worker            pass
991*da0073e9SAndroid Build Coastguard Worker
992*da0073e9SAndroid Build Coastguard Worker        class ErrorB(RuntimeError):
993*da0073e9SAndroid Build Coastguard Worker            pass
994*da0073e9SAndroid Build Coastguard Worker
995*da0073e9SAndroid Build Coastguard Worker        # The big tests for code coverage are test_precedence_semantics in
996*da0073e9SAndroid Build Coastguard Worker        # test_overrides.py; this is just to make sure it is wired up at all
997*da0073e9SAndroid Build Coastguard Worker        # correctly for __torch_dispatch__
998*da0073e9SAndroid Build Coastguard Worker        class A(torch.Tensor):
999*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1000*da0073e9SAndroid Build Coastguard Worker            def __new__(cls, elem):
1001*da0073e9SAndroid Build Coastguard Worker                return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
1002*da0073e9SAndroid Build Coastguard Worker
1003*da0073e9SAndroid Build Coastguard Worker            @classmethod
1004*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1005*da0073e9SAndroid Build Coastguard Worker                raise ErrorA
1006*da0073e9SAndroid Build Coastguard Worker
1007*da0073e9SAndroid Build Coastguard Worker        class B(A):
1008*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1009*da0073e9SAndroid Build Coastguard Worker            def __new__(cls, elem):
1010*da0073e9SAndroid Build Coastguard Worker                return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
1011*da0073e9SAndroid Build Coastguard Worker
1012*da0073e9SAndroid Build Coastguard Worker            @classmethod
1013*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1014*da0073e9SAndroid Build Coastguard Worker                raise ErrorB
1015*da0073e9SAndroid Build Coastguard Worker
1016*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(
1017*da0073e9SAndroid Build Coastguard Worker            ErrorA, lambda: torch.add(A(torch.empty(1)), A(torch.empty(1)))
1018*da0073e9SAndroid Build Coastguard Worker        )
1019*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(
1020*da0073e9SAndroid Build Coastguard Worker            ErrorB, lambda: torch.add(A(torch.empty(1)), B(torch.empty(1)))
1021*da0073e9SAndroid Build Coastguard Worker        )
1022*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(
1023*da0073e9SAndroid Build Coastguard Worker            ErrorB, lambda: torch.add(B(torch.empty(1)), A(torch.empty(1)))
1024*da0073e9SAndroid Build Coastguard Worker        )
1025*da0073e9SAndroid Build Coastguard Worker        self.assertRaises(
1026*da0073e9SAndroid Build Coastguard Worker            ErrorB, lambda: torch.add(B(torch.empty(1)), B(torch.empty(1)))
1027*da0073e9SAndroid Build Coastguard Worker        )
1028*da0073e9SAndroid Build Coastguard Worker
1029*da0073e9SAndroid Build Coastguard Worker    def test_format(self) -> None:
1030*da0073e9SAndroid Build Coastguard Worker        x = LoggingTensor(torch.ones(1))
1031*da0073e9SAndroid Build Coastguard Worker        s1 = str(x)
1032*da0073e9SAndroid Build Coastguard Worker        s2 = repr(x)
1033*da0073e9SAndroid Build Coastguard Worker        s3 = f"{x}"
1034*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(s1, """LoggingTensor(tensor([1.]))""")
1035*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s1, s2)
1036*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(s1, s3)
1037*da0073e9SAndroid Build Coastguard Worker
1038*da0073e9SAndroid Build Coastguard Worker    def test_custom_autograd(self) -> None:
1039*da0073e9SAndroid Build Coastguard Worker        escape = [None]
1040*da0073e9SAndroid Build Coastguard Worker
1041*da0073e9SAndroid Build Coastguard Worker        class Square(torch.autograd.Function):
1042*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1043*da0073e9SAndroid Build Coastguard Worker            def forward(ctx, x):
1044*da0073e9SAndroid Build Coastguard Worker                y = x**2
1045*da0073e9SAndroid Build Coastguard Worker                ctx.save_for_backward(x)
1046*da0073e9SAndroid Build Coastguard Worker                return y
1047*da0073e9SAndroid Build Coastguard Worker
1048*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1049*da0073e9SAndroid Build Coastguard Worker            def backward(ctx, grad_output):
1050*da0073e9SAndroid Build Coastguard Worker                assert isinstance(grad_output, LoggingTensor)
1051*da0073e9SAndroid Build Coastguard Worker                (x,) = ctx.saved_tensors
1052*da0073e9SAndroid Build Coastguard Worker                assert isinstance(x, LoggingTensor)
1053*da0073e9SAndroid Build Coastguard Worker                escape[0] = x
1054*da0073e9SAndroid Build Coastguard Worker                return grad_output * 2 * x
1055*da0073e9SAndroid Build Coastguard Worker
1056*da0073e9SAndroid Build Coastguard Worker        with capture_logs() as logs:
1057*da0073e9SAndroid Build Coastguard Worker            x = LoggingTensor(torch.ones(1), requires_grad=True)
1058*da0073e9SAndroid Build Coastguard Worker            log_input("x", x)
1059*da0073e9SAndroid Build Coastguard Worker            x.grad = LoggingTensor(torch.zeros(1))
1060*da0073e9SAndroid Build Coastguard Worker            log_input("x.grad", x.grad)
1061*da0073e9SAndroid Build Coastguard Worker            y = Square.apply(x)
1062*da0073e9SAndroid Build Coastguard Worker            grad_output = LoggingTensor(torch.ones(1))
1063*da0073e9SAndroid Build Coastguard Worker            log_input("grad_output", grad_output)
1064*da0073e9SAndroid Build Coastguard Worker            y.backward(grad_output)
1065*da0073e9SAndroid Build Coastguard Worker
1066*da0073e9SAndroid Build Coastguard Worker        with torch.no_grad():
1067*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(escape[0], x)
1068*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(escape[0]._version, x._version)
1069*da0073e9SAndroid Build Coastguard Worker            # TODO: figure out why x.requires_grad = False doesn't
1070*da0073e9SAndroid Build Coastguard Worker            # trigger an error for LoggingTensor
1071*da0073e9SAndroid Build Coastguard Worker            x.add_(2)
1072*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(escape[0], x)
1073*da0073e9SAndroid Build Coastguard Worker            # TODO: figure out why this is broken
1074*da0073e9SAndroid Build Coastguard Worker            # self.assertEqual(escape[0]._version, x._version)
1075*da0073e9SAndroid Build Coastguard Worker
1076*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1077*da0073e9SAndroid Build Coastguard Worker            "\n".join(logs),
1078*da0073e9SAndroid Build Coastguard Worker            """\
1079*da0073e9SAndroid Build Coastguard Worker$0: f32[1] = input('x')
1080*da0073e9SAndroid Build Coastguard Worker$1: f32[1] = input('x.grad')
1081*da0073e9SAndroid Build Coastguard Worker$2: f32[1] = torch._ops.aten.pow.Tensor_Scalar($0, 2)
1082*da0073e9SAndroid Build Coastguard Worker$3: f32[1] = input('grad_output')
1083*da0073e9SAndroid Build Coastguard Worker$4: f32[1] = torch._ops.aten.mul.Tensor($3, 2)
1084*da0073e9SAndroid Build Coastguard Worker$5: f32[1] = torch._ops.aten.mul.Tensor($4, $0)
1085*da0073e9SAndroid Build Coastguard Worker$6: f32[1] = torch._ops.aten.add_.Tensor($1, $5)""",
1086*da0073e9SAndroid Build Coastguard Worker        )
1087*da0073e9SAndroid Build Coastguard Worker
1088*da0073e9SAndroid Build Coastguard Worker    def test_subclass_creation(self):
1089*da0073e9SAndroid Build Coastguard Worker        # Make sure these statements runs without error
1090*da0073e9SAndroid Build Coastguard Worker        # In particular checking that when internal detach returns
1091*da0073e9SAndroid Build Coastguard Worker        # subclasses, these are cleanly overwritten.
1092*da0073e9SAndroid Build Coastguard Worker        class Foo(torch.Tensor):
1093*da0073e9SAndroid Build Coastguard Worker            pass
1094*da0073e9SAndroid Build Coastguard Worker
1095*da0073e9SAndroid Build Coastguard Worker        err_msg = "subclass Foo but.*already associated to a python object of type LoggingTensor"
1096*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, err_msg):
1097*da0073e9SAndroid Build Coastguard Worker            a = torch.Tensor._make_subclass(Foo, LoggingTensor(torch.rand(2)))
1098*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, err_msg):
1099*da0073e9SAndroid Build Coastguard Worker            b = LoggingTensor(torch.rand(2)).as_subclass(Foo)
1100*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, err_msg):
1101*da0073e9SAndroid Build Coastguard Worker            Foo(LoggingTensor(torch.rand(2)))
1102*da0073e9SAndroid Build Coastguard Worker
1103*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(TypeError, "Foo must define __torch_dispatch__"):
1104*da0073e9SAndroid Build Coastguard Worker            torch.Tensor._make_wrapper_subclass(Foo, (2, 2))
1105*da0073e9SAndroid Build Coastguard Worker
1106*da0073e9SAndroid Build Coastguard Worker    def test_new_ones(self) -> None:
1107*da0073e9SAndroid Build Coastguard Worker        class MyTensor(torch.Tensor):
1108*da0073e9SAndroid Build Coastguard Worker            @classmethod
1109*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1110*da0073e9SAndroid Build Coastguard Worker                return MyTensor(3)
1111*da0073e9SAndroid Build Coastguard Worker
1112*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(type(MyTensor(2).new_ones(3)), MyTensor)
1113*da0073e9SAndroid Build Coastguard Worker
1114*da0073e9SAndroid Build Coastguard Worker    def test_like(self) -> None:
1115*da0073e9SAndroid Build Coastguard Worker        class MyTensor(torch.Tensor):
1116*da0073e9SAndroid Build Coastguard Worker            @classmethod
1117*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1118*da0073e9SAndroid Build Coastguard Worker                return MyTensor(3)
1119*da0073e9SAndroid Build Coastguard Worker
1120*da0073e9SAndroid Build Coastguard Worker        for f in ["empty", "ones", "rand", "randn", "zeros"]:
1121*da0073e9SAndroid Build Coastguard Worker            f_name = f + "_like"
1122*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(type(getattr(torch, f_name)(MyTensor(2))), MyTensor)
1123*da0073e9SAndroid Build Coastguard Worker
1124*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(type(torch.full_like(MyTensor(2), 1.0)), MyTensor)
1125*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(type(torch.randint_like(MyTensor(2), high=3)), MyTensor)
1126*da0073e9SAndroid Build Coastguard Worker
1127*da0073e9SAndroid Build Coastguard Worker    def test_make_fx_with_subclass(self) -> None:
1128*da0073e9SAndroid Build Coastguard Worker        def f(x, y):
1129*da0073e9SAndroid Build Coastguard Worker            # Returns (TwoTensor, Tensor)
1130*da0073e9SAndroid Build Coastguard Worker            return x * y, y + y
1131*da0073e9SAndroid Build Coastguard Worker
1132*da0073e9SAndroid Build Coastguard Worker        x_a = torch.zeros(4)
1133*da0073e9SAndroid Build Coastguard Worker        x_b = torch.zeros(4)
1134*da0073e9SAndroid Build Coastguard Worker        y = torch.ones(4)
1135*da0073e9SAndroid Build Coastguard Worker
1136*da0073e9SAndroid Build Coastguard Worker        # make_fx() is not responsible for unwrapping tensor subclass inputs,
1137*da0073e9SAndroid Build Coastguard Worker        # so we do it manually here.
1138*da0073e9SAndroid Build Coastguard Worker        # Why? In general, make_fx(f)(*args) promises that the graph returned has the same calling
1139*da0073e9SAndroid Build Coastguard Worker        # convention as f(*args). Unwrapping tensor subclass inputs can potentially change
1140*da0073e9SAndroid Build Coastguard Worker        # the number of input args to the graph, breaking that assumption
1141*da0073e9SAndroid Build Coastguard Worker        def f_to_trace(x_a, x_b, y):
1142*da0073e9SAndroid Build Coastguard Worker            x = TwoTensor(x_a, x_b)
1143*da0073e9SAndroid Build Coastguard Worker            out1, out2 = f(x, y)
1144*da0073e9SAndroid Build Coastguard Worker            out1_unwrapped_attrs, _ = out1.__tensor_flatten__()
1145*da0073e9SAndroid Build Coastguard Worker            return (*[getattr(out1, attr) for attr in out1_unwrapped_attrs], out2)
1146*da0073e9SAndroid Build Coastguard Worker
1147*da0073e9SAndroid Build Coastguard Worker        fx_g = make_fx(f_to_trace, tracing_mode="fake")(x_a, x_b, y)
1148*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1149*da0073e9SAndroid Build Coastguard Worker            fx_g.code,
1150*da0073e9SAndroid Build Coastguard Worker            """\
1151*da0073e9SAndroid Build Coastguard Worker
1152*da0073e9SAndroid Build Coastguard Worker
1153*da0073e9SAndroid Build Coastguard Worker
1154*da0073e9SAndroid Build Coastguard Workerdef forward(self, x_a_1, x_b_1, y_1):
1155*da0073e9SAndroid Build Coastguard Worker    mul = torch.ops.aten.mul.Tensor(x_a_1, y_1);  x_a_1 = None
1156*da0073e9SAndroid Build Coastguard Worker    mul_1 = torch.ops.aten.mul.Tensor(x_b_1, y_1);  x_b_1 = None
1157*da0073e9SAndroid Build Coastguard Worker    add = torch.ops.aten.add.Tensor(y_1, y_1);  y_1 = None
1158*da0073e9SAndroid Build Coastguard Worker    return (mul, mul_1, add)
1159*da0073e9SAndroid Build Coastguard Worker    """,
1160*da0073e9SAndroid Build Coastguard Worker        )
1161*da0073e9SAndroid Build Coastguard Worker
1162*da0073e9SAndroid Build Coastguard Worker    # See https://github.com/pytorch/pytorch/issues/117794
1163*da0073e9SAndroid Build Coastguard Worker    def test_return_and_correct_aliasing_gives_correct_stride(self):
1164*da0073e9SAndroid Build Coastguard Worker        t = TwoTensor(torch.randn(2, 2), torch.randn(2, 2))
1165*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, 2)
1166*da0073e9SAndroid Build Coastguard Worker        # slicing should result in the same stride for TwoTensor as a dense tensor would give
1167*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(t[:, 0].stride(), x[:, 0].stride())
1168*da0073e9SAndroid Build Coastguard Worker
1169*da0073e9SAndroid Build Coastguard Worker    def test_make_wrapper_subclass_propagates_metadata(self) -> None:
1170*da0073e9SAndroid Build Coastguard Worker        class WrapperTensor(torch.Tensor):
1171*da0073e9SAndroid Build Coastguard Worker            elem: torch.Tensor
1172*da0073e9SAndroid Build Coastguard Worker
1173*da0073e9SAndroid Build Coastguard Worker            __slots__ = ["elem"]
1174*da0073e9SAndroid Build Coastguard Worker
1175*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1176*da0073e9SAndroid Build Coastguard Worker            def __new__(cls, elem, *args, **kwargs):
1177*da0073e9SAndroid Build Coastguard Worker                r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
1178*da0073e9SAndroid Build Coastguard Worker                    cls,
1179*da0073e9SAndroid Build Coastguard Worker                    elem.size(),
1180*da0073e9SAndroid Build Coastguard Worker                    dtype=elem.dtype,
1181*da0073e9SAndroid Build Coastguard Worker                    layout=elem.layout,
1182*da0073e9SAndroid Build Coastguard Worker                    device=elem.device,
1183*da0073e9SAndroid Build Coastguard Worker                    requires_grad=elem.requires_grad,
1184*da0073e9SAndroid Build Coastguard Worker                    strides=elem.stride(),
1185*da0073e9SAndroid Build Coastguard Worker                    storage_offset=elem.storage_offset(),
1186*da0073e9SAndroid Build Coastguard Worker                )
1187*da0073e9SAndroid Build Coastguard Worker                r.elem = elem
1188*da0073e9SAndroid Build Coastguard Worker                return r
1189*da0073e9SAndroid Build Coastguard Worker
1190*da0073e9SAndroid Build Coastguard Worker            @classmethod
1191*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1192*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError("NYI")
1193*da0073e9SAndroid Build Coastguard Worker
1194*da0073e9SAndroid Build Coastguard Worker        # non-contiguous strides, non-zero storage offset
1195*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(4, 6).t().diagonal(offset=2)
1196*da0073e9SAndroid Build Coastguard Worker        y = WrapperTensor(x)
1197*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.size(), x.size())
1198*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.stride(), x.stride())
1199*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.storage_offset(), x.storage_offset())
1200*da0073e9SAndroid Build Coastguard Worker
1201*da0073e9SAndroid Build Coastguard Worker    def test_wrapper_subclass_serializes(self) -> None:
1202*da0073e9SAndroid Build Coastguard Worker        with tempfile.TemporaryFile() as f:
1203*da0073e9SAndroid Build Coastguard Worker            # purposefully use int64 to test non-default dtype
1204*da0073e9SAndroid Build Coastguard Worker            x = LoggingTensor(torch.randperm(3))
1205*da0073e9SAndroid Build Coastguard Worker            torch.save(x, f)
1206*da0073e9SAndroid Build Coastguard Worker            f.seek(0)
1207*da0073e9SAndroid Build Coastguard Worker            with torch.serialization.safe_globals([LoggingTensor]):
1208*da0073e9SAndroid Build Coastguard Worker                x_loaded = torch.load(f)
1209*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(type(x_loaded) is type(x))
1210*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, x_loaded)
1211*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x.elem, x_loaded.elem)
1212*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(x is x_loaded)
1213*da0073e9SAndroid Build Coastguard Worker
1214*da0073e9SAndroid Build Coastguard Worker    def test_deepcopy_wrapper_subclass(self) -> None:
1215*da0073e9SAndroid Build Coastguard Worker        # purposefully use int64 to test non-default dtype
1216*da0073e9SAndroid Build Coastguard Worker        x = LoggingTensor(torch.randperm(3))
1217*da0073e9SAndroid Build Coastguard Worker        x_copy = deepcopy(x)
1218*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(type(x_copy) is type(x))
1219*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x, x_copy)
1220*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.elem, x_copy.elem)
1221*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(x is x_copy)
1222*da0073e9SAndroid Build Coastguard Worker
1223*da0073e9SAndroid Build Coastguard Worker    def test_deepcopy_wrapper_subclass_with_clone_returning_different_type(
1224*da0073e9SAndroid Build Coastguard Worker        self,
1225*da0073e9SAndroid Build Coastguard Worker    ) -> None:
1226*da0073e9SAndroid Build Coastguard Worker        class MyWrapperTensor(torch.Tensor):
1227*da0073e9SAndroid Build Coastguard Worker            elem: torch.Tensor
1228*da0073e9SAndroid Build Coastguard Worker
1229*da0073e9SAndroid Build Coastguard Worker            __slots__ = ["elem"]
1230*da0073e9SAndroid Build Coastguard Worker
1231*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1232*da0073e9SAndroid Build Coastguard Worker            def __new__(cls, elem, *args, **kwargs):
1233*da0073e9SAndroid Build Coastguard Worker                r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
1234*da0073e9SAndroid Build Coastguard Worker                    cls,
1235*da0073e9SAndroid Build Coastguard Worker                    elem.size(),
1236*da0073e9SAndroid Build Coastguard Worker                    dtype=elem.dtype,
1237*da0073e9SAndroid Build Coastguard Worker                    layout=elem.layout,
1238*da0073e9SAndroid Build Coastguard Worker                    device=elem.device,
1239*da0073e9SAndroid Build Coastguard Worker                    requires_grad=elem.requires_grad,
1240*da0073e9SAndroid Build Coastguard Worker                    strides=elem.stride(),
1241*da0073e9SAndroid Build Coastguard Worker                    storage_offset=elem.storage_offset(),
1242*da0073e9SAndroid Build Coastguard Worker                )
1243*da0073e9SAndroid Build Coastguard Worker                r.elem = elem
1244*da0073e9SAndroid Build Coastguard Worker                return r
1245*da0073e9SAndroid Build Coastguard Worker
1246*da0073e9SAndroid Build Coastguard Worker            @classmethod
1247*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1248*da0073e9SAndroid Build Coastguard Worker                if func.overloadpacket.__name__ == "clone":
1249*da0073e9SAndroid Build Coastguard Worker                    # Return a plain tensor from clone().
1250*da0073e9SAndroid Build Coastguard Worker                    return args[0].elem.clone()
1251*da0073e9SAndroid Build Coastguard Worker                raise RuntimeError("NYI")
1252*da0073e9SAndroid Build Coastguard Worker
1253*da0073e9SAndroid Build Coastguard Worker            # NB: The default Tensor.__torch_function__ implementation called for deepcopy
1254*da0073e9SAndroid Build Coastguard Worker            # disables __torch_function__ by the time we get to clone(), so there is no need to
1255*da0073e9SAndroid Build Coastguard Worker            # explicitly disable __torch_function__ for this subclass.
1256*da0073e9SAndroid Build Coastguard Worker
1257*da0073e9SAndroid Build Coastguard Worker        x = MyWrapperTensor(torch.randn(3))
1258*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1259*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1260*da0073e9SAndroid Build Coastguard Worker            "for which cloning returns another instance of the same subclass",
1261*da0073e9SAndroid Build Coastguard Worker        ):
1262*da0073e9SAndroid Build Coastguard Worker            x_copy = deepcopy(x)
1263*da0073e9SAndroid Build Coastguard Worker
1264*da0073e9SAndroid Build Coastguard Worker    def test_deepcopy_non_wrapper_subclass(self) -> None:
1265*da0073e9SAndroid Build Coastguard Worker        # Ensure correct error is thrown for common error cases.
1266*da0073e9SAndroid Build Coastguard Worker        class SubTensorError1(torch.Tensor):
1267*da0073e9SAndroid Build Coastguard Worker            # Default implementation of new_empty() returns a plain tensor.
1268*da0073e9SAndroid Build Coastguard Worker            pass
1269*da0073e9SAndroid Build Coastguard Worker
1270*da0073e9SAndroid Build Coastguard Worker        class SubTensorError2(torch.Tensor):
1271*da0073e9SAndroid Build Coastguard Worker            # new_empty() incorrectly returns a different type (i.e. a plain tensor).
1272*da0073e9SAndroid Build Coastguard Worker            def new_empty(self, shape):
1273*da0073e9SAndroid Build Coastguard Worker                return torch.Tensor(shape)
1274*da0073e9SAndroid Build Coastguard Worker
1275*da0073e9SAndroid Build Coastguard Worker        for error_cls in [SubTensorError1, SubTensorError2]:
1276*da0073e9SAndroid Build Coastguard Worker            x = error_cls(3)
1277*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(
1278*da0073e9SAndroid Build Coastguard Worker                RuntimeError,
1279*da0073e9SAndroid Build Coastguard Worker                "for which that function returns another instance of the same subclass",
1280*da0073e9SAndroid Build Coastguard Worker            ):
1281*da0073e9SAndroid Build Coastguard Worker                x_copy = deepcopy(x)
1282*da0073e9SAndroid Build Coastguard Worker
1283*da0073e9SAndroid Build Coastguard Worker        # Ensure a correctly implemented new_empty() causes deepcopy() to work.
1284*da0073e9SAndroid Build Coastguard Worker        class SubTensorSuccess(torch.Tensor):
1285*da0073e9SAndroid Build Coastguard Worker            def new_empty(self, shape):
1286*da0073e9SAndroid Build Coastguard Worker                return type(self)(shape)
1287*da0073e9SAndroid Build Coastguard Worker
1288*da0073e9SAndroid Build Coastguard Worker        x = SubTensorSuccess(3)
1289*da0073e9SAndroid Build Coastguard Worker        x_copy = deepcopy(x)
1290*da0073e9SAndroid Build Coastguard Worker        self.assertIs(type(x_copy), type(x))
1291*da0073e9SAndroid Build Coastguard Worker
1292*da0073e9SAndroid Build Coastguard Worker    def test_wrapper_subclass_extra_dispatch_keys(self) -> None:
1293*da0073e9SAndroid Build Coastguard Worker        class ExtraKeysTensor(torch.Tensor):
1294*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1295*da0073e9SAndroid Build Coastguard Worker            def __new__(cls, elem, *args, **kwargs):
1296*da0073e9SAndroid Build Coastguard Worker                # NB: only the non-kwarg overload of _make_wrapper_subclass supports
1297*da0073e9SAndroid Build Coastguard Worker                #     extra dispatch keys. We probably want to unify the two APIs
1298*da0073e9SAndroid Build Coastguard Worker                #     in the future.
1299*da0073e9SAndroid Build Coastguard Worker                r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
1300*da0073e9SAndroid Build Coastguard Worker                    cls,
1301*da0073e9SAndroid Build Coastguard Worker                    elem.size(),
1302*da0073e9SAndroid Build Coastguard Worker                    elem.stride(),
1303*da0073e9SAndroid Build Coastguard Worker                    elem.storage_offset(),
1304*da0073e9SAndroid Build Coastguard Worker                    torch.contiguous_format,
1305*da0073e9SAndroid Build Coastguard Worker                    elem.dtype,
1306*da0073e9SAndroid Build Coastguard Worker                    elem.layout,
1307*da0073e9SAndroid Build Coastguard Worker                    elem.device,
1308*da0073e9SAndroid Build Coastguard Worker                    False,
1309*da0073e9SAndroid Build Coastguard Worker                    False,
1310*da0073e9SAndroid Build Coastguard Worker                    None,
1311*da0073e9SAndroid Build Coastguard Worker                    False,
1312*da0073e9SAndroid Build Coastguard Worker                    False,
1313*da0073e9SAndroid Build Coastguard Worker                    DispatchKeySet(DispatchKey.NestedTensor),
1314*da0073e9SAndroid Build Coastguard Worker                )
1315*da0073e9SAndroid Build Coastguard Worker                return r
1316*da0073e9SAndroid Build Coastguard Worker
1317*da0073e9SAndroid Build Coastguard Worker            @classmethod
1318*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1319*da0073e9SAndroid Build Coastguard Worker                pass
1320*da0073e9SAndroid Build Coastguard Worker
1321*da0073e9SAndroid Build Coastguard Worker        x = ExtraKeysTensor(torch.randn(3))
1322*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch._C._dispatch_keys(x).has(DispatchKey.NestedTensor))
1323*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(
1324*da0073e9SAndroid Build Coastguard Worker            torch._C._dispatch_keys(x).has(DispatchKey.AutogradNestedTensor)
1325*da0073e9SAndroid Build Coastguard Worker        )
1326*da0073e9SAndroid Build Coastguard Worker
1327*da0073e9SAndroid Build Coastguard Worker    def test_wrapper_subclass_multiprocessing_preserves_dtype(self):
1328*da0073e9SAndroid Build Coastguard Worker        # a and b have dtype of int64, which is purposefully different from the default
1329*da0073e9SAndroid Build Coastguard Worker        # assumed by _make_wrapper_subclass().
1330*da0073e9SAndroid Build Coastguard Worker        a = torch.randperm(5)
1331*da0073e9SAndroid Build Coastguard Worker        b = torch.randperm(5)
1332*da0073e9SAndroid Build Coastguard Worker        data = TwoTensor(a, b)
1333*da0073e9SAndroid Build Coastguard Worker        expected_dtype = data.dtype
1334*da0073e9SAndroid Build Coastguard Worker
1335*da0073e9SAndroid Build Coastguard Worker        loader = torch.utils.data.DataLoader(
1336*da0073e9SAndroid Build Coastguard Worker            [data, data],
1337*da0073e9SAndroid Build Coastguard Worker            batch_size=2,
1338*da0073e9SAndroid Build Coastguard Worker            num_workers=2,
1339*da0073e9SAndroid Build Coastguard Worker            collate_fn=_identity,
1340*da0073e9SAndroid Build Coastguard Worker        )
1341*da0073e9SAndroid Build Coastguard Worker        for batch in loader:
1342*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(batch[0].dtype, expected_dtype)
1343*da0073e9SAndroid Build Coastguard Worker
1344*da0073e9SAndroid Build Coastguard Worker    def test_index_put_where_only_index_is_subclass(self) -> None:
1345*da0073e9SAndroid Build Coastguard Worker        called_funcs = []
1346*da0073e9SAndroid Build Coastguard Worker
1347*da0073e9SAndroid Build Coastguard Worker        class MyTensor(torch.Tensor):
1348*da0073e9SAndroid Build Coastguard Worker            elem: torch.Tensor
1349*da0073e9SAndroid Build Coastguard Worker            __slots__ = ["elem"]
1350*da0073e9SAndroid Build Coastguard Worker
1351*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1352*da0073e9SAndroid Build Coastguard Worker            def __new__(cls, elem, *args, **kwargs):
1353*da0073e9SAndroid Build Coastguard Worker                r = torch.Tensor._make_wrapper_subclass(
1354*da0073e9SAndroid Build Coastguard Worker                    cls,
1355*da0073e9SAndroid Build Coastguard Worker                    elem.size(),
1356*da0073e9SAndroid Build Coastguard Worker                    dtype=elem.dtype,
1357*da0073e9SAndroid Build Coastguard Worker                    layout=elem.layout,
1358*da0073e9SAndroid Build Coastguard Worker                    device=elem.device,
1359*da0073e9SAndroid Build Coastguard Worker                    requires_grad=elem.requires_grad,
1360*da0073e9SAndroid Build Coastguard Worker                )
1361*da0073e9SAndroid Build Coastguard Worker                r.elem = elem
1362*da0073e9SAndroid Build Coastguard Worker                return r
1363*da0073e9SAndroid Build Coastguard Worker
1364*da0073e9SAndroid Build Coastguard Worker            @classmethod
1365*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1366*da0073e9SAndroid Build Coastguard Worker                called_funcs.append(func)
1367*da0073e9SAndroid Build Coastguard Worker                return MyTensor(torch.tensor(3))
1368*da0073e9SAndroid Build Coastguard Worker
1369*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3, 3)
1370*da0073e9SAndroid Build Coastguard Worker        idxs = (MyTensor(torch.tensor(0)),)
1371*da0073e9SAndroid Build Coastguard Worker        v = torch.randn(1)
1372*da0073e9SAndroid Build Coastguard Worker        res = x.index_put_(idxs, v)
1373*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(called_funcs, [torch.ops.aten.index_put_.default])
1374*da0073e9SAndroid Build Coastguard Worker
1375*da0073e9SAndroid Build Coastguard Worker    def test_torch_dispatch_mode_basic(self) -> None:
1376*da0073e9SAndroid Build Coastguard Worker        with capture_logs(is_mode=True) as logs:
1377*da0073e9SAndroid Build Coastguard Worker            with LoggingTensorMode():
1378*da0073e9SAndroid Build Coastguard Worker                torch.empty([])
1379*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1380*da0073e9SAndroid Build Coastguard Worker            "\n".join(logs),
1381*da0073e9SAndroid Build Coastguard Worker            """\
1382*da0073e9SAndroid Build Coastguard Worker$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""",
1383*da0073e9SAndroid Build Coastguard Worker        )
1384*da0073e9SAndroid Build Coastguard Worker
1385*da0073e9SAndroid Build Coastguard Worker    def test_torch_dispatch_mode_unrelated_tensors(self) -> None:
1386*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([])
1387*da0073e9SAndroid Build Coastguard Worker        y = torch.randn([])
1388*da0073e9SAndroid Build Coastguard Worker        with capture_logs(is_mode=True) as logs:
1389*da0073e9SAndroid Build Coastguard Worker            with LoggingTensorMode():
1390*da0073e9SAndroid Build Coastguard Worker                x + y
1391*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1392*da0073e9SAndroid Build Coastguard Worker            "\n".join(logs), """$2: f32[] = torch._ops.aten.add.Tensor($0, $1)"""
1393*da0073e9SAndroid Build Coastguard Worker        )
1394*da0073e9SAndroid Build Coastguard Worker
1395*da0073e9SAndroid Build Coastguard Worker    def test_nested_push_logging_tensor_mode(self):
1396*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([])
1397*da0073e9SAndroid Build Coastguard Worker        y = torch.randn([])
1398*da0073e9SAndroid Build Coastguard Worker        with capture_logs(is_mode=True) as logs:
1399*da0073e9SAndroid Build Coastguard Worker            with LoggingTensorMode():
1400*da0073e9SAndroid Build Coastguard Worker                with LoggingTensorMode():
1401*da0073e9SAndroid Build Coastguard Worker                    torch.empty([])
1402*da0073e9SAndroid Build Coastguard Worker                    x + y
1403*da0073e9SAndroid Build Coastguard Worker
1404*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1405*da0073e9SAndroid Build Coastguard Worker            "\n".join(logs),
1406*da0073e9SAndroid Build Coastguard Worker            """\
1407*da0073e9SAndroid Build Coastguard Worker$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
1408*da0073e9SAndroid Build Coastguard Worker$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
1409*da0073e9SAndroid Build Coastguard Worker$3: f32[] = torch._ops.aten.add.Tensor($1, $2)
1410*da0073e9SAndroid Build Coastguard Worker$3: f32[] = torch._ops.aten.add.Tensor($1, $2)""",
1411*da0073e9SAndroid Build Coastguard Worker        )
1412*da0073e9SAndroid Build Coastguard Worker
1413*da0073e9SAndroid Build Coastguard Worker    def test_capture_logs_with_torch_dispatch_mode(self):
1414*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([])
1415*da0073e9SAndroid Build Coastguard Worker        y = torch.randn([])
1416*da0073e9SAndroid Build Coastguard Worker        with capture_logs_with_logging_tensor_mode() as logs:
1417*da0073e9SAndroid Build Coastguard Worker            torch.empty([])
1418*da0073e9SAndroid Build Coastguard Worker            x + y
1419*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1420*da0073e9SAndroid Build Coastguard Worker            "\n".join(logs),
1421*da0073e9SAndroid Build Coastguard Worker            """\
1422*da0073e9SAndroid Build Coastguard Worker$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
1423*da0073e9SAndroid Build Coastguard Worker$3: f32[] = torch._ops.aten.add.Tensor($1, $2)""",
1424*da0073e9SAndroid Build Coastguard Worker        )
1425*da0073e9SAndroid Build Coastguard Worker
1426*da0073e9SAndroid Build Coastguard Worker        x = torch.randn([])
1427*da0073e9SAndroid Build Coastguard Worker        y = torch.randn([])
1428*da0073e9SAndroid Build Coastguard Worker
1429*da0073e9SAndroid Build Coastguard Worker        with capture_logs_with_logging_tensor_mode() as logs1:
1430*da0073e9SAndroid Build Coastguard Worker            with capture_logs_with_logging_tensor_mode() as logs2:
1431*da0073e9SAndroid Build Coastguard Worker                torch.empty([])
1432*da0073e9SAndroid Build Coastguard Worker                x + y
1433*da0073e9SAndroid Build Coastguard Worker
1434*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
1435*da0073e9SAndroid Build Coastguard Worker            "\n".join(logs2),
1436*da0073e9SAndroid Build Coastguard Worker            """\
1437*da0073e9SAndroid Build Coastguard Worker$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
1438*da0073e9SAndroid Build Coastguard Worker$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
1439*da0073e9SAndroid Build Coastguard Worker$3: f32[] = torch._ops.aten.add.Tensor($1, $2)
1440*da0073e9SAndroid Build Coastguard Worker$3: f32[] = torch._ops.aten.add.Tensor($1, $2)""",
1441*da0073e9SAndroid Build Coastguard Worker        )
1442*da0073e9SAndroid Build Coastguard Worker
1443*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(logs1, logs2)
1444*da0073e9SAndroid Build Coastguard Worker
1445*da0073e9SAndroid Build Coastguard Worker    def test_torch_dispatch_mode_subclass_priority(self) -> None:
1446*da0073e9SAndroid Build Coastguard Worker        class ErrorA(RuntimeError):
1447*da0073e9SAndroid Build Coastguard Worker            pass
1448*da0073e9SAndroid Build Coastguard Worker
1449*da0073e9SAndroid Build Coastguard Worker        class ErrorB(RuntimeError):
1450*da0073e9SAndroid Build Coastguard Worker            pass
1451*da0073e9SAndroid Build Coastguard Worker
1452*da0073e9SAndroid Build Coastguard Worker        class A(torch.Tensor):
1453*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1454*da0073e9SAndroid Build Coastguard Worker            def __new__(cls, elem):
1455*da0073e9SAndroid Build Coastguard Worker                return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
1456*da0073e9SAndroid Build Coastguard Worker
1457*da0073e9SAndroid Build Coastguard Worker            @classmethod
1458*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1459*da0073e9SAndroid Build Coastguard Worker                with AMode():
1460*da0073e9SAndroid Build Coastguard Worker                    raise ErrorA
1461*da0073e9SAndroid Build Coastguard Worker
1462*da0073e9SAndroid Build Coastguard Worker        class B(A):
1463*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1464*da0073e9SAndroid Build Coastguard Worker            def __new__(cls, elem):
1465*da0073e9SAndroid Build Coastguard Worker                return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
1466*da0073e9SAndroid Build Coastguard Worker
1467*da0073e9SAndroid Build Coastguard Worker            @classmethod
1468*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1469*da0073e9SAndroid Build Coastguard Worker                with BMode():
1470*da0073e9SAndroid Build Coastguard Worker                    func(*args, **kwargs)
1471*da0073e9SAndroid Build Coastguard Worker
1472*da0073e9SAndroid Build Coastguard Worker        class AMode(TorchDispatchMode):
1473*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1474*da0073e9SAndroid Build Coastguard Worker                raise ErrorA
1475*da0073e9SAndroid Build Coastguard Worker
1476*da0073e9SAndroid Build Coastguard Worker        class BMode(TorchDispatchMode):
1477*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1478*da0073e9SAndroid Build Coastguard Worker                raise ErrorB
1479*da0073e9SAndroid Build Coastguard Worker
1480*da0073e9SAndroid Build Coastguard Worker        a = A(torch.empty(1))
1481*da0073e9SAndroid Build Coastguard Worker        b = B(torch.empty(1))
1482*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ErrorA):
1483*da0073e9SAndroid Build Coastguard Worker            a + a
1484*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ErrorB):
1485*da0073e9SAndroid Build Coastguard Worker            a + b
1486*da0073e9SAndroid Build Coastguard Worker
1487*da0073e9SAndroid Build Coastguard Worker        # B has precedence over A due to the subclass relationship yet
1488*da0073e9SAndroid Build Coastguard Worker        # modes take precedence over arguments
1489*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ErrorA):
1490*da0073e9SAndroid Build Coastguard Worker            with AMode():
1491*da0073e9SAndroid Build Coastguard Worker                b + b
1492*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ErrorB):
1493*da0073e9SAndroid Build Coastguard Worker            with BMode():
1494*da0073e9SAndroid Build Coastguard Worker                a + a
1495*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ErrorB):
1496*da0073e9SAndroid Build Coastguard Worker            with BMode():
1497*da0073e9SAndroid Build Coastguard Worker                a + b
1498*da0073e9SAndroid Build Coastguard Worker
1499*da0073e9SAndroid Build Coastguard Worker    def test_mode_with_make_subclass(self):
1500*da0073e9SAndroid Build Coastguard Worker        class SubTensor(torch.Tensor):
1501*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1502*da0073e9SAndroid Build Coastguard Worker            def __new__(cls, elem):
1503*da0073e9SAndroid Build Coastguard Worker                return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
1504*da0073e9SAndroid Build Coastguard Worker
1505*da0073e9SAndroid Build Coastguard Worker        class BasicMode(TorchDispatchMode):
1506*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1507*da0073e9SAndroid Build Coastguard Worker                return func(*args, **kwargs)
1508*da0073e9SAndroid Build Coastguard Worker
1509*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(3)
1510*da0073e9SAndroid Build Coastguard Worker        with BasicMode():
1511*da0073e9SAndroid Build Coastguard Worker            y = SubTensor(x)
1512*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(y, SubTensor)
1513*da0073e9SAndroid Build Coastguard Worker
1514*da0073e9SAndroid Build Coastguard Worker    def test_torch_dispatch_mode_respects_no_dispatch(self) -> None:
1515*da0073e9SAndroid Build Coastguard Worker        with capture_logs(is_mode=True) as logs1:
1516*da0073e9SAndroid Build Coastguard Worker            with LoggingTensorMode():
1517*da0073e9SAndroid Build Coastguard Worker                torch.ones([2, 3])
1518*da0073e9SAndroid Build Coastguard Worker                with no_dispatch():
1519*da0073e9SAndroid Build Coastguard Worker                    torch.ones([2, 3])
1520*da0073e9SAndroid Build Coastguard Worker        with capture_logs(is_mode=True) as logs2:
1521*da0073e9SAndroid Build Coastguard Worker            with LoggingTensorMode():
1522*da0073e9SAndroid Build Coastguard Worker                torch.ones([2, 3])
1523*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(logs1, logs2)
1524*da0073e9SAndroid Build Coastguard Worker
1525*da0073e9SAndroid Build Coastguard Worker    def test_shallow_copy_and_detach(self) -> None:
1526*da0073e9SAndroid Build Coastguard Worker        seen = set()
1527*da0073e9SAndroid Build Coastguard Worker        test_case = self
1528*da0073e9SAndroid Build Coastguard Worker
1529*da0073e9SAndroid Build Coastguard Worker        class TestMode(TorchDispatchMode):
1530*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1531*da0073e9SAndroid Build Coastguard Worker                tree_map_only(
1532*da0073e9SAndroid Build Coastguard Worker                    torch.Tensor, lambda t: test_case.assertIn(t, seen), (args, kwargs)
1533*da0073e9SAndroid Build Coastguard Worker                )
1534*da0073e9SAndroid Build Coastguard Worker                if kwargs is None:
1535*da0073e9SAndroid Build Coastguard Worker                    kwargs = {}
1536*da0073e9SAndroid Build Coastguard Worker                r = func(*args, **kwargs)
1537*da0073e9SAndroid Build Coastguard Worker                tree_map_only(torch.Tensor, lambda t: seen.add(t), r)
1538*da0073e9SAndroid Build Coastguard Worker                return r
1539*da0073e9SAndroid Build Coastguard Worker
1540*da0073e9SAndroid Build Coastguard Worker        with TestMode():
1541*da0073e9SAndroid Build Coastguard Worker            x = torch.randn(3, requires_grad=True)
1542*da0073e9SAndroid Build Coastguard Worker            loss = (x * x).sum()
1543*da0073e9SAndroid Build Coastguard Worker            loss.backward()
1544*da0073e9SAndroid Build Coastguard Worker
1545*da0073e9SAndroid Build Coastguard Worker    def test_exception_handling(self):
1546*da0073e9SAndroid Build Coastguard Worker        class A(torch.Tensor):
1547*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1548*da0073e9SAndroid Build Coastguard Worker            def __new__(cls, elem):
1549*da0073e9SAndroid Build Coastguard Worker                return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
1550*da0073e9SAndroid Build Coastguard Worker
1551*da0073e9SAndroid Build Coastguard Worker        class AMode(TorchDispatchMode):
1552*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1553*da0073e9SAndroid Build Coastguard Worker                if func.__name__ == "randn.default":
1554*da0073e9SAndroid Build Coastguard Worker                    raise RuntimeError
1555*da0073e9SAndroid Build Coastguard Worker                return A(torch.zeros(()))
1556*da0073e9SAndroid Build Coastguard Worker
1557*da0073e9SAndroid Build Coastguard Worker        with AMode():
1558*da0073e9SAndroid Build Coastguard Worker            try:
1559*da0073e9SAndroid Build Coastguard Worker                torch.randn(())
1560*da0073e9SAndroid Build Coastguard Worker            except RuntimeError:
1561*da0073e9SAndroid Build Coastguard Worker                pass
1562*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(isinstance(torch.zeros(()), A))
1563*da0073e9SAndroid Build Coastguard Worker
1564*da0073e9SAndroid Build Coastguard Worker    def test_with_mode_created_separately(self):
1565*da0073e9SAndroid Build Coastguard Worker        class ErrorA(RuntimeError):
1566*da0073e9SAndroid Build Coastguard Worker            pass
1567*da0073e9SAndroid Build Coastguard Worker
1568*da0073e9SAndroid Build Coastguard Worker        class A(TorchDispatchMode):
1569*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1570*da0073e9SAndroid Build Coastguard Worker                raise ErrorA
1571*da0073e9SAndroid Build Coastguard Worker
1572*da0073e9SAndroid Build Coastguard Worker        x = A()
1573*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(ErrorA):
1574*da0073e9SAndroid Build Coastguard Worker            with x:
1575*da0073e9SAndroid Build Coastguard Worker                torch.empty([])
1576*da0073e9SAndroid Build Coastguard Worker
1577*da0073e9SAndroid Build Coastguard Worker    def test_with_nested_modes(self):
1578*da0073e9SAndroid Build Coastguard Worker        class ErrorA(RuntimeError):
1579*da0073e9SAndroid Build Coastguard Worker            def __init__(self, msg):
1580*da0073e9SAndroid Build Coastguard Worker                super().__init__(msg)
1581*da0073e9SAndroid Build Coastguard Worker
1582*da0073e9SAndroid Build Coastguard Worker        class A(TorchDispatchMode):
1583*da0073e9SAndroid Build Coastguard Worker            def __init__(self, msg):
1584*da0073e9SAndroid Build Coastguard Worker                self.msg = msg
1585*da0073e9SAndroid Build Coastguard Worker
1586*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1587*da0073e9SAndroid Build Coastguard Worker                raise ErrorA(self.msg)
1588*da0073e9SAndroid Build Coastguard Worker
1589*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(ErrorA, "layer2"):
1590*da0073e9SAndroid Build Coastguard Worker            with A("layer1"):
1591*da0073e9SAndroid Build Coastguard Worker                with A("layer2"):
1592*da0073e9SAndroid Build Coastguard Worker                    torch.empty([])
1593*da0073e9SAndroid Build Coastguard Worker
1594*da0073e9SAndroid Build Coastguard Worker    def test_make_subclass_with_modes(self):
1595*da0073e9SAndroid Build Coastguard Worker        class ModeTensor(torch.Tensor):
1596*da0073e9SAndroid Build Coastguard Worker            def __new__(cls, elem, mode):
1597*da0073e9SAndroid Build Coastguard Worker                r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
1598*da0073e9SAndroid Build Coastguard Worker                r.elem = elem
1599*da0073e9SAndroid Build Coastguard Worker                r.mode = mode
1600*da0073e9SAndroid Build Coastguard Worker                return r
1601*da0073e9SAndroid Build Coastguard Worker
1602*da0073e9SAndroid Build Coastguard Worker            @classmethod
1603*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1604*da0073e9SAndroid Build Coastguard Worker                raise NotImplementedError("Shouldn't be here")
1605*da0073e9SAndroid Build Coastguard Worker
1606*da0073e9SAndroid Build Coastguard Worker        class Mode(TorchDispatchMode):
1607*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1608*da0073e9SAndroid Build Coastguard Worker                def unwrap(e):
1609*da0073e9SAndroid Build Coastguard Worker                    if isinstance(e, ModeTensor):
1610*da0073e9SAndroid Build Coastguard Worker                        return e.elem
1611*da0073e9SAndroid Build Coastguard Worker                    else:
1612*da0073e9SAndroid Build Coastguard Worker                        return e
1613*da0073e9SAndroid Build Coastguard Worker
1614*da0073e9SAndroid Build Coastguard Worker                def wrap(t):
1615*da0073e9SAndroid Build Coastguard Worker                    if isinstance(t, torch.Tensor):
1616*da0073e9SAndroid Build Coastguard Worker                        return ModeTensor(t, self)
1617*da0073e9SAndroid Build Coastguard Worker                    else:
1618*da0073e9SAndroid Build Coastguard Worker                        return t
1619*da0073e9SAndroid Build Coastguard Worker
1620*da0073e9SAndroid Build Coastguard Worker                return wrap(func(*tuple(unwrap(a) for a in args), **kwargs))
1621*da0073e9SAndroid Build Coastguard Worker
1622*da0073e9SAndroid Build Coastguard Worker        class BasicMode(TorchDispatchMode):
1623*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1624*da0073e9SAndroid Build Coastguard Worker                return func(*args, **kwargs)
1625*da0073e9SAndroid Build Coastguard Worker
1626*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(4.0)
1627*da0073e9SAndroid Build Coastguard Worker        with Mode():
1628*da0073e9SAndroid Build Coastguard Worker            y = x + x
1629*da0073e9SAndroid Build Coastguard Worker            z = y + y
1630*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(y, ModeTensor)
1631*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(z, ModeTensor)
1632*da0073e9SAndroid Build Coastguard Worker
1633*da0073e9SAndroid Build Coastguard Worker        with Mode():
1634*da0073e9SAndroid Build Coastguard Worker            with BasicMode():  # we can't nest two modes that call make_subclass because it only accepts vanilla tensors
1635*da0073e9SAndroid Build Coastguard Worker                y = x + x
1636*da0073e9SAndroid Build Coastguard Worker                z = y + y
1637*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(y, ModeTensor)
1638*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(z, ModeTensor)
1639*da0073e9SAndroid Build Coastguard Worker
1640*da0073e9SAndroid Build Coastguard Worker        assert self.assertRaisesRegex(
1641*da0073e9SAndroid Build Coastguard Worker            RuntimeError,
1642*da0073e9SAndroid Build Coastguard Worker            "subclass Mode but.* associated to a python object of type Mode",
1643*da0073e9SAndroid Build Coastguard Worker        )
1644*da0073e9SAndroid Build Coastguard Worker
1645*da0073e9SAndroid Build Coastguard Worker    def test_notimplemented_mode(self):
1646*da0073e9SAndroid Build Coastguard Worker        sub_count = 0
1647*da0073e9SAndroid Build Coastguard Worker
1648*da0073e9SAndroid Build Coastguard Worker        class PoliteMode(TorchDispatchMode):
1649*da0073e9SAndroid Build Coastguard Worker            def __init__(self) -> None:
1650*da0073e9SAndroid Build Coastguard Worker                self.pre_count = 0
1651*da0073e9SAndroid Build Coastguard Worker                self.post_count = 0
1652*da0073e9SAndroid Build Coastguard Worker
1653*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1654*da0073e9SAndroid Build Coastguard Worker                self.pre_count += 1
1655*da0073e9SAndroid Build Coastguard Worker                if any(t is not torch.Tensor for t in types):
1656*da0073e9SAndroid Build Coastguard Worker                    return NotImplemented
1657*da0073e9SAndroid Build Coastguard Worker                self.post_count += 1
1658*da0073e9SAndroid Build Coastguard Worker                return func(*args, **kwargs)
1659*da0073e9SAndroid Build Coastguard Worker
1660*da0073e9SAndroid Build Coastguard Worker        class SubTensor(torch.Tensor):
1661*da0073e9SAndroid Build Coastguard Worker            def __new__(cls, elem):
1662*da0073e9SAndroid Build Coastguard Worker                r = torch.Tensor._make_wrapper_subclass(cls, elem.shape)
1663*da0073e9SAndroid Build Coastguard Worker                r.elem = elem
1664*da0073e9SAndroid Build Coastguard Worker                return r
1665*da0073e9SAndroid Build Coastguard Worker
1666*da0073e9SAndroid Build Coastguard Worker            @classmethod
1667*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1668*da0073e9SAndroid Build Coastguard Worker                nonlocal sub_count
1669*da0073e9SAndroid Build Coastguard Worker                sub_count += 1
1670*da0073e9SAndroid Build Coastguard Worker
1671*da0073e9SAndroid Build Coastguard Worker                def unwrap(t):
1672*da0073e9SAndroid Build Coastguard Worker                    if isinstance(t, SubTensor):
1673*da0073e9SAndroid Build Coastguard Worker                        return t.elem
1674*da0073e9SAndroid Build Coastguard Worker                    else:
1675*da0073e9SAndroid Build Coastguard Worker                        return t
1676*da0073e9SAndroid Build Coastguard Worker
1677*da0073e9SAndroid Build Coastguard Worker                return func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
1678*da0073e9SAndroid Build Coastguard Worker
1679*da0073e9SAndroid Build Coastguard Worker        a = SubTensor(torch.randn(2))
1680*da0073e9SAndroid Build Coastguard Worker        with PoliteMode() as mode:
1681*da0073e9SAndroid Build Coastguard Worker            a.abs()
1682*da0073e9SAndroid Build Coastguard Worker
1683*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mode.pre_count, 2)
1684*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(mode.post_count, 1)
1685*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(sub_count, 1)
1686*da0073e9SAndroid Build Coastguard Worker
1687*da0073e9SAndroid Build Coastguard Worker        # make sure this doesn't error
1688*da0073e9SAndroid Build Coastguard Worker        with PoliteMode():
1689*da0073e9SAndroid Build Coastguard Worker            with PoliteMode():
1690*da0073e9SAndroid Build Coastguard Worker                a.abs()
1691*da0073e9SAndroid Build Coastguard Worker
1692*da0073e9SAndroid Build Coastguard Worker    def test_nesting_same_mode(self):
1693*da0073e9SAndroid Build Coastguard Worker        # If the pushed mode is the same instance as the current mode, we allow pushing an already active mode.
1694*da0073e9SAndroid Build Coastguard Worker
1695*da0073e9SAndroid Build Coastguard Worker        with capture_logs(is_mode=True) as logs:
1696*da0073e9SAndroid Build Coastguard Worker            with LoggingTensorMode() as reenabled:
1697*da0073e9SAndroid Build Coastguard Worker                with reenabled:
1698*da0073e9SAndroid Build Coastguard Worker                    torch.empty([])
1699*da0073e9SAndroid Build Coastguard Worker            self.assertExpectedInline(
1700*da0073e9SAndroid Build Coastguard Worker                "\n".join(logs),
1701*da0073e9SAndroid Build Coastguard Worker                """\
1702*da0073e9SAndroid Build Coastguard Worker$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)
1703*da0073e9SAndroid Build Coastguard Worker$0: f32[] = torch._ops.aten.empty.memory_format([], device=device(type='cpu'), pin_memory=False)""",
1704*da0073e9SAndroid Build Coastguard Worker            )
1705*da0073e9SAndroid Build Coastguard Worker
1706*da0073e9SAndroid Build Coastguard Worker    def test_error_using_class_method_on_mode(self):
1707*da0073e9SAndroid Build Coastguard Worker        class A(TorchDispatchMode):
1708*da0073e9SAndroid Build Coastguard Worker            @classmethod
1709*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1710*da0073e9SAndroid Build Coastguard Worker                return func(args, kwargs)
1711*da0073e9SAndroid Build Coastguard Worker
1712*da0073e9SAndroid Build Coastguard Worker        x = torch.tensor(5.0)
1713*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1714*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "classmethod is not supported, please make it a plain method"
1715*da0073e9SAndroid Build Coastguard Worker        ):
1716*da0073e9SAndroid Build Coastguard Worker            with A():
1717*da0073e9SAndroid Build Coastguard Worker                x + x
1718*da0073e9SAndroid Build Coastguard Worker
1719*da0073e9SAndroid Build Coastguard Worker    def test_get_cur_mode(self):
1720*da0073e9SAndroid Build Coastguard Worker        class A(TorchDispatchMode):
1721*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1722*da0073e9SAndroid Build Coastguard Worker                pass
1723*da0073e9SAndroid Build Coastguard Worker
1724*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(_get_current_dispatch_mode(), None)
1725*da0073e9SAndroid Build Coastguard Worker
1726*da0073e9SAndroid Build Coastguard Worker        with A() as mode1:
1727*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(_get_current_dispatch_mode(), mode1)
1728*da0073e9SAndroid Build Coastguard Worker
1729*da0073e9SAndroid Build Coastguard Worker        with mode1:
1730*da0073e9SAndroid Build Coastguard Worker            with A() as mode2:
1731*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(_get_current_dispatch_mode(), mode2)
1732*da0073e9SAndroid Build Coastguard Worker
1733*da0073e9SAndroid Build Coastguard Worker    def test_get_mode_stack(self):
1734*da0073e9SAndroid Build Coastguard Worker        class A(TorchDispatchMode):
1735*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1736*da0073e9SAndroid Build Coastguard Worker                pass
1737*da0073e9SAndroid Build Coastguard Worker
1738*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(_get_current_dispatch_mode_stack(), [])
1739*da0073e9SAndroid Build Coastguard Worker
1740*da0073e9SAndroid Build Coastguard Worker        with A() as mode1:
1741*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(_get_current_dispatch_mode_stack(), [mode1])
1742*da0073e9SAndroid Build Coastguard Worker
1743*da0073e9SAndroid Build Coastguard Worker        with mode1:
1744*da0073e9SAndroid Build Coastguard Worker            with A() as mode2:
1745*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(_get_current_dispatch_mode_stack(), [mode1, mode2])
1746*da0073e9SAndroid Build Coastguard Worker
1747*da0073e9SAndroid Build Coastguard Worker    def test_all_same_mode(self):
1748*da0073e9SAndroid Build Coastguard Worker        x = LoggingTensorMode()
1749*da0073e9SAndroid Build Coastguard Worker        y = LoggingTensorMode()
1750*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(all_same_mode([x, x, x]))
1751*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(all_same_mode([x, None]))
1752*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(all_same_mode([x, y]))
1753*da0073e9SAndroid Build Coastguard Worker
1754*da0073e9SAndroid Build Coastguard Worker    def test_mode_detection(self):
1755*da0073e9SAndroid Build Coastguard Worker        class InfraMode(TorchDispatchMode):
1756*da0073e9SAndroid Build Coastguard Worker            @classmethod
1757*da0073e9SAndroid Build Coastguard Worker            def is_infra_mode(cls):
1758*da0073e9SAndroid Build Coastguard Worker                return True
1759*da0073e9SAndroid Build Coastguard Worker
1760*da0073e9SAndroid Build Coastguard Worker        class NonInfraMode(TorchDispatchMode):
1761*da0073e9SAndroid Build Coastguard Worker            pass
1762*da0073e9SAndroid Build Coastguard Worker
1763*da0073e9SAndroid Build Coastguard Worker        with InfraMode():
1764*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(is_in_torch_dispatch_mode())
1765*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(is_in_torch_dispatch_mode(include_infra_modes=False))
1766*da0073e9SAndroid Build Coastguard Worker            with NonInfraMode():
1767*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(is_in_torch_dispatch_mode())
1768*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(is_in_torch_dispatch_mode(include_infra_modes=False))
1769*da0073e9SAndroid Build Coastguard Worker                with InfraMode():
1770*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(is_in_torch_dispatch_mode())
1771*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(
1772*da0073e9SAndroid Build Coastguard Worker                        is_in_torch_dispatch_mode(include_infra_modes=False)
1773*da0073e9SAndroid Build Coastguard Worker                    )
1774*da0073e9SAndroid Build Coastguard Worker
1775*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(is_in_torch_dispatch_mode())
1776*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(is_in_torch_dispatch_mode(include_infra_modes=False))
1777*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(is_in_torch_dispatch_mode())
1778*da0073e9SAndroid Build Coastguard Worker            self.assertFalse(is_in_torch_dispatch_mode(include_infra_modes=False))
1779*da0073e9SAndroid Build Coastguard Worker
1780*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(is_in_torch_dispatch_mode())
1781*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(is_in_torch_dispatch_mode(include_infra_modes=False))
1782*da0073e9SAndroid Build Coastguard Worker
1783*da0073e9SAndroid Build Coastguard Worker    def test_tolist_numpy_with_torch_dispatch_mode(self) -> None:
1784*da0073e9SAndroid Build Coastguard Worker        x = LoggingTensor(torch.tensor([2.0, 3.0]))
1785*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1786*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "is not supported for tensor subclasses."
1787*da0073e9SAndroid Build Coastguard Worker        ):
1788*da0073e9SAndroid Build Coastguard Worker            x.tolist()
1789*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(
1790*da0073e9SAndroid Build Coastguard Worker            RuntimeError, "is not supported for tensor subclasses."
1791*da0073e9SAndroid Build Coastguard Worker        ):
1792*da0073e9SAndroid Build Coastguard Worker            x.numpy()
1793*da0073e9SAndroid Build Coastguard Worker        with self.assertRaises(AssertionError):
1794*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(x, None)
1795*da0073e9SAndroid Build Coastguard Worker
1796*da0073e9SAndroid Build Coastguard Worker    def test_record_stream(self) -> None:
1797*da0073e9SAndroid Build Coastguard Worker        class TestMode(TorchDispatchMode):
1798*da0073e9SAndroid Build Coastguard Worker            def __init__(self, testcase):
1799*da0073e9SAndroid Build Coastguard Worker                self.testcase = testcase
1800*da0073e9SAndroid Build Coastguard Worker
1801*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1802*da0073e9SAndroid Build Coastguard Worker                self.testcase.assertEqual(func.name(), "aten::record_stream")
1803*da0073e9SAndroid Build Coastguard Worker                self.testcase.assertIsInstance(args[0], torch.Tensor)
1804*da0073e9SAndroid Build Coastguard Worker                self.testcase.assertIsInstance(args[1], torch.Stream)
1805*da0073e9SAndroid Build Coastguard Worker                self.testcase.assertEqual(args[1].stream_id, 1)
1806*da0073e9SAndroid Build Coastguard Worker                self.testcase.assertEqual(args[1].device_index, 2)
1807*da0073e9SAndroid Build Coastguard Worker                self.testcase.assertEqual(args[1].device_type, 3)
1808*da0073e9SAndroid Build Coastguard Worker
1809*da0073e9SAndroid Build Coastguard Worker        t = torch.tensor(5.0)
1810*da0073e9SAndroid Build Coastguard Worker        s = torch.Stream(stream_id=1, device_index=2, device_type=3)
1811*da0073e9SAndroid Build Coastguard Worker        with TestMode(self):
1812*da0073e9SAndroid Build Coastguard Worker            t.record_stream(s)
1813*da0073e9SAndroid Build Coastguard Worker
1814*da0073e9SAndroid Build Coastguard Worker    def test_return_stream(self) -> None:
1815*da0073e9SAndroid Build Coastguard Worker        with _scoped_library("test_return_stream", "DEF") as l_def:
1816*da0073e9SAndroid Build Coastguard Worker            l_def.define("return_stream(Tensor self) -> Stream")
1817*da0073e9SAndroid Build Coastguard Worker            with _scoped_library("test_return_stream", "IMPL", "CPU") as l_impl:
1818*da0073e9SAndroid Build Coastguard Worker                l_impl.impl(
1819*da0073e9SAndroid Build Coastguard Worker                    "return_stream",
1820*da0073e9SAndroid Build Coastguard Worker                    lambda _: torch.Stream(stream_id=0, device_index=1, device_type=2),
1821*da0073e9SAndroid Build Coastguard Worker                )
1822*da0073e9SAndroid Build Coastguard Worker
1823*da0073e9SAndroid Build Coastguard Worker                class TestMode(TorchDispatchMode):
1824*da0073e9SAndroid Build Coastguard Worker                    def __torch_dispatch__(self, func, types, args=(), kwargs=None):
1825*da0073e9SAndroid Build Coastguard Worker                        return torch.Stream(stream_id=1, device_index=2, device_type=3)
1826*da0073e9SAndroid Build Coastguard Worker
1827*da0073e9SAndroid Build Coastguard Worker                t = torch.tensor(5.0)
1828*da0073e9SAndroid Build Coastguard Worker                s = torch.ops.test_return_stream.return_stream(t)
1829*da0073e9SAndroid Build Coastguard Worker                self.assertIsInstance(s, torch.Stream)
1830*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(s.stream_id, 0)
1831*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(s.device_index, 1)
1832*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(s.device_type, 2)
1833*da0073e9SAndroid Build Coastguard Worker
1834*da0073e9SAndroid Build Coastguard Worker                with TestMode():
1835*da0073e9SAndroid Build Coastguard Worker                    s = torch.ops.test_return_stream.return_stream(t)
1836*da0073e9SAndroid Build Coastguard Worker                self.assertIsInstance(s, torch.Stream)
1837*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(s.stream_id, 1)
1838*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(s.device_index, 2)
1839*da0073e9SAndroid Build Coastguard Worker                self.assertEqual(s.device_type, 3)
1840*da0073e9SAndroid Build Coastguard Worker
1841*da0073e9SAndroid Build Coastguard Worker    def test_subclass_autograd_device_check(self) -> None:
1842*da0073e9SAndroid Build Coastguard Worker        class NonWrapperSubclass(torch.Tensor):
1843*da0073e9SAndroid Build Coastguard Worker            elem: torch.Tensor
1844*da0073e9SAndroid Build Coastguard Worker
1845*da0073e9SAndroid Build Coastguard Worker            __slots__ = ["elem"]
1846*da0073e9SAndroid Build Coastguard Worker
1847*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1848*da0073e9SAndroid Build Coastguard Worker            def __new__(cls, elem, *args, **kwargs):
1849*da0073e9SAndroid Build Coastguard Worker                # Wrong device here!
1850*da0073e9SAndroid Build Coastguard Worker                r = torch.Tensor._make_subclass(
1851*da0073e9SAndroid Build Coastguard Worker                    cls, elem.to("meta"), elem.requires_grad
1852*da0073e9SAndroid Build Coastguard Worker                )
1853*da0073e9SAndroid Build Coastguard Worker                # ...the real tensor is held as an element on the tensor.
1854*da0073e9SAndroid Build Coastguard Worker                r.elem = elem
1855*da0073e9SAndroid Build Coastguard Worker                return r
1856*da0073e9SAndroid Build Coastguard Worker
1857*da0073e9SAndroid Build Coastguard Worker            @classmethod
1858*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1859*da0073e9SAndroid Build Coastguard Worker                def unwrap(e):
1860*da0073e9SAndroid Build Coastguard Worker                    return e.elem if isinstance(e, NonWrapperSubclass) else e
1861*da0073e9SAndroid Build Coastguard Worker
1862*da0073e9SAndroid Build Coastguard Worker                def wrap(e):
1863*da0073e9SAndroid Build Coastguard Worker                    return NonWrapperSubclass(e) if isinstance(e, torch.Tensor) else e
1864*da0073e9SAndroid Build Coastguard Worker
1865*da0073e9SAndroid Build Coastguard Worker                rs = tree_map(
1866*da0073e9SAndroid Build Coastguard Worker                    wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
1867*da0073e9SAndroid Build Coastguard Worker                )
1868*da0073e9SAndroid Build Coastguard Worker                logging.getLogger("NonWrapperSubclass").info(
1869*da0073e9SAndroid Build Coastguard Worker                    f"{func.__module__}.{func.__name__}",  # noqa: G004
1870*da0073e9SAndroid Build Coastguard Worker                    args,
1871*da0073e9SAndroid Build Coastguard Worker                    kwargs,
1872*da0073e9SAndroid Build Coastguard Worker                    rs,
1873*da0073e9SAndroid Build Coastguard Worker                )
1874*da0073e9SAndroid Build Coastguard Worker                return rs
1875*da0073e9SAndroid Build Coastguard Worker
1876*da0073e9SAndroid Build Coastguard Worker        x = NonWrapperSubclass(torch.tensor([3.0, 4.0], requires_grad=True))
1877*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2, requires_grad=True)
1878*da0073e9SAndroid Build Coastguard Worker        z = x * y
1879*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(z, NonWrapperSubclass)
1880*da0073e9SAndroid Build Coastguard Worker        z.sum().backward(torch.tensor(1))
1881*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(x.grad, y)
1882*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(y.grad, x)
1883*da0073e9SAndroid Build Coastguard Worker
1884*da0073e9SAndroid Build Coastguard Worker    def test_none_wrapping(self):
1885*da0073e9SAndroid Build Coastguard Worker        # A Tensor subclass that returns None when doing add
1886*da0073e9SAndroid Build Coastguard Worker        # See LoggingTensor above for more details on the subclass
1887*da0073e9SAndroid Build Coastguard Worker        class SubclassWithNone(torch.Tensor):
1888*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1889*da0073e9SAndroid Build Coastguard Worker            def __new__(cls, elem, *args, **kwargs):
1890*da0073e9SAndroid Build Coastguard Worker                r = torch.Tensor._make_wrapper_subclass(
1891*da0073e9SAndroid Build Coastguard Worker                    cls,
1892*da0073e9SAndroid Build Coastguard Worker                    elem.size(),
1893*da0073e9SAndroid Build Coastguard Worker                    dtype=elem.dtype,
1894*da0073e9SAndroid Build Coastguard Worker                    layout=elem.layout,
1895*da0073e9SAndroid Build Coastguard Worker                    device=elem.device,
1896*da0073e9SAndroid Build Coastguard Worker                    requires_grad=elem.requires_grad,
1897*da0073e9SAndroid Build Coastguard Worker                )
1898*da0073e9SAndroid Build Coastguard Worker                r.elem = elem
1899*da0073e9SAndroid Build Coastguard Worker                return r
1900*da0073e9SAndroid Build Coastguard Worker
1901*da0073e9SAndroid Build Coastguard Worker            @classmethod
1902*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1903*da0073e9SAndroid Build Coastguard Worker                def unwrap(e):
1904*da0073e9SAndroid Build Coastguard Worker                    return e.elem if isinstance(e, SubclassWithNone) else e
1905*da0073e9SAndroid Build Coastguard Worker
1906*da0073e9SAndroid Build Coastguard Worker                def wrap(e):
1907*da0073e9SAndroid Build Coastguard Worker                    return SubclassWithNone(e) if isinstance(e, torch.Tensor) else e
1908*da0073e9SAndroid Build Coastguard Worker
1909*da0073e9SAndroid Build Coastguard Worker                rs = tree_map(
1910*da0073e9SAndroid Build Coastguard Worker                    wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))
1911*da0073e9SAndroid Build Coastguard Worker                )
1912*da0073e9SAndroid Build Coastguard Worker                if func.overloadpacket.__name__ == "add":
1913*da0073e9SAndroid Build Coastguard Worker                    return None
1914*da0073e9SAndroid Build Coastguard Worker                else:
1915*da0073e9SAndroid Build Coastguard Worker                    return rs
1916*da0073e9SAndroid Build Coastguard Worker
1917*da0073e9SAndroid Build Coastguard Worker        x = SubclassWithNone(torch.rand(2))
1918*da0073e9SAndroid Build Coastguard Worker        # Make sure both run without error
1919*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(x * 2, SubclassWithNone)
1920*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(x + 2)
1921*da0073e9SAndroid Build Coastguard Worker
1922*da0073e9SAndroid Build Coastguard Worker        x.requires_grad_()
1923*da0073e9SAndroid Build Coastguard Worker        out = x.acos().sum()
1924*da0073e9SAndroid Build Coastguard Worker
1925*da0073e9SAndroid Build Coastguard Worker        # The backward of acos does add then rsqrt so here we make sure that the
1926*da0073e9SAndroid Build Coastguard Worker        # undefined Tensor generated by the user code is nicely handled.
1927*da0073e9SAndroid Build Coastguard Worker        # If acos formula changes in the future, this can be replaced by any other
1928*da0073e9SAndroid Build Coastguard Worker        # function that does add then something in the backward in a composite way
1929*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "but got None"):
1930*da0073e9SAndroid Build Coastguard Worker            out.backward()
1931*da0073e9SAndroid Build Coastguard Worker
1932*da0073e9SAndroid Build Coastguard Worker    def test_storage_can_be_converted_to_python_object(self):
1933*da0073e9SAndroid Build Coastguard Worker        s = torch.Storage()
1934*da0073e9SAndroid Build Coastguard Worker        z = LoggingTensor(torch.empty([]))
1935*da0073e9SAndroid Build Coastguard Worker        z.set_(s)
1936*da0073e9SAndroid Build Coastguard Worker
1937*da0073e9SAndroid Build Coastguard Worker    def test_autograd_in_attr(self):
1938*da0073e9SAndroid Build Coastguard Worker        # We want the wrapped Tensor to require gradients!
1939*da0073e9SAndroid Build Coastguard Worker        true_t = torch.rand(2, requires_grad=True)
1940*da0073e9SAndroid Build Coastguard Worker        t = LoggingTensorReentrant(true_t)
1941*da0073e9SAndroid Build Coastguard Worker
1942*da0073e9SAndroid Build Coastguard Worker        out = t + 2
1943*da0073e9SAndroid Build Coastguard Worker
1944*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(out.requires_grad)
1945*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(out.grad_fn)
1946*da0073e9SAndroid Build Coastguard Worker
1947*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(out.elem.requires_grad)
1948*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(out.elem.grad_fn)
1949*da0073e9SAndroid Build Coastguard Worker
1950*da0073e9SAndroid Build Coastguard Worker        with self.assertRaisesRegex(RuntimeError, "does not require grad"):
1951*da0073e9SAndroid Build Coastguard Worker            out.sum().backward()
1952*da0073e9SAndroid Build Coastguard Worker
1953*da0073e9SAndroid Build Coastguard Worker        out.elem.sum().backward()
1954*da0073e9SAndroid Build Coastguard Worker
1955*da0073e9SAndroid Build Coastguard Worker        self.assertIsNone(t.grad)
1956*da0073e9SAndroid Build Coastguard Worker        self.assertIsNotNone(t.elem.grad)
1957*da0073e9SAndroid Build Coastguard Worker
1958*da0073e9SAndroid Build Coastguard Worker    def test_dispatch_super_call(self):
1959*da0073e9SAndroid Build Coastguard Worker        called = []
1960*da0073e9SAndroid Build Coastguard Worker
1961*da0073e9SAndroid Build Coastguard Worker        class SubTensor(torch.Tensor):
1962*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1963*da0073e9SAndroid Build Coastguard Worker            def __new__(cls, elem):
1964*da0073e9SAndroid Build Coastguard Worker                return torch.Tensor._make_subclass(cls, elem)
1965*da0073e9SAndroid Build Coastguard Worker
1966*da0073e9SAndroid Build Coastguard Worker            @classmethod
1967*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1968*da0073e9SAndroid Build Coastguard Worker                called.append(func)
1969*da0073e9SAndroid Build Coastguard Worker                return super().__torch_dispatch__(func, types, args, kwargs)
1970*da0073e9SAndroid Build Coastguard Worker
1971*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2)
1972*da0073e9SAndroid Build Coastguard Worker        y = torch.randn(2)
1973*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(SubTensor(x) + SubTensor(y), x + y)
1974*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(called, [torch.ops.aten.add.Tensor])
1975*da0073e9SAndroid Build Coastguard Worker
1976*da0073e9SAndroid Build Coastguard Worker    def test_dispatch_super_call_list_arg(self):
1977*da0073e9SAndroid Build Coastguard Worker        called = []
1978*da0073e9SAndroid Build Coastguard Worker
1979*da0073e9SAndroid Build Coastguard Worker        class SubTensorWithListArg(torch.Tensor):
1980*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1981*da0073e9SAndroid Build Coastguard Worker            def __new__(cls, elem):
1982*da0073e9SAndroid Build Coastguard Worker                return torch.Tensor._make_subclass(cls, elem)
1983*da0073e9SAndroid Build Coastguard Worker
1984*da0073e9SAndroid Build Coastguard Worker            @classmethod
1985*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
1986*da0073e9SAndroid Build Coastguard Worker                called.append(func)
1987*da0073e9SAndroid Build Coastguard Worker                return super().__torch_dispatch__(func, types, list(args), kwargs)
1988*da0073e9SAndroid Build Coastguard Worker
1989*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2)
1990*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(SubTensorWithListArg(x).neg(), x.neg())
1991*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(called, [torch.ops.aten.neg.default])
1992*da0073e9SAndroid Build Coastguard Worker
1993*da0073e9SAndroid Build Coastguard Worker    def test_dispatch_super_dont_autograd(self):
1994*da0073e9SAndroid Build Coastguard Worker        called = []
1995*da0073e9SAndroid Build Coastguard Worker
1996*da0073e9SAndroid Build Coastguard Worker        class SubTensor(torch.Tensor):
1997*da0073e9SAndroid Build Coastguard Worker            @staticmethod
1998*da0073e9SAndroid Build Coastguard Worker            def __new__(cls, elem):
1999*da0073e9SAndroid Build Coastguard Worker                return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
2000*da0073e9SAndroid Build Coastguard Worker
2001*da0073e9SAndroid Build Coastguard Worker            @classmethod
2002*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
2003*da0073e9SAndroid Build Coastguard Worker                called.append(func)
2004*da0073e9SAndroid Build Coastguard Worker                # This argument still requires grad because it was passed
2005*da0073e9SAndroid Build Coastguard Worker                # through directly...
2006*da0073e9SAndroid Build Coastguard Worker                self.assertTrue(args[0].requires_grad)
2007*da0073e9SAndroid Build Coastguard Worker                r = super().__torch_dispatch__(func, types, args, kwargs)
2008*da0073e9SAndroid Build Coastguard Worker                # But the output better not require grad, because that means
2009*da0073e9SAndroid Build Coastguard Worker                # you did autograd again in torch dispatch (oops)
2010*da0073e9SAndroid Build Coastguard Worker                self.assertFalse(r.requires_grad)
2011*da0073e9SAndroid Build Coastguard Worker                return r
2012*da0073e9SAndroid Build Coastguard Worker
2013*da0073e9SAndroid Build Coastguard Worker        x = SubTensor(torch.randn(2, requires_grad=True))
2014*da0073e9SAndroid Build Coastguard Worker        x.neg()
2015*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(called, [torch.ops.aten.neg.default])
2016*da0073e9SAndroid Build Coastguard Worker
2017*da0073e9SAndroid Build Coastguard Worker    def test_set_data(self):
2018*da0073e9SAndroid Build Coastguard Worker        called = 0
2019*da0073e9SAndroid Build Coastguard Worker
2020*da0073e9SAndroid Build Coastguard Worker        class SubTensor(torch.Tensor):
2021*da0073e9SAndroid Build Coastguard Worker            @classmethod
2022*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
2023*da0073e9SAndroid Build Coastguard Worker                nonlocal called
2024*da0073e9SAndroid Build Coastguard Worker                called += 1
2025*da0073e9SAndroid Build Coastguard Worker                return super().__torch_dispatch__(func, types, args, kwargs)
2026*da0073e9SAndroid Build Coastguard Worker
2027*da0073e9SAndroid Build Coastguard Worker        x = SubTensor(torch.empty(2))
2028*da0073e9SAndroid Build Coastguard Worker        x.data
2029*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(called, 1)
2030*da0073e9SAndroid Build Coastguard Worker        x.data = torch.empty(2)
2031*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(called, 1)
2032*da0073e9SAndroid Build Coastguard Worker        x.data
2033*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(called, 2)
2034*da0073e9SAndroid Build Coastguard Worker        self.assertIs(type(x), SubTensor)
2035*da0073e9SAndroid Build Coastguard Worker        x.set_(torch.empty(2))
2036*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(called, 3)
2037*da0073e9SAndroid Build Coastguard Worker        x.data
2038*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(called, 4)
2039*da0073e9SAndroid Build Coastguard Worker        self.assertIs(type(x), SubTensor)
2040*da0073e9SAndroid Build Coastguard Worker
2041*da0073e9SAndroid Build Coastguard Worker    def test_construct_int_tensor(self):
2042*da0073e9SAndroid Build Coastguard Worker        class SubTensor(torch.Tensor):
2043*da0073e9SAndroid Build Coastguard Worker            pass
2044*da0073e9SAndroid Build Coastguard Worker
2045*da0073e9SAndroid Build Coastguard Worker        # should not fail
2046*da0073e9SAndroid Build Coastguard Worker        SubTensor(torch.zeros(2, dtype=torch.int))
2047*da0073e9SAndroid Build Coastguard Worker
2048*da0073e9SAndroid Build Coastguard Worker    def test_multiple_ops_subclass(self):
2049*da0073e9SAndroid Build Coastguard Worker        # This is a Direct Subclass, don't do that!
2050*da0073e9SAndroid Build Coastguard Worker        class MySubclass(torch.Tensor):
2051*da0073e9SAndroid Build Coastguard Worker            @staticmethod
2052*da0073e9SAndroid Build Coastguard Worker            def __new__(cls, elem):
2053*da0073e9SAndroid Build Coastguard Worker                r = torch.Tensor._make_subclass(cls, elem)
2054*da0073e9SAndroid Build Coastguard Worker                return r
2055*da0073e9SAndroid Build Coastguard Worker
2056*da0073e9SAndroid Build Coastguard Worker            @classmethod
2057*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
2058*da0073e9SAndroid Build Coastguard Worker                with no_dispatch():
2059*da0073e9SAndroid Build Coastguard Worker                    return func(*args, **kwargs)
2060*da0073e9SAndroid Build Coastguard Worker
2061*da0073e9SAndroid Build Coastguard Worker        x = MySubclass(torch.rand(2, 2, dtype=torch.complex64))
2062*da0073e9SAndroid Build Coastguard Worker        y = x.conj()
2063*da0073e9SAndroid Build Coastguard Worker        # Details of the bug that this tests for:
2064*da0073e9SAndroid Build Coastguard Worker        # Here, y dispatch keys are: {PythonTLSSnapshot, AutogradCPU, Conjugate, Python, CPU}
2065*da0073e9SAndroid Build Coastguard Worker        # There are a few calls to the dispatcher that are going to happen here:
2066*da0073e9SAndroid Build Coastguard Worker        #  - call_exp: User calling exp on y
2067*da0073e9SAndroid Build Coastguard Worker        #    - PythonTLSSnapshot: records the TLS on entry and redispatch
2068*da0073e9SAndroid Build Coastguard Worker        #    - AutogradCPU: no input requires grad, so does nothing and redispatch
2069*da0073e9SAndroid Build Coastguard Worker        #    - Conjugate: no special implementation for exp: use the fallback that
2070*da0073e9SAndroid Build Coastguard Worker        #                 first clone the Tensor (to materialize the conj) then redispatch
2071*da0073e9SAndroid Build Coastguard Worker        #      - call_clone: conjugate fallback calling clone on y
2072*da0073e9SAndroid Build Coastguard Worker        #        - PythonTLSSnapshot: records the TLS on entry and redispatch
2073*da0073e9SAndroid Build Coastguard Worker        #        - (AutogradCPU: skipped as autograd added itself to the exclude set above)
2074*da0073e9SAndroid Build Coastguard Worker        #        - Conjugate: special implementation for clone: just skip this key
2075*da0073e9SAndroid Build Coastguard Worker        #        - Python: Reset the TLS based on the snapshot above and call the user implementation (this
2076*da0073e9SAndroid Build Coastguard Worker        #                  actually calls into the dispatcher again but since we disable both our keys
2077*da0073e9SAndroid Build Coastguard Worker        #                  before, not detailed here)
2078*da0073e9SAndroid Build Coastguard Worker        #        - exit Python: restore the TLS and exit
2079*da0073e9SAndroid Build Coastguard Worker        #        - exit Conjugate: nothing was inplace so just exit
2080*da0073e9SAndroid Build Coastguard Worker        #        - exit PythonTLSSnapshot: done with this call, reset the saved TLS to empty
2081*da0073e9SAndroid Build Coastguard Worker        #    - Python: Reset the TLS again based on the snapshot. <- this used to fail
2082*da0073e9SAndroid Build Coastguard Worker        #    - More steps....
2083*da0073e9SAndroid Build Coastguard Worker        y.exp()
2084*da0073e9SAndroid Build Coastguard Worker
2085*da0073e9SAndroid Build Coastguard Worker    @staticmethod
2086*da0073e9SAndroid Build Coastguard Worker    def subclass_helper(cls, data, use_wrapper_subclass, **kwargs):
2087*da0073e9SAndroid Build Coastguard Worker        if use_wrapper_subclass:
2088*da0073e9SAndroid Build Coastguard Worker            kwargs["device"] = data.device
2089*da0073e9SAndroid Build Coastguard Worker            kwargs["dtype"] = data.dtype
2090*da0073e9SAndroid Build Coastguard Worker            kwargs["layout"] = data.layout
2091*da0073e9SAndroid Build Coastguard Worker            kwargs["requires_grad"] = True
2092*da0073e9SAndroid Build Coastguard Worker            return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs)  # type: ignore[attr-defined]
2093*da0073e9SAndroid Build Coastguard Worker        else:
2094*da0073e9SAndroid Build Coastguard Worker            return torch.Tensor._make_subclass(cls, data, True, **kwargs)
2095*da0073e9SAndroid Build Coastguard Worker
2096*da0073e9SAndroid Build Coastguard Worker    def test_is_contiguous_slow_path(self):
2097*da0073e9SAndroid Build Coastguard Worker        data = torch.randn(3, 3)
2098*da0073e9SAndroid Build Coastguard Worker        contiguous_data = data.clone()
2099*da0073e9SAndroid Build Coastguard Worker        not_contiguous_data = torch.as_strided(data.clone(), (2, 2), (1, 2))
2100*da0073e9SAndroid Build Coastguard Worker
2101*da0073e9SAndroid Build Coastguard Worker        for use_wrapper_subclass in [True, False]:
2102*da0073e9SAndroid Build Coastguard Worker
2103*da0073e9SAndroid Build Coastguard Worker            class ExampleTensor1(torch.Tensor):
2104*da0073e9SAndroid Build Coastguard Worker                @staticmethod
2105*da0073e9SAndroid Build Coastguard Worker                def __new__(cls, data, wrapper):
2106*da0073e9SAndroid Build Coastguard Worker                    return TestPythonDispatch.subclass_helper(
2107*da0073e9SAndroid Build Coastguard Worker                        cls, data, wrapper, dispatch_sizes_strides_policy="strides"
2108*da0073e9SAndroid Build Coastguard Worker                    )
2109*da0073e9SAndroid Build Coastguard Worker
2110*da0073e9SAndroid Build Coastguard Worker                @classmethod
2111*da0073e9SAndroid Build Coastguard Worker                def __torch_dispatch__(cls, func, types, args, kwargs):
2112*da0073e9SAndroid Build Coastguard Worker                    return NotImplemented
2113*da0073e9SAndroid Build Coastguard Worker
2114*da0073e9SAndroid Build Coastguard Worker            class ExampleTensor2(torch.Tensor):
2115*da0073e9SAndroid Build Coastguard Worker                @staticmethod
2116*da0073e9SAndroid Build Coastguard Worker                def __new__(cls, data, wrapper):
2117*da0073e9SAndroid Build Coastguard Worker                    return TestPythonDispatch.subclass_helper(
2118*da0073e9SAndroid Build Coastguard Worker                        cls, data, wrapper, dispatch_sizes_strides_policy="strides"
2119*da0073e9SAndroid Build Coastguard Worker                    )
2120*da0073e9SAndroid Build Coastguard Worker
2121*da0073e9SAndroid Build Coastguard Worker                @classmethod
2122*da0073e9SAndroid Build Coastguard Worker                def __torch_dispatch__(cls, func, types, args, kwargs):
2123*da0073e9SAndroid Build Coastguard Worker                    if func.overloadpacket == torch.ops.aten.is_contiguous:
2124*da0073e9SAndroid Build Coastguard Worker                        return contiguous_data.is_contiguous()
2125*da0073e9SAndroid Build Coastguard Worker                    return NotImplemented
2126*da0073e9SAndroid Build Coastguard Worker
2127*da0073e9SAndroid Build Coastguard Worker            class ExampleTensor3(torch.Tensor):
2128*da0073e9SAndroid Build Coastguard Worker                @staticmethod
2129*da0073e9SAndroid Build Coastguard Worker                def __new__(cls, data, wrapper):
2130*da0073e9SAndroid Build Coastguard Worker                    return TestPythonDispatch.subclass_helper(
2131*da0073e9SAndroid Build Coastguard Worker                        cls, data, wrapper, dispatch_sizes_strides_policy="strides"
2132*da0073e9SAndroid Build Coastguard Worker                    )
2133*da0073e9SAndroid Build Coastguard Worker
2134*da0073e9SAndroid Build Coastguard Worker                @classmethod
2135*da0073e9SAndroid Build Coastguard Worker                def __torch_dispatch__(cls, func, types, args, kwargs):
2136*da0073e9SAndroid Build Coastguard Worker                    if func.overloadpacket == torch.ops.aten.is_contiguous:
2137*da0073e9SAndroid Build Coastguard Worker                        return not_contiguous_data.is_contiguous()
2138*da0073e9SAndroid Build Coastguard Worker                    return NotImplemented
2139*da0073e9SAndroid Build Coastguard Worker
2140*da0073e9SAndroid Build Coastguard Worker            err_msg = "Multiple dispatch failed for 'torch.ops.aten.is_contiguous'"
2141*da0073e9SAndroid Build Coastguard Worker            e = ExampleTensor1(torch.randn(3, 3), use_wrapper_subclass)
2142*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(TypeError, err_msg):
2143*da0073e9SAndroid Build Coastguard Worker                e.is_contiguous()
2144*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(TypeError, err_msg):
2145*da0073e9SAndroid Build Coastguard Worker                e.contiguous()
2146*da0073e9SAndroid Build Coastguard Worker
2147*da0073e9SAndroid Build Coastguard Worker            e = ExampleTensor2(torch.randn(3, 3), use_wrapper_subclass)
2148*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(e.is_contiguous(), True)
2149*da0073e9SAndroid Build Coastguard Worker            e.contiguous()  # this will just return the original TensorImpl since is_contiguous = True
2150*da0073e9SAndroid Build Coastguard Worker
2151*da0073e9SAndroid Build Coastguard Worker            err_msg = "Multiple dispatch failed for"
2152*da0073e9SAndroid Build Coastguard Worker            e = ExampleTensor3(torch.randn(3, 3), use_wrapper_subclass)
2153*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(e.is_contiguous(), False)
2154*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(TypeError, err_msg):
2155*da0073e9SAndroid Build Coastguard Worker                e.contiguous()
2156*da0073e9SAndroid Build Coastguard Worker
2157*da0073e9SAndroid Build Coastguard Worker    def test_fancy_strides(self):
2158*da0073e9SAndroid Build Coastguard Worker        calls = []
2159*da0073e9SAndroid Build Coastguard Worker
2160*da0073e9SAndroid Build Coastguard Worker        class ExampleTensor(torch.Tensor):
2161*da0073e9SAndroid Build Coastguard Worker            @staticmethod
2162*da0073e9SAndroid Build Coastguard Worker            def __new__(cls, data):
2163*da0073e9SAndroid Build Coastguard Worker                return TestPythonDispatch.subclass_helper(
2164*da0073e9SAndroid Build Coastguard Worker                    cls, data, False, dispatch_sizes_strides_policy="strides"
2165*da0073e9SAndroid Build Coastguard Worker                )
2166*da0073e9SAndroid Build Coastguard Worker
2167*da0073e9SAndroid Build Coastguard Worker            @classmethod
2168*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(cls, func, types, args, kwargs):
2169*da0073e9SAndroid Build Coastguard Worker                if func in [
2170*da0073e9SAndroid Build Coastguard Worker                    torch.ops.aten.is_contiguous.default,
2171*da0073e9SAndroid Build Coastguard Worker                    torch.ops.aten.is_contiguous.memory_format,
2172*da0073e9SAndroid Build Coastguard Worker                    torch.ops.aten.is_strides_like_format.default,
2173*da0073e9SAndroid Build Coastguard Worker                    torch.ops.aten.is_non_overlapping_and_dense.default,
2174*da0073e9SAndroid Build Coastguard Worker                    torch.ops.aten.stride.default,
2175*da0073e9SAndroid Build Coastguard Worker                ]:
2176*da0073e9SAndroid Build Coastguard Worker                    calls.append((func, list(args)[1:]))
2177*da0073e9SAndroid Build Coastguard Worker                    return None
2178*da0073e9SAndroid Build Coastguard Worker                with no_dispatch():
2179*da0073e9SAndroid Build Coastguard Worker                    return func(*args, **kwargs)
2180*da0073e9SAndroid Build Coastguard Worker
2181*da0073e9SAndroid Build Coastguard Worker        e = ExampleTensor(torch.randn(2, 2))
2182*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(e.is_contiguous(memory_format=torch.channels_last))
2183*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2184*da0073e9SAndroid Build Coastguard Worker            calls, [(torch.ops.aten.is_contiguous.memory_format, [torch.channels_last])]
2185*da0073e9SAndroid Build Coastguard Worker        )
2186*da0073e9SAndroid Build Coastguard Worker        calls.clear()
2187*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(
2188*da0073e9SAndroid Build Coastguard Worker            torch.ops.aten.is_strides_like_format.default(e, torch.channels_last)
2189*da0073e9SAndroid Build Coastguard Worker        )
2190*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2191*da0073e9SAndroid Build Coastguard Worker            calls,
2192*da0073e9SAndroid Build Coastguard Worker            [(torch.ops.aten.is_strides_like_format.default, [torch.channels_last])],
2193*da0073e9SAndroid Build Coastguard Worker        )
2194*da0073e9SAndroid Build Coastguard Worker        calls.clear()
2195*da0073e9SAndroid Build Coastguard Worker        self.assertTrue(torch.ops.aten.is_non_overlapping_and_dense.default(e))
2196*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(
2197*da0073e9SAndroid Build Coastguard Worker            calls, [(torch.ops.aten.is_non_overlapping_and_dense.default, [])]
2198*da0073e9SAndroid Build Coastguard Worker        )
2199*da0073e9SAndroid Build Coastguard Worker
2200*da0073e9SAndroid Build Coastguard Worker    def test_device_slowpath(self):
2201*da0073e9SAndroid Build Coastguard Worker        for use_wrapper_subclass in [True]:
2202*da0073e9SAndroid Build Coastguard Worker
2203*da0073e9SAndroid Build Coastguard Worker            class ExampleTensor1(torch.Tensor):
2204*da0073e9SAndroid Build Coastguard Worker                @staticmethod
2205*da0073e9SAndroid Build Coastguard Worker                def __new__(cls, data, wrapper):
2206*da0073e9SAndroid Build Coastguard Worker                    return TestPythonDispatch.subclass_helper(
2207*da0073e9SAndroid Build Coastguard Worker                        cls, data, wrapper, dispatch_device=True
2208*da0073e9SAndroid Build Coastguard Worker                    )
2209*da0073e9SAndroid Build Coastguard Worker
2210*da0073e9SAndroid Build Coastguard Worker                @classmethod
2211*da0073e9SAndroid Build Coastguard Worker                def __torch_dispatch__(cls, func, types, args, kwargs):
2212*da0073e9SAndroid Build Coastguard Worker                    return NotImplemented
2213*da0073e9SAndroid Build Coastguard Worker
2214*da0073e9SAndroid Build Coastguard Worker            class ExampleTensor2(torch.Tensor):
2215*da0073e9SAndroid Build Coastguard Worker                @staticmethod
2216*da0073e9SAndroid Build Coastguard Worker                def __new__(cls, data, wrapper):
2217*da0073e9SAndroid Build Coastguard Worker                    return TestPythonDispatch.subclass_helper(
2218*da0073e9SAndroid Build Coastguard Worker                        cls, data, wrapper, dispatch_device=True
2219*da0073e9SAndroid Build Coastguard Worker                    )
2220*da0073e9SAndroid Build Coastguard Worker
2221*da0073e9SAndroid Build Coastguard Worker                @classmethod
2222*da0073e9SAndroid Build Coastguard Worker                def __torch_dispatch__(cls, func, types, args, kwargs):
2223*da0073e9SAndroid Build Coastguard Worker                    if func.overloadpacket == torch.ops.prim.device:
2224*da0073e9SAndroid Build Coastguard Worker                        return torch.device("meta")
2225*da0073e9SAndroid Build Coastguard Worker                    return NotImplemented
2226*da0073e9SAndroid Build Coastguard Worker
2227*da0073e9SAndroid Build Coastguard Worker            class ExampleTensor3(torch.Tensor):
2228*da0073e9SAndroid Build Coastguard Worker                @staticmethod
2229*da0073e9SAndroid Build Coastguard Worker                def __new__(cls, data, wrapper):
2230*da0073e9SAndroid Build Coastguard Worker                    return TestPythonDispatch.subclass_helper(
2231*da0073e9SAndroid Build Coastguard Worker                        cls, data, wrapper, dispatch_device=True
2232*da0073e9SAndroid Build Coastguard Worker                    )
2233*da0073e9SAndroid Build Coastguard Worker
2234*da0073e9SAndroid Build Coastguard Worker                @classmethod
2235*da0073e9SAndroid Build Coastguard Worker                def __torch_dispatch__(cls, func, types, args, kwargs):
2236*da0073e9SAndroid Build Coastguard Worker                    if func.overloadpacket == torch.ops.prim.device:
2237*da0073e9SAndroid Build Coastguard Worker                        return torch.device("meta")
2238*da0073e9SAndroid Build Coastguard Worker                    return NotImplemented
2239*da0073e9SAndroid Build Coastguard Worker
2240*da0073e9SAndroid Build Coastguard Worker            err_msg = "Multiple dispatch failed for 'torch.ops.prim.device'"
2241*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(TypeError, err_msg):
2242*da0073e9SAndroid Build Coastguard Worker                e = ExampleTensor1(torch.randn(3, 3), use_wrapper_subclass)
2243*da0073e9SAndroid Build Coastguard Worker                e.device()
2244*da0073e9SAndroid Build Coastguard Worker
2245*da0073e9SAndroid Build Coastguard Worker            ten = torch.rand([1])
2246*da0073e9SAndroid Build Coastguard Worker            e = ExampleTensor2(torch.randn(3, 3, device="cpu"), use_wrapper_subclass)
2247*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(e.device.type, "meta")
2248*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ten.type_as(e).device.type, "meta")
2249*da0073e9SAndroid Build Coastguard Worker
2250*da0073e9SAndroid Build Coastguard Worker            e = ExampleTensor3(torch.randn(3, 3, device="cpu"), use_wrapper_subclass)
2251*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(e.device.type, "meta")
2252*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(ten.type_as(e).device.type, "meta")
2253*da0073e9SAndroid Build Coastguard Worker
2254*da0073e9SAndroid Build Coastguard Worker    def test_dim_slowpath(self):
2255*da0073e9SAndroid Build Coastguard Worker        data = torch.randn(3, 3)
2256*da0073e9SAndroid Build Coastguard Worker
2257*da0073e9SAndroid Build Coastguard Worker        for use_wrapper_subclass in [True, False]:
2258*da0073e9SAndroid Build Coastguard Worker
2259*da0073e9SAndroid Build Coastguard Worker            class DimNotImplementedTensor(torch.Tensor):
2260*da0073e9SAndroid Build Coastguard Worker                @staticmethod
2261*da0073e9SAndroid Build Coastguard Worker                def __new__(cls, data, wrapper):
2262*da0073e9SAndroid Build Coastguard Worker                    return TestPythonDispatch.subclass_helper(
2263*da0073e9SAndroid Build Coastguard Worker                        cls, data, wrapper, dispatch_sizes_strides_policy="sizes"
2264*da0073e9SAndroid Build Coastguard Worker                    )
2265*da0073e9SAndroid Build Coastguard Worker
2266*da0073e9SAndroid Build Coastguard Worker                @classmethod
2267*da0073e9SAndroid Build Coastguard Worker                def __torch_dispatch__(cls, func, types, args, kwargs):
2268*da0073e9SAndroid Build Coastguard Worker                    return NotImplemented
2269*da0073e9SAndroid Build Coastguard Worker
2270*da0073e9SAndroid Build Coastguard Worker            class DimImplementedTensor(torch.Tensor):
2271*da0073e9SAndroid Build Coastguard Worker                @staticmethod
2272*da0073e9SAndroid Build Coastguard Worker                def __new__(cls, data, wrapper):
2273*da0073e9SAndroid Build Coastguard Worker                    return TestPythonDispatch.subclass_helper(
2274*da0073e9SAndroid Build Coastguard Worker                        cls, data, wrapper, dispatch_sizes_strides_policy="sizes"
2275*da0073e9SAndroid Build Coastguard Worker                    )
2276*da0073e9SAndroid Build Coastguard Worker
2277*da0073e9SAndroid Build Coastguard Worker                @classmethod
2278*da0073e9SAndroid Build Coastguard Worker                def __torch_dispatch__(cls, func, types, args, kwargs):
2279*da0073e9SAndroid Build Coastguard Worker                    if func.overloadpacket == torch.ops.aten.dim:
2280*da0073e9SAndroid Build Coastguard Worker                        return data.dim()
2281*da0073e9SAndroid Build Coastguard Worker                    return NotImplemented
2282*da0073e9SAndroid Build Coastguard Worker
2283*da0073e9SAndroid Build Coastguard Worker            err_msg = "Multiple dispatch failed for 'torch.ops.aten.dim'"
2284*da0073e9SAndroid Build Coastguard Worker            e = DimNotImplementedTensor(torch.randn(3, 3), use_wrapper_subclass)
2285*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(TypeError, err_msg):
2286*da0073e9SAndroid Build Coastguard Worker                e.dim()
2287*da0073e9SAndroid Build Coastguard Worker
2288*da0073e9SAndroid Build Coastguard Worker            t = DimImplementedTensor(torch.randn(3, 3), use_wrapper_subclass)
2289*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(t.dim(), 2)
2290*da0073e9SAndroid Build Coastguard Worker
2291*da0073e9SAndroid Build Coastguard Worker    def test_maybe_tuple_bug(self):
2292*da0073e9SAndroid Build Coastguard Worker        class T(torch.Tensor):
2293*da0073e9SAndroid Build Coastguard Worker            @classmethod
2294*da0073e9SAndroid Build Coastguard Worker            def __torch_function__(cls, *args, **kwargs):
2295*da0073e9SAndroid Build Coastguard Worker                pass
2296*da0073e9SAndroid Build Coastguard Worker
2297*da0073e9SAndroid Build Coastguard Worker        a = torch.rand(3)
2298*da0073e9SAndroid Build Coastguard Worker
2299*da0073e9SAndroid Build Coastguard Worker        a[[T(), T()]]
2300*da0073e9SAndroid Build Coastguard Worker
2301*da0073e9SAndroid Build Coastguard Worker    def test_standard_is_not_subclass(self):
2302*da0073e9SAndroid Build Coastguard Worker        # https://github.com/pytorch/pytorch/issues/79079
2303*da0073e9SAndroid Build Coastguard Worker        self.assertFalse(torch._C._dispatch_isTensorSubclassLike(torch.empty(0)))
2304*da0073e9SAndroid Build Coastguard Worker
2305*da0073e9SAndroid Build Coastguard Worker    def test_sym_sizes_strides_slow_path(self):
2306*da0073e9SAndroid Build Coastguard Worker        class TestTensor(torch.Tensor):
2307*da0073e9SAndroid Build Coastguard Worker            @staticmethod
2308*da0073e9SAndroid Build Coastguard Worker            def __new__(cls, *args, **kwargs):
2309*da0073e9SAndroid Build Coastguard Worker                r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
2310*da0073e9SAndroid Build Coastguard Worker                    cls, (0,), dispatch_sizes_strides_policy="sizes"
2311*da0073e9SAndroid Build Coastguard Worker                )
2312*da0073e9SAndroid Build Coastguard Worker                return r
2313*da0073e9SAndroid Build Coastguard Worker
2314*da0073e9SAndroid Build Coastguard Worker            @classmethod
2315*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
2316*da0073e9SAndroid Build Coastguard Worker                if func in (
2317*da0073e9SAndroid Build Coastguard Worker                    torch.ops.aten.sym_size.default,
2318*da0073e9SAndroid Build Coastguard Worker                    torch.ops.aten.sym_stride.default,
2319*da0073e9SAndroid Build Coastguard Worker                ):
2320*da0073e9SAndroid Build Coastguard Worker                    from torch._dynamo.source import ConstantSource
2321*da0073e9SAndroid Build Coastguard Worker                    from torch.fx.experimental.symbolic_shapes import (
2322*da0073e9SAndroid Build Coastguard Worker                        DimDynamic,
2323*da0073e9SAndroid Build Coastguard Worker                        ShapeEnv,
2324*da0073e9SAndroid Build Coastguard Worker                    )
2325*da0073e9SAndroid Build Coastguard Worker
2326*da0073e9SAndroid Build Coastguard Worker                    shape_env = ShapeEnv()
2327*da0073e9SAndroid Build Coastguard Worker                    si = shape_env.create_symintnode(
2328*da0073e9SAndroid Build Coastguard Worker                        shape_env.create_symbol(
2329*da0073e9SAndroid Build Coastguard Worker                            123,
2330*da0073e9SAndroid Build Coastguard Worker                            source=ConstantSource("abc"),
2331*da0073e9SAndroid Build Coastguard Worker                            dynamic_dim=DimDynamic.DUCK,
2332*da0073e9SAndroid Build Coastguard Worker                            constraint_dim=None,
2333*da0073e9SAndroid Build Coastguard Worker                        ),
2334*da0073e9SAndroid Build Coastguard Worker                        hint=123,
2335*da0073e9SAndroid Build Coastguard Worker                    )
2336*da0073e9SAndroid Build Coastguard Worker                    return (si,)
2337*da0073e9SAndroid Build Coastguard Worker
2338*da0073e9SAndroid Build Coastguard Worker        t = TestTensor()
2339*da0073e9SAndroid Build Coastguard Worker        si = t.size()[0]
2340*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(si, torch.SymInt)
2341*da0073e9SAndroid Build Coastguard Worker        si = t.stride()[0]
2342*da0073e9SAndroid Build Coastguard Worker        self.assertIsInstance(si, torch.SymInt)
2343*da0073e9SAndroid Build Coastguard Worker
2344*da0073e9SAndroid Build Coastguard Worker    def test_strides_slow_path(self):
2345*da0073e9SAndroid Build Coastguard Worker        for use_wrapper_subclass in [True, False]:
2346*da0073e9SAndroid Build Coastguard Worker
2347*da0073e9SAndroid Build Coastguard Worker            class StridesNotImplemented(torch.Tensor):
2348*da0073e9SAndroid Build Coastguard Worker                @staticmethod
2349*da0073e9SAndroid Build Coastguard Worker                def __new__(cls, data, wrapper):
2350*da0073e9SAndroid Build Coastguard Worker                    return TestPythonDispatch.subclass_helper(
2351*da0073e9SAndroid Build Coastguard Worker                        cls, data, wrapper, dispatch_sizes_strides_policy="strides"
2352*da0073e9SAndroid Build Coastguard Worker                    )
2353*da0073e9SAndroid Build Coastguard Worker
2354*da0073e9SAndroid Build Coastguard Worker                @classmethod
2355*da0073e9SAndroid Build Coastguard Worker                def __torch_dispatch__(cls, func, types, args, kwargs):
2356*da0073e9SAndroid Build Coastguard Worker                    return NotImplemented
2357*da0073e9SAndroid Build Coastguard Worker
2358*da0073e9SAndroid Build Coastguard Worker            class StridesCustomReturn(torch.Tensor):
2359*da0073e9SAndroid Build Coastguard Worker                @staticmethod
2360*da0073e9SAndroid Build Coastguard Worker                def __new__(cls, data, wrapper):
2361*da0073e9SAndroid Build Coastguard Worker                    return TestPythonDispatch.subclass_helper(
2362*da0073e9SAndroid Build Coastguard Worker                        cls, data, wrapper, dispatch_sizes_strides_policy="strides"
2363*da0073e9SAndroid Build Coastguard Worker                    )
2364*da0073e9SAndroid Build Coastguard Worker
2365*da0073e9SAndroid Build Coastguard Worker                @classmethod
2366*da0073e9SAndroid Build Coastguard Worker                def __torch_dispatch__(cls, func, types, args, kwargs):
2367*da0073e9SAndroid Build Coastguard Worker                    if func == torch.ops.aten.sym_stride.default:
2368*da0073e9SAndroid Build Coastguard Worker                        return (4, 2)
2369*da0073e9SAndroid Build Coastguard Worker                    return NotImplemented
2370*da0073e9SAndroid Build Coastguard Worker
2371*da0073e9SAndroid Build Coastguard Worker            class StridesDefaultReturn(torch.Tensor):
2372*da0073e9SAndroid Build Coastguard Worker                @staticmethod
2373*da0073e9SAndroid Build Coastguard Worker                def __new__(cls, data, wrapper):
2374*da0073e9SAndroid Build Coastguard Worker                    return TestPythonDispatch.subclass_helper(
2375*da0073e9SAndroid Build Coastguard Worker                        cls, data, wrapper, dispatch_sizes_strides_policy="strides"
2376*da0073e9SAndroid Build Coastguard Worker                    )
2377*da0073e9SAndroid Build Coastguard Worker
2378*da0073e9SAndroid Build Coastguard Worker                @classmethod
2379*da0073e9SAndroid Build Coastguard Worker                def __torch_dispatch__(cls, func, types, args, kwargs):
2380*da0073e9SAndroid Build Coastguard Worker                    if func == torch.ops.aten.sym_stride.default:
2381*da0073e9SAndroid Build Coastguard Worker                        return None
2382*da0073e9SAndroid Build Coastguard Worker                    return NotImplemented
2383*da0073e9SAndroid Build Coastguard Worker
2384*da0073e9SAndroid Build Coastguard Worker            err_msg = "Multiple dispatch failed for 'torch.ops.aten.sym_stride'"
2385*da0073e9SAndroid Build Coastguard Worker            e = StridesNotImplemented(torch.randn(3, 3), use_wrapper_subclass)
2386*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(TypeError, err_msg):
2387*da0073e9SAndroid Build Coastguard Worker                e.stride()
2388*da0073e9SAndroid Build Coastguard Worker
2389*da0073e9SAndroid Build Coastguard Worker            e = StridesCustomReturn(torch.randn(3, 3), use_wrapper_subclass)
2390*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(e.stride(), (4, 2))
2391*da0073e9SAndroid Build Coastguard Worker
2392*da0073e9SAndroid Build Coastguard Worker            e = StridesDefaultReturn(torch.randn(6, 2), use_wrapper_subclass)
2393*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(e.stride(), (2, 1))
2394*da0073e9SAndroid Build Coastguard Worker
2395*da0073e9SAndroid Build Coastguard Worker    def test_sizes_slow_path(self):
2396*da0073e9SAndroid Build Coastguard Worker        for use_wrapper_subclass in [True, False]:
2397*da0073e9SAndroid Build Coastguard Worker            data = torch.randn(6, 2)
2398*da0073e9SAndroid Build Coastguard Worker
2399*da0073e9SAndroid Build Coastguard Worker            class SizesNotImplemented(torch.Tensor):
2400*da0073e9SAndroid Build Coastguard Worker                @staticmethod
2401*da0073e9SAndroid Build Coastguard Worker                def __new__(cls, data, wrapper):
2402*da0073e9SAndroid Build Coastguard Worker                    return TestPythonDispatch.subclass_helper(
2403*da0073e9SAndroid Build Coastguard Worker                        cls, data, wrapper, dispatch_sizes_strides_policy="sizes"
2404*da0073e9SAndroid Build Coastguard Worker                    )
2405*da0073e9SAndroid Build Coastguard Worker
2406*da0073e9SAndroid Build Coastguard Worker                @classmethod
2407*da0073e9SAndroid Build Coastguard Worker                def __torch_dispatch__(cls, func, types, args, kwargs):
2408*da0073e9SAndroid Build Coastguard Worker                    if func.overloadpacket == torch.ops.aten.dim:
2409*da0073e9SAndroid Build Coastguard Worker                        return data.dim()
2410*da0073e9SAndroid Build Coastguard Worker                    return NotImplemented
2411*da0073e9SAndroid Build Coastguard Worker
2412*da0073e9SAndroid Build Coastguard Worker            class SizesCustomReturn(torch.Tensor):
2413*da0073e9SAndroid Build Coastguard Worker                @staticmethod
2414*da0073e9SAndroid Build Coastguard Worker                def __new__(cls, data, wrapper):
2415*da0073e9SAndroid Build Coastguard Worker                    return TestPythonDispatch.subclass_helper(
2416*da0073e9SAndroid Build Coastguard Worker                        cls, data, wrapper, dispatch_sizes_strides_policy="sizes"
2417*da0073e9SAndroid Build Coastguard Worker                    )
2418*da0073e9SAndroid Build Coastguard Worker
2419*da0073e9SAndroid Build Coastguard Worker                @classmethod
2420*da0073e9SAndroid Build Coastguard Worker                def __torch_dispatch__(cls, func, types, args, kwargs):
2421*da0073e9SAndroid Build Coastguard Worker                    if func.overloadpacket == torch.ops.aten.dim:
2422*da0073e9SAndroid Build Coastguard Worker                        return data.dim()
2423*da0073e9SAndroid Build Coastguard Worker                    if func.overloadpacket == torch.ops.aten.sym_size:
2424*da0073e9SAndroid Build Coastguard Worker                        return (5, 3)
2425*da0073e9SAndroid Build Coastguard Worker                    return NotImplemented
2426*da0073e9SAndroid Build Coastguard Worker
2427*da0073e9SAndroid Build Coastguard Worker            class SizesDefaultReturn(torch.Tensor):
2428*da0073e9SAndroid Build Coastguard Worker                @staticmethod
2429*da0073e9SAndroid Build Coastguard Worker                def __new__(cls, data, wrapper):
2430*da0073e9SAndroid Build Coastguard Worker                    return TestPythonDispatch.subclass_helper(
2431*da0073e9SAndroid Build Coastguard Worker                        cls, data, wrapper, dispatch_sizes_strides_policy="sizes"
2432*da0073e9SAndroid Build Coastguard Worker                    )
2433*da0073e9SAndroid Build Coastguard Worker
2434*da0073e9SAndroid Build Coastguard Worker                @classmethod
2435*da0073e9SAndroid Build Coastguard Worker                def __torch_dispatch__(cls, func, types, args, kwargs):
2436*da0073e9SAndroid Build Coastguard Worker                    if func.overloadpacket == torch.ops.aten.dim:
2437*da0073e9SAndroid Build Coastguard Worker                        return data.dim()
2438*da0073e9SAndroid Build Coastguard Worker                    if func.overloadpacket == torch.ops.aten.sym_size:
2439*da0073e9SAndroid Build Coastguard Worker                        return None
2440*da0073e9SAndroid Build Coastguard Worker                    return NotImplemented
2441*da0073e9SAndroid Build Coastguard Worker
2442*da0073e9SAndroid Build Coastguard Worker            err_msg = "Multiple dispatch failed for 'torch.ops.aten.sym_size'"
2443*da0073e9SAndroid Build Coastguard Worker            e = SizesNotImplemented(torch.randn(3, 3), use_wrapper_subclass)
2444*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(TypeError, err_msg):
2445*da0073e9SAndroid Build Coastguard Worker                e.size()
2446*da0073e9SAndroid Build Coastguard Worker
2447*da0073e9SAndroid Build Coastguard Worker            e = SizesCustomReturn(torch.randn(3, 3), use_wrapper_subclass)
2448*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(e.size(), (5, 3))
2449*da0073e9SAndroid Build Coastguard Worker
2450*da0073e9SAndroid Build Coastguard Worker            e = SizesDefaultReturn(torch.randn(4, 2), use_wrapper_subclass)
2451*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(e.size(), (4, 2))
2452*da0073e9SAndroid Build Coastguard Worker
2453*da0073e9SAndroid Build Coastguard Worker    def test_custom_size_policy_dynamic_shapes(self):
2454*da0073e9SAndroid Build Coastguard Worker        data = torch.randn(6, 2)
2455*da0073e9SAndroid Build Coastguard Worker
2456*da0073e9SAndroid Build Coastguard Worker        class CustomSizeDynamicShapesTensor(torch.Tensor):
2457*da0073e9SAndroid Build Coastguard Worker            @staticmethod
2458*da0073e9SAndroid Build Coastguard Worker            def __new__(cls, inner):
2459*da0073e9SAndroid Build Coastguard Worker                return torch.Tensor._make_wrapper_subclass(
2460*da0073e9SAndroid Build Coastguard Worker                    # TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great.
2461*da0073e9SAndroid Build Coastguard Worker                    # Calling the overload that has kwargs causes us to go down the first overload path,
2462*da0073e9SAndroid Build Coastguard Worker                    # which will **always** specialize sizes.
2463*da0073e9SAndroid Build Coastguard Worker                    # We should probably eventually fix this so that the first overload can just handle dynamic shapes.
2464*da0073e9SAndroid Build Coastguard Worker                    cls,
2465*da0073e9SAndroid Build Coastguard Worker                    inner.size(),
2466*da0073e9SAndroid Build Coastguard Worker                    inner.stride(),
2467*da0073e9SAndroid Build Coastguard Worker                    None,
2468*da0073e9SAndroid Build Coastguard Worker                    None,
2469*da0073e9SAndroid Build Coastguard Worker                    inner.dtype,
2470*da0073e9SAndroid Build Coastguard Worker                    inner.layout,
2471*da0073e9SAndroid Build Coastguard Worker                    inner.device,
2472*da0073e9SAndroid Build Coastguard Worker                    False,
2473*da0073e9SAndroid Build Coastguard Worker                    inner.requires_grad,
2474*da0073e9SAndroid Build Coastguard Worker                    "sizes",
2475*da0073e9SAndroid Build Coastguard Worker                )
2476*da0073e9SAndroid Build Coastguard Worker
2477*da0073e9SAndroid Build Coastguard Worker            def __init__(self, inner):
2478*da0073e9SAndroid Build Coastguard Worker                self.inner = inner
2479*da0073e9SAndroid Build Coastguard Worker
2480*da0073e9SAndroid Build Coastguard Worker            @classmethod
2481*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(cls, func, types, args, kwargs):
2482*da0073e9SAndroid Build Coastguard Worker                if func == torch.ops.aten.sym_size.default:
2483*da0073e9SAndroid Build Coastguard Worker                    return args[0].inner.shape
2484*da0073e9SAndroid Build Coastguard Worker                if func == torch.ops.aten.sym_stride.default:
2485*da0073e9SAndroid Build Coastguard Worker                    return args[0].inner.shape
2486*da0073e9SAndroid Build Coastguard Worker                return NotImplemented
2487*da0073e9SAndroid Build Coastguard Worker
2488*da0073e9SAndroid Build Coastguard Worker        x = torch.ones(2, 2)
2489*da0073e9SAndroid Build Coastguard Worker
2490*da0073e9SAndroid Build Coastguard Worker        def trace_fn(x):
2491*da0073e9SAndroid Build Coastguard Worker            x_wrapper = CustomSizeDynamicShapesTensor(x)
2492*da0073e9SAndroid Build Coastguard Worker            return x_wrapper.size(), x_wrapper.stride()
2493*da0073e9SAndroid Build Coastguard Worker
2494*da0073e9SAndroid Build Coastguard Worker        fx_g = make_fx(trace_fn, tracing_mode="symbolic")(x)
2495*da0073e9SAndroid Build Coastguard Worker        self.assertExpectedInline(
2496*da0073e9SAndroid Build Coastguard Worker            fx_g.code.strip(),
2497*da0073e9SAndroid Build Coastguard Worker            """\
2498*da0073e9SAndroid Build Coastguard Workerdef forward(self, x_1):
2499*da0073e9SAndroid Build Coastguard Worker    sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
2500*da0073e9SAndroid Build Coastguard Worker    sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1);  x_1 = None
2501*da0073e9SAndroid Build Coastguard Worker    return ((sym_size_int, sym_size_int_1), (sym_size_int, sym_size_int_1))""",
2502*da0073e9SAndroid Build Coastguard Worker        )
2503*da0073e9SAndroid Build Coastguard Worker
2504*da0073e9SAndroid Build Coastguard Worker    def test_data_ptr_respects_numel_slow_path(self):
2505*da0073e9SAndroid Build Coastguard Worker        data = torch.randn(6, 2)
2506*da0073e9SAndroid Build Coastguard Worker
2507*da0073e9SAndroid Build Coastguard Worker        class NumelDefaultReturn(torch.Tensor):
2508*da0073e9SAndroid Build Coastguard Worker            @staticmethod
2509*da0073e9SAndroid Build Coastguard Worker            def __new__(cls, data, wrapper):
2510*da0073e9SAndroid Build Coastguard Worker                return TestPythonDispatch.subclass_helper(
2511*da0073e9SAndroid Build Coastguard Worker                    cls, data, wrapper, dispatch_sizes_strides_policy="sizes"
2512*da0073e9SAndroid Build Coastguard Worker                )
2513*da0073e9SAndroid Build Coastguard Worker
2514*da0073e9SAndroid Build Coastguard Worker            @classmethod
2515*da0073e9SAndroid Build Coastguard Worker            def __torch_dispatch__(cls, func, types, args, kwargs):
2516*da0073e9SAndroid Build Coastguard Worker                if func.overloadpacket == torch.ops.aten.dim:
2517*da0073e9SAndroid Build Coastguard Worker                    return data.dim()
2518*da0073e9SAndroid Build Coastguard Worker                if func.overloadpacket == torch.ops.aten.numel:
2519*da0073e9SAndroid Build Coastguard Worker                    numel_called[0] = True
2520*da0073e9SAndroid Build Coastguard Worker                    return None
2521*da0073e9SAndroid Build Coastguard Worker                return NotImplemented
2522*da0073e9SAndroid Build Coastguard Worker
2523*da0073e9SAndroid Build Coastguard Worker        for use_wrapper_subclass in (False, True):
2524*da0073e9SAndroid Build Coastguard Worker            numel_called = [False]
2525*da0073e9SAndroid Build Coastguard Worker            e = NumelDefaultReturn(torch.randn(2, 2), use_wrapper_subclass)
2526*da0073e9SAndroid Build Coastguard Worker            e.data_ptr()
2527*da0073e9SAndroid Build Coastguard Worker            self.assertTrue(numel_called[0])
2528*da0073e9SAndroid Build Coastguard Worker
2529*da0073e9SAndroid Build Coastguard Worker    def test_layout_slow_path(self):
2530*da0073e9SAndroid Build Coastguard Worker        for use_wrapper_subclass in [True, False]:
2531*da0073e9SAndroid Build Coastguard Worker            data = torch.randn(6, 2)
2532*da0073e9SAndroid Build Coastguard Worker
2533*da0073e9SAndroid Build Coastguard Worker            class LayoutNotImplemented(torch.Tensor):
2534*da0073e9SAndroid Build Coastguard Worker                @staticmethod
2535*da0073e9SAndroid Build Coastguard Worker                def __new__(cls, data, wrapper):
2536*da0073e9SAndroid Build Coastguard Worker                    return TestPythonDispatch.subclass_helper(
2537*da0073e9SAndroid Build Coastguard Worker                        cls, data, wrapper, dispatch_layout=True
2538*da0073e9SAndroid Build Coastguard Worker                    )
2539*da0073e9SAndroid Build Coastguard Worker
2540*da0073e9SAndroid Build Coastguard Worker                @classmethod
2541*da0073e9SAndroid Build Coastguard Worker                def __torch_dispatch__(cls, func, types, args, kwargs):
2542*da0073e9SAndroid Build Coastguard Worker                    return NotImplemented
2543*da0073e9SAndroid Build Coastguard Worker
2544*da0073e9SAndroid Build Coastguard Worker            class LayoutCustomReturn(torch.Tensor):
2545*da0073e9SAndroid Build Coastguard Worker                @staticmethod
2546*da0073e9SAndroid Build Coastguard Worker                def __new__(cls, data, wrapper):
2547*da0073e9SAndroid Build Coastguard Worker                    return TestPythonDispatch.subclass_helper(
2548*da0073e9SAndroid Build Coastguard Worker                        cls, data, wrapper, dispatch_layout=True
2549*da0073e9SAndroid Build Coastguard Worker                    )
2550*da0073e9SAndroid Build Coastguard Worker
2551*da0073e9SAndroid Build Coastguard Worker                @classmethod
2552*da0073e9SAndroid Build Coastguard Worker                def __torch_dispatch__(cls, func, types, args, kwargs):
2553*da0073e9SAndroid Build Coastguard Worker                    if func.overloadpacket == torch.ops.prim.layout:
2554*da0073e9SAndroid Build Coastguard Worker                        return torch.sparse_csr
2555*da0073e9SAndroid Build Coastguard Worker                    return NotImplemented
2556*da0073e9SAndroid Build Coastguard Worker
2557*da0073e9SAndroid Build Coastguard Worker            class LayoutDefaultReturn(torch.Tensor):
2558*da0073e9SAndroid Build Coastguard Worker                @staticmethod
2559*da0073e9SAndroid Build Coastguard Worker                def __new__(cls, data, wrapper):
2560*da0073e9SAndroid Build Coastguard Worker                    return TestPythonDispatch.subclass_helper(
2561*da0073e9SAndroid Build Coastguard Worker                        cls, data, wrapper, dispatch_layout=True
2562*da0073e9SAndroid Build Coastguard Worker                    )
2563*da0073e9SAndroid Build Coastguard Worker
2564*da0073e9SAndroid Build Coastguard Worker                @classmethod
2565*da0073e9SAndroid Build Coastguard Worker                def __torch_dispatch__(cls, func, types, args, kwargs):
2566*da0073e9SAndroid Build Coastguard Worker                    if func.overloadpacket == torch.ops.prim.layout:
2567*da0073e9SAndroid Build Coastguard Worker                        return data.layout
2568*da0073e9SAndroid Build Coastguard Worker                    return NotImplemented
2569*da0073e9SAndroid Build Coastguard Worker
2570*da0073e9SAndroid Build Coastguard Worker            err_msg = "Multiple dispatch failed for 'torch.ops.prim.layout'"
2571*da0073e9SAndroid Build Coastguard Worker            e = LayoutNotImplemented(torch.randn(3, 3), use_wrapper_subclass)
2572*da0073e9SAndroid Build Coastguard Worker            with self.assertRaisesRegex(TypeError, err_msg):
2573*da0073e9SAndroid Build Coastguard Worker                e.layout
2574*da0073e9SAndroid Build Coastguard Worker
2575*da0073e9SAndroid Build Coastguard Worker            e = LayoutCustomReturn(torch.randn(3, 3), use_wrapper_subclass)
2576*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(e.layout, torch.sparse_csr)
2577*da0073e9SAndroid Build Coastguard Worker
2578*da0073e9SAndroid Build Coastguard Worker            e = LayoutDefaultReturn(torch.randn(4, 2), use_wrapper_subclass)
2579*da0073e9SAndroid Build Coastguard Worker            self.assertEqual(e.layout, torch.strided)
2580*da0073e9SAndroid Build Coastguard Worker
2581*da0073e9SAndroid Build Coastguard Worker
2582*da0073e9SAndroid Build Coastguard Workerclass TestPythonDispatcher(TestCase):
2583*da0073e9SAndroid Build Coastguard Worker    def test_basic(self):
2584*da0073e9SAndroid Build Coastguard Worker        x = torch.randn(2, requires_grad=True)
2585*da0073e9SAndroid Build Coastguard Worker        r = torch._C._EnablePythonDispatcher()
2586*da0073e9SAndroid Build Coastguard Worker        torch.add(x, x)
2587*da0073e9SAndroid Build Coastguard Worker
2588*da0073e9SAndroid Build Coastguard Worker    def test_lstsq(self):
2589*da0073e9SAndroid Build Coastguard Worker        a = torch.randn(4, 3)
2590*da0073e9SAndroid Build Coastguard Worker        b = torch.rand(4, 3)
2591*da0073e9SAndroid Build Coastguard Worker        expected_shape = torch.linalg.lstsq(a, b).solution.shape
2592*da0073e9SAndroid Build Coastguard Worker        r = torch._C._EnablePythonDispatcher()
2593*da0073e9SAndroid Build Coastguard Worker        python_disp_shape = torch.linalg.lstsq(a, b).solution.shape
2594*da0073e9SAndroid Build Coastguard Worker        self.assertEqual(expected_shape, python_disp_shape)
2595*da0073e9SAndroid Build Coastguard Worker
2596*da0073e9SAndroid Build Coastguard Worker
2597*da0073e9SAndroid Build Coastguard Workerclass TestWrapperSubclassAliasing(TestCase):
2598*da0073e9SAndroid Build Coastguard Worker    def _test_wrapper_subclass_aliasing(self, op, args, kwargs):
2599*da0073e9SAndroid Build Coastguard Worker        def to_subclass(t: torch.Tensor):
2600*da0073e9SAndroid Build Coastguard Worker            return TwoTensor(t, t.clone())
2601*da0073e9SAndroid Build Coastguard Worker
2602*da0073e9SAndroid Build Coastguard Worker        result_ref = op(*args, **kwargs)
2603*da0073e9SAndroid Build Coastguard Worker
2604*da0073e9SAndroid Build Coastguard Worker        args_subclass = pytree.tree_map_only(torch.Tensor, to_subclass, args)
2605*da0073e9SAndroid Build Coastguard Worker        kwargs_subclass = pytree.tree_map_only(torch.Tensor, to_subclass, kwargs)
2606*da0073e9SAndroid Build Coastguard Worker
2607*da0073e9SAndroid Build Coastguard Worker        result_test = op(*args_subclass, **kwargs_subclass)
2608*da0073e9SAndroid Build Coastguard Worker
2609*da0073e9SAndroid Build Coastguard Worker        args_ref_flat = pytree.arg_tree_leaves(*args, **kwargs)
2610*da0073e9SAndroid Build Coastguard Worker        args_ref_flat_tensors = [
2611*da0073e9SAndroid Build Coastguard Worker            x for x in args_ref_flat if isinstance(x, torch.Tensor)
2612*da0073e9SAndroid Build Coastguard Worker        ]
2613*da0073e9SAndroid Build Coastguard Worker
2614*da0073e9SAndroid Build Coastguard Worker        args_test_flat = pytree.tree_leaves((args_subclass, kwargs_subclass))
2615*da0073e9SAndroid Build Coastguard Worker        args_test_flat_tensors = [
2616*da0073e9SAndroid Build Coastguard Worker            x for x in args_test_flat if isinstance(x, torch.Tensor)
2617*da0073e9SAndroid Build Coastguard Worker        ]
2618*da0073e9SAndroid Build Coastguard Worker
2619*da0073e9SAndroid Build Coastguard Worker        result_ref_flat = pytree.tree_leaves(result_ref)
2620*da0073e9SAndroid Build Coastguard Worker        result_ref_flat_tensors = [
2621*da0073e9SAndroid Build Coastguard Worker            x for x in result_ref_flat if isinstance(x, torch.Tensor)
2622*da0073e9SAndroid Build Coastguard Worker        ]
2623*da0073e9SAndroid Build Coastguard Worker
2624*da0073e9SAndroid Build Coastguard Worker        result_test_flat = pytree.tree_leaves(result_test)
2625*da0073e9SAndroid Build Coastguard Worker        result_test_flat_tensors = [
2626*da0073e9SAndroid Build Coastguard Worker            x for x in result_test_flat if isinstance(x, torch.Tensor)
2627*da0073e9SAndroid Build Coastguard Worker        ]
2628*da0073e9SAndroid Build Coastguard Worker
2629*da0073e9SAndroid Build Coastguard Worker        for o_ref, o_test in zip(result_ref_flat_tensors, result_test_flat_tensors):
2630*da0073e9SAndroid Build Coastguard Worker            for a_ref, a_test in zip(args_ref_flat_tensors, args_test_flat_tensors):
2631*da0073e9SAndroid Build Coastguard Worker                out_is_inpt = o_ref is a_ref
2632*da0073e9SAndroid Build Coastguard Worker                if out_is_inpt:
2633*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(o_test is a_test)
2634*da0073e9SAndroid Build Coastguard Worker
2635*da0073e9SAndroid Build Coastguard Worker                out_aliases_inpt = StorageWeakRef(
2636*da0073e9SAndroid Build Coastguard Worker                    o_ref.untyped_storage()
2637*da0073e9SAndroid Build Coastguard Worker                ) == StorageWeakRef(a_ref.untyped_storage())
2638*da0073e9SAndroid Build Coastguard Worker                if out_aliases_inpt:
2639*da0073e9SAndroid Build Coastguard Worker                    self.assertTrue(
2640*da0073e9SAndroid Build Coastguard Worker                        StorageWeakRef(o_test.untyped_storage())
2641*da0073e9SAndroid Build Coastguard Worker                        == StorageWeakRef(a_test.untyped_storage())
2642*da0073e9SAndroid Build Coastguard Worker                    )
2643*da0073e9SAndroid Build Coastguard Worker                else:
2644*da0073e9SAndroid Build Coastguard Worker                    self.assertFalse(
2645*da0073e9SAndroid Build Coastguard Worker                        StorageWeakRef(o_test.untyped_storage())
2646*da0073e9SAndroid Build Coastguard Worker                        == StorageWeakRef(a_test.untyped_storage())
2647*da0073e9SAndroid Build Coastguard Worker                    )
2648*da0073e9SAndroid Build Coastguard Worker
2649*da0073e9SAndroid Build Coastguard Worker    # This tests the correctness of `torch.utils._python_dispatch.return_and_correct_aliasing`,
2650*da0073e9SAndroid Build Coastguard Worker    # a util for wrapper subclasses to promise correct aliasing behavior.
2651*da0073e9SAndroid Build Coastguard Worker    # It's probably overkill to test every OpInfo,
2652*da0073e9SAndroid Build Coastguard Worker    # so I picked a sampling of ops with representative schemas.
2653*da0073e9SAndroid Build Coastguard Worker    @ops(
2654*da0073e9SAndroid Build Coastguard Worker        [
2655*da0073e9SAndroid Build Coastguard Worker            op
2656*da0073e9SAndroid Build Coastguard Worker            for op in op_db
2657*da0073e9SAndroid Build Coastguard Worker            if op.name
2658*da0073e9SAndroid Build Coastguard Worker            in [
2659*da0073e9SAndroid Build Coastguard Worker                "mul",  # out-of-place
2660*da0073e9SAndroid Build Coastguard Worker                "cat",  # out-of-place (TensorList input)
2661*da0073e9SAndroid Build Coastguard Worker                "index",  # out-of-place (Optional TensorList input)
2662*da0073e9SAndroid Build Coastguard Worker                "mul_",  # inplace
2663*da0073e9SAndroid Build Coastguard Worker                "view",  # view
2664*da0073e9SAndroid Build Coastguard Worker                "t_",  # inplace-view
2665*da0073e9SAndroid Build Coastguard Worker                "split",  # view (multi-return)
2666*da0073e9SAndroid Build Coastguard Worker                "native_batch_norm",  # mutable op (returns outputs and mutates some inputs)
2667*da0073e9SAndroid Build Coastguard Worker            ]
2668*da0073e9SAndroid Build Coastguard Worker        ],
2669*da0073e9SAndroid Build Coastguard Worker        allowed_dtypes=(torch.float,),
2670*da0073e9SAndroid Build Coastguard Worker    )
2671*da0073e9SAndroid Build Coastguard Worker    def test_wrapper_subclass_aliasing(self, device, dtype, op):
2672*da0073e9SAndroid Build Coastguard Worker        samples = op.sample_inputs(device, dtype)
2673*da0073e9SAndroid Build Coastguard Worker        sample = first_sample(self, samples)
2674*da0073e9SAndroid Build Coastguard Worker        args = (sample.input, *sample.args)
2675*da0073e9SAndroid Build Coastguard Worker        kwargs = sample.kwargs
2676*da0073e9SAndroid Build Coastguard Worker        self._test_wrapper_subclass_aliasing(op, args, kwargs)
2677*da0073e9SAndroid Build Coastguard Worker
2678*da0073e9SAndroid Build Coastguard Worker    @ops(custom_op_db, allowed_dtypes=(torch.float,))
2679*da0073e9SAndroid Build Coastguard Worker    def test_wrapper_subclass_aliasing_custom(self, device, dtype, op):
2680*da0073e9SAndroid Build Coastguard Worker        samples = op.sample_inputs(device, dtype)
2681*da0073e9SAndroid Build Coastguard Worker        sample = first_sample(self, samples)
2682*da0073e9SAndroid Build Coastguard Worker        args = (sample.input, *sample.args)
2683*da0073e9SAndroid Build Coastguard Worker        kwargs = sample.kwargs
2684*da0073e9SAndroid Build Coastguard Worker        self._test_wrapper_subclass_aliasing(op, args, kwargs)
2685*da0073e9SAndroid Build Coastguard Worker
2686*da0073e9SAndroid Build Coastguard Worker    def test_wrapper_subclass_aliasing_conv2d(self, device):
2687*da0073e9SAndroid Build Coastguard Worker        args = (torch.randn(4, 4, 4, 4), torch.randn(4, 4, 4, 4))
2688*da0073e9SAndroid Build Coastguard Worker        kwargs = {}
2689*da0073e9SAndroid Build Coastguard Worker        # conv2d has a default arg 'int[2] strides=0',
2690*da0073e9SAndroid Build Coastguard Worker        # which torchscript expands into 'int[2] strides=[0, 0]'
2691*da0073e9SAndroid Build Coastguard Worker        # Make sure that _return_and_correct_aliasing can handle this case
2692*da0073e9SAndroid Build Coastguard Worker        # (I'm using inference_mode to make sure conv2d doesn't decompose and goes to torch_dispatch)
2693*da0073e9SAndroid Build Coastguard Worker        with torch.inference_mode():
2694*da0073e9SAndroid Build Coastguard Worker            self._test_wrapper_subclass_aliasing(
2695*da0073e9SAndroid Build Coastguard Worker                torch.ops.aten.conv2d.default, args, kwargs
2696*da0073e9SAndroid Build Coastguard Worker            )
2697*da0073e9SAndroid Build Coastguard Worker
2698*da0073e9SAndroid Build Coastguard Worker    def test_wrapper_subclass_aliasing_out_op(self, device):
2699*da0073e9SAndroid Build Coastguard Worker        # Make sure that _return_and_correct_aliasing can handle kwargs w mutable tensors
2700*da0073e9SAndroid Build Coastguard Worker        args = (torch.ones(4), torch.ones(4))
2701*da0073e9SAndroid Build Coastguard Worker        kwargs = {"out": torch.empty(4)}
2702*da0073e9SAndroid Build Coastguard Worker        self._test_wrapper_subclass_aliasing(torch.ops.aten.add.out, args, kwargs)
2703*da0073e9SAndroid Build Coastguard Worker
2704*da0073e9SAndroid Build Coastguard Worker
2705*da0073e9SAndroid Build Coastguard Workerinstantiate_device_type_tests(TestWrapperSubclassAliasing, globals())
2706*da0073e9SAndroid Build Coastguard Worker
2707*da0073e9SAndroid Build Coastguard Workerif __name__ == "__main__":
2708*da0073e9SAndroid Build Coastguard Worker    run_tests()
2709