xref: /aosp_15_r20/external/pytorch/test/export/test_torchbind.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: export"]
2
3
4import unittest
5
6import torch
7import torch.utils._pytree as pytree
8from torch._dynamo.testing import EagerAndRecordGraphs
9from torch._functorch.aot_autograd import aot_export_module
10from torch._higher_order_ops.torchbind import enable_torchbind_tracing
11from torch._higher_order_ops.wrap import wrap
12from torch._library.fake_class_registry import FakeScriptObject
13from torch.export import export
14from torch.export._trace import _export
15from torch.fx.experimental.proxy_tensor import make_fx
16from torch.testing._internal.common_utils import (
17    instantiate_parametrized_tests,
18    parametrize,
19    run_tests,
20    skipIfTorchDynamo,
21    TestCase,
22)
23from torch.testing._internal.torchbind_impls import (
24    _empty_tensor_queue,
25    init_torchbind_implementations,
26)
27
28
29def _assertEqualSkipScriptObject(test_case, exp, actual):
30    flat_exp = pytree.tree_leaves(exp)
31    flat_actual = pytree.tree_leaves(actual)
32    test_case.assertEqual(len(flat_exp), len(flat_actual))
33    for a, b in zip(flat_exp, flat_actual):
34        if isinstance(a, torch.ScriptObject) and isinstance(b, torch.ScriptObject):
35            continue
36        test_case.assertEqual(a, b)
37
38
39def _check_script_obj_equal(test_case, a: torch.ScriptObject, b: torch.ScriptObject):
40    return test_case.assertEqual(
41        a._type().qualified_name(), b._type().qualified_name()
42    ) and test_case.assertEqual(a.__obj_flatten__(), b.__obj_flatten__())
43
44
45def _assertEqualScriptObject(
46    test_case, exp, actual, check_obj_eq=_check_script_obj_equal
47):
48    flat_exp = pytree.tree_leaves(exp)
49    flat_actual = pytree.tree_leaves(actual)
50    test_case.assertEqual(len(flat_exp), len(flat_actual))
51    for a, b in zip(flat_exp, flat_actual):
52        if isinstance(a, torch.ScriptObject) and isinstance(b, torch.ScriptObject):
53            check_obj_eq(test_case, a, b)
54        else:
55            test_case.assertEqual(a, b)
56
57
58@skipIfTorchDynamo("torchbind not supported with dynamo yet")
59class TestExportTorchbind(TestCase):
60    def setUp(self):
61        init_torchbind_implementations()
62
63        test = self
64        test.tq_push_counter = 0
65        test.tq_pop_counter = 0
66        test.tq_size_counter = 0
67        test.foo_add_tensor_counter = 0
68
69        @torch._library.register_fake_class("_TorchScriptTesting::_Foo")
70        class FakeFoo:
71            def __init__(self, x: int, y: int):
72                self.x = x
73                self.y = y
74
75            @classmethod
76            def __obj_unflatten__(cls, flattend_foo):
77                return cls(**dict(flattend_foo))
78
79            def add_tensor(self, z):
80                test.foo_add_tensor_counter += 1
81                return (self.x + self.y) * z
82
83        @torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
84        class FakeTensorQueue:
85            def __init__(self, queue):
86                self.queue = queue
87
88            @classmethod
89            def __obj_unflatten__(cls, flattened_ctx):
90                return cls(**dict(flattened_ctx))
91
92            def push(self, x):
93                test.tq_push_counter += 1
94                self.queue.append(x)
95
96            def pop(self):
97                test.tq_pop_counter += 1
98                return self.queue.pop(0)
99
100            def size(self):
101                test.tq_size_counter += 1
102                return len(self.queue)
103
104            def is_empty(self):
105                return len(self.queue) == 0
106
107            def float_size(self):
108                return float(len(self.queue))
109
110        self.torch_bind_ops = [
111            torch.ops._TorchScriptTesting.takes_foo,
112            torch.ops._TorchScriptTesting.takes_foo_python_meta,
113            torch.ops._TorchScriptTesting.takes_foo_list_return,
114            torch.ops._TorchScriptTesting.takes_foo_tuple_return,
115            torch.ops._TorchScriptTesting.take_an_instance,
116            torch.ops._TorchScriptTesting.take_an_instance_inferred,
117            torch.ops._TorchScriptTesting.takes_foo_cia,
118            torch.ops._TorchScriptTesting.queue_pop,
119            torch.ops._TorchScriptTesting.queue_push,
120            torch.ops._TorchScriptTesting.queue_size,
121        ]
122
123    def tearDown(self):
124        torch._library.fake_class_registry.deregister_fake_class(
125            "_TorchScriptTesting::_Foo"
126        )
127        torch._library.fake_class_registry.deregister_fake_class(
128            "_TorchScriptTesting::_TensorQueue"
129        )
130
131    def _test_export_same_as_eager(
132        self, f, args, kwargs=None, strict=True, pre_dispatch=False
133    ):
134        kwargs = kwargs or {}
135
136        def export_wrapper(f, args, kwargs, strcit, pre_dispatch):
137            with enable_torchbind_tracing():
138                if pre_dispatch:
139                    exported_program = _export(
140                        f, args, kwargs, strict=strict, pre_dispatch=True
141                    )
142                else:
143                    exported_program = export(f, args, kwargs, strict=strict)
144            return exported_program
145
146        exported_program = export_wrapper(f, args, kwargs, strict, pre_dispatch)
147        reversed_kwargs = {key: kwargs[key] for key in reversed(kwargs)}
148        unlifted = exported_program.module()
149        exp = f(*args, **kwargs)
150        _assertEqualScriptObject(self, unlifted(*args, **kwargs), exp)
151        _assertEqualScriptObject(
152            self,
153            unlifted(*args, **reversed_kwargs),
154            exp,
155        )
156
157        # check re-tracing
158        retraced_ep = export_wrapper(unlifted, args, kwargs, strict, pre_dispatch)
159        _assertEqualScriptObject(self, retraced_ep.module()(*args, **kwargs), exp)
160        return exported_program
161
162    @parametrize("pre_dispatch", [True, False])
163    def test_none(self, pre_dispatch):
164        class MyModule(torch.nn.Module):
165            def __init__(self) -> None:
166                super().__init__()
167                self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
168
169            def forward(self, x, n):
170                return x + self.attr.add_tensor(x)
171
172        ep = self._test_export_same_as_eager(
173            MyModule(),
174            (torch.ones(2, 3), None),
175            strict=False,
176            pre_dispatch=pre_dispatch,
177        )
178        self.assertExpectedInline(
179            ep.module().code.strip(),
180            """\
181def forward(self, x, n):
182    x, n, = fx_pytree.tree_flatten_spec(([x, n], {}), self._in_spec)
183    attr = self.attr
184    call_torchbind = torch.ops.higher_order.call_torchbind(attr, 'add_tensor', x);  attr = None
185    add = torch.ops.aten.add.Tensor(x, call_torchbind);  x = call_torchbind = None
186    return pytree.tree_unflatten((add,), self._out_spec)""",
187        )
188        self.assertExpectedInline(
189            ep.graph_module.code.strip(),
190            """\
191def forward(self, token, obj_attr, x, n):
192    with_effects = torch.ops.higher_order.with_effects(token, torch.ops.higher_order.call_torchbind, obj_attr, 'add_tensor', x);  token = obj_attr = None
193    getitem = with_effects[0]
194    getitem_1 = with_effects[1];  with_effects = None
195    add = torch.ops.aten.add.Tensor(x, getitem_1);  x = getitem_1 = None
196    return (getitem, add)""",  # noqa: B950
197        )
198
199    def test_method_schema(self):
200        tq = _empty_tensor_queue()
201        fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()
202        fake_obj = torch._library.fake_class_registry.maybe_to_fake_obj(fake_mode, tq)
203        self.assertExpectedInline(
204            str(fake_obj.push.schema),
205            """push(__torch__.torch.classes._TorchScriptTesting._TensorQueue _0, Tensor _1) -> NoneType _0""",
206        )
207        self.assertExpectedInline(
208            str(fake_obj.pop.schema),
209            """pop(__torch__.torch.classes._TorchScriptTesting._TensorQueue _0) -> Tensor _0""",
210        )
211
212    @parametrize("pre_dispatch", [True, False])
213    def test_attribute(self, pre_dispatch):
214        class MyModule(torch.nn.Module):
215            def __init__(self) -> None:
216                super().__init__()
217                self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
218
219            def forward(self, x):
220                return x + self.attr.add_tensor(x)
221
222        ep = self._test_export_same_as_eager(
223            MyModule(), (torch.ones(2, 3),), strict=False, pre_dispatch=pre_dispatch
224        )
225        self.assertExpectedInline(
226            ep.module().code.strip(),
227            """\
228def forward(self, x):
229    x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
230    attr = self.attr
231    call_torchbind = torch.ops.higher_order.call_torchbind(attr, 'add_tensor', x);  attr = None
232    add = torch.ops.aten.add.Tensor(x, call_torchbind);  x = call_torchbind = None
233    return pytree.tree_unflatten((add,), self._out_spec)""",
234        )
235        self.assertExpectedInline(
236            ep.graph_module.code.strip(),
237            """\
238def forward(self, token, obj_attr, x):
239    with_effects = torch.ops.higher_order.with_effects(token, torch.ops.higher_order.call_torchbind, obj_attr, 'add_tensor', x);  token = obj_attr = None
240    getitem = with_effects[0]
241    getitem_1 = with_effects[1];  with_effects = None
242    add = torch.ops.aten.add.Tensor(x, getitem_1);  x = getitem_1 = None
243    return (getitem, add)""",  # noqa: B950
244        )
245
246    @parametrize("pre_dispatch", [True, False])
247    def test_attribute_as_custom_op_argument(self, pre_dispatch):
248        class MyModule(torch.nn.Module):
249            def __init__(self) -> None:
250                super().__init__()
251                self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
252
253            def forward(self, x):
254                return x + torch.ops._TorchScriptTesting.takes_foo(self.attr, x)
255
256        ep = self._test_export_same_as_eager(
257            MyModule(), (torch.ones(2, 3),), strict=False, pre_dispatch=pre_dispatch
258        )
259        self.assertExpectedInline(
260            ep.module().code.strip(),
261            """\
262def forward(self, x):
263    x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
264    attr = self.attr
265    takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, x);  attr = None
266    add = torch.ops.aten.add.Tensor(x, takes_foo_default);  x = takes_foo_default = None
267    return pytree.tree_unflatten((add,), self._out_spec)""",
268        )
269        self.assertExpectedInline(
270            ep.graph_module.code.strip(),
271            """\
272def forward(self, token, obj_attr, x):
273    with_effects = torch.ops.higher_order.with_effects(token, torch.ops._TorchScriptTesting.takes_foo.default, obj_attr, x);  token = obj_attr = None
274    getitem = with_effects[0]
275    getitem_1 = with_effects[1];  with_effects = None
276    add = torch.ops.aten.add.Tensor(x, getitem_1);  x = getitem_1 = None
277    return (getitem, add)""",  # noqa: B950
278        )
279
280    @parametrize("pre_dispatch", [True, False])
281    def test_input(self, pre_dispatch):
282        cc = torch.classes._TorchScriptTesting._Foo(10, 20)
283
284        class MyModule(torch.nn.Module):
285            def __init__(self) -> None:
286                super().__init__()
287
288            def forward(self, x, cc):
289                return x + cc.add_tensor(x)
290
291        ep = self._test_export_same_as_eager(
292            MyModule(), (torch.ones(2, 3), cc), strict=False, pre_dispatch=pre_dispatch
293        )
294        self.assertExpectedInline(
295            ep.module().code.strip(),
296            """\
297def forward(self, x, cc):
298    x, cc, = fx_pytree.tree_flatten_spec(([x, cc], {}), self._in_spec)
299    call_torchbind = torch.ops.higher_order.call_torchbind(cc, 'add_tensor', x);  cc = None
300    add = torch.ops.aten.add.Tensor(x, call_torchbind);  x = call_torchbind = None
301    return pytree.tree_unflatten((add,), self._out_spec)""",
302        )
303        self.assertExpectedInline(
304            ep.graph_module.code.strip(),
305            """\
306def forward(self, token, x, cc):
307    with_effects = torch.ops.higher_order.with_effects(token, torch.ops.higher_order.call_torchbind, cc, 'add_tensor', x);  token = cc = None
308    getitem = with_effects[0]
309    getitem_1 = with_effects[1];  with_effects = None
310    add = torch.ops.aten.add.Tensor(x, getitem_1);  x = getitem_1 = None
311    return (getitem, add)""",  # noqa: B950
312        )
313        # aot_export_function runs the program twice
314        # in run_functionalized_fw_and_collect_metadata and create_aot_dispatcher_function
315        # We also have a re-tracing test, which doubles the count.
316        self.assertEqual(self.foo_add_tensor_counter, 4)
317
318    @parametrize("pre_dispatch", [True, False])
319    def test_input_as_custom_op_argument(self, pre_dispatch):
320        cc = torch.classes._TorchScriptTesting._Foo(10, 20)
321
322        class MyModule(torch.nn.Module):
323            def __init__(self) -> None:
324                super().__init__()
325
326            def forward(self, x, cc):
327                return x + torch.ops._TorchScriptTesting.takes_foo(cc, x)
328
329        del torch.ops._TorchScriptTesting.takes_foo.default.py_kernels[
330            torch._C.DispatchKey.Meta
331        ]
332        torch.ops._TorchScriptTesting.takes_foo.default._dispatch_cache.clear()
333        # Even though a C++ implementation for takes_foo.default is registered,
334        # we still need the python implementation for takes_foo.default to trace with FakeFoo.
335        with self.assertRaisesRegex(RuntimeError, "no python implementation is found"):
336            self._test_export_same_as_eager(
337                MyModule(),
338                (torch.ones(2, 3), cc),
339                strict=False,
340                pre_dispatch=pre_dispatch,
341            )
342
343        torch.ops._TorchScriptTesting.takes_foo.default.py_impl(
344            torch._C.DispatchKey.Meta
345        )(lambda cc, x: cc.add_tensor(x))
346        ep = self._test_export_same_as_eager(
347            MyModule(),
348            (torch.ones(2, 3), cc),
349            strict=False,
350            pre_dispatch=pre_dispatch,
351        )
352
353        self.assertExpectedInline(
354            ep.module().code.strip(),
355            """\
356def forward(self, x, cc):
357    x, cc, = fx_pytree.tree_flatten_spec(([x, cc], {}), self._in_spec)
358    takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(cc, x);  cc = None
359    add = torch.ops.aten.add.Tensor(x, takes_foo_default);  x = takes_foo_default = None
360    return pytree.tree_unflatten((add,), self._out_spec)""",
361        )
362        self.assertExpectedInline(
363            ep.graph_module.code.strip(),
364            """\
365def forward(self, token, x, cc):
366    with_effects = torch.ops.higher_order.with_effects(token, torch.ops._TorchScriptTesting.takes_foo.default, cc, x);  token = cc = None
367    getitem = with_effects[0]
368    getitem_1 = with_effects[1];  with_effects = None
369    add = torch.ops.aten.add.Tensor(x, getitem_1);  x = getitem_1 = None
370    return (getitem, add)""",  # noqa: B950
371        )
372
373    @parametrize("pre_dispatch", [True, False])
374    def test_torchbind_alias(self, pre_dispatch):
375        class F2(torch.nn.Module):
376            def __init__(self, foo):
377                super().__init__()
378                self.foo = foo
379
380            def forward(self, x):
381                return x + torch.ops._TorchScriptTesting.takes_foo(self.foo, x)
382
383        class F1(torch.nn.Module):
384            def __init__(self) -> None:
385                super().__init__()
386                self.alpha = torch.classes._TorchScriptTesting._Foo(10, 20)
387                self.beta = self.alpha
388                self.gamma = self.alpha
389                self.foo = F2(self.gamma)
390
391            def forward(self, x):
392                return (
393                    x
394                    + torch.ops._TorchScriptTesting.takes_foo(self.gamma, x)
395                    + self.foo(x)
396                )
397
398        self._test_export_same_as_eager(
399            F1(), (torch.ones(2, 3),), strict=False, pre_dispatch=pre_dispatch
400        )
401
402    # TODO(pianpwk): look into this
403    @unittest.expectedFailure
404    @parametrize("pre_dispatch", [True, False])
405    def test_torchbind_input_and_alias(self, pre_dispatch):
406        # alias as model attribute
407        class F3(torch.nn.Module):
408            def forward(self, x, foo):
409                self.foo = foo
410                return x + self.foo.add_tensor(x)
411
412        foo = torch.classes._TorchScriptTesting._Foo(10, 20)
413        self._test_export_same_as_eager(
414            F3(), (torch.ones(2, 3), foo), strict=False, pre_dispatch=pre_dispatch
415        )
416
417    @parametrize("pre_dispatch", [True, False])
418    def test_unlift_custom_obj(self, pre_dispatch):
419        class MyModule(torch.nn.Module):
420            def __init__(self) -> None:
421                super().__init__()
422                self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
423
424            def forward(self, x):
425                a = torch.ops._TorchScriptTesting.takes_foo(self.attr, x)
426                b = torch.ops._TorchScriptTesting.takes_foo(self.attr, a)
427                return x + b
428
429        input = torch.ones(2, 3)
430        ep = self._test_export_same_as_eager(
431            MyModule(), (input,), strict=False, pre_dispatch=pre_dispatch
432        )
433        self.assertExpectedInline(
434            ep.module().code.strip(),
435            """\
436def forward(self, x):
437    x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
438    attr = self.attr
439    takes_foo_default_1 = torch.ops._TorchScriptTesting.takes_foo.default(attr, x)
440    takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, takes_foo_default_1);  attr = takes_foo_default_1 = None
441    add = torch.ops.aten.add.Tensor(x, takes_foo_default);  x = takes_foo_default = None
442    return pytree.tree_unflatten((add,), self._out_spec)""",  # noqa: B950
443        )
444        self.assertExpectedInline(
445            ep.graph_module.code.strip(),
446            """\
447def forward(self, token, obj_attr, x):
448    with_effects = torch.ops.higher_order.with_effects(token, torch.ops._TorchScriptTesting.takes_foo.default, obj_attr, x);  token = None
449    getitem = with_effects[0]
450    getitem_1 = with_effects[1];  with_effects = None
451    with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._TorchScriptTesting.takes_foo.default, obj_attr, getitem_1);  getitem = obj_attr = getitem_1 = None
452    getitem_2 = with_effects_1[0]
453    getitem_3 = with_effects_1[1];  with_effects_1 = None
454    add = torch.ops.aten.add.Tensor(x, getitem_3);  x = getitem_3 = None
455    return (getitem_2, add)""",  # noqa: B950
456        )
457
458    @parametrize("pre_dispatch", [True, False])
459    def test_custom_obj_list_out(self, pre_dispatch):
460        class MyModule(torch.nn.Module):
461            def __init__(self) -> None:
462                super().__init__()
463                self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
464
465            def forward(self, x):
466                a = torch.ops._TorchScriptTesting.takes_foo_list_return(self.attr, x)
467                y = a[0] + a[1] + a[2]
468                b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y)
469                return x + b
470
471        input = torch.ones(2, 3)
472        ep = self._test_export_same_as_eager(
473            MyModule(), (input,), strict=False, pre_dispatch=pre_dispatch
474        )
475        self.assertExpectedInline(
476            ep.module().code.strip(),
477            """\
478def forward(self, x):
479    x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
480    attr = self.attr
481    takes_foo_list_return_default = torch.ops._TorchScriptTesting.takes_foo_list_return.default(attr, x)
482    getitem_2 = takes_foo_list_return_default[0]
483    getitem_3 = takes_foo_list_return_default[1]
484    getitem_4 = takes_foo_list_return_default[2];  takes_foo_list_return_default = None
485    add = torch.ops.aten.add.Tensor(getitem_2, getitem_3);  getitem_2 = getitem_3 = None
486    add_1 = torch.ops.aten.add.Tensor(add, getitem_4);  add = getitem_4 = None
487    takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, add_1);  attr = add_1 = None
488    add_2 = torch.ops.aten.add.Tensor(x, takes_foo_default);  x = takes_foo_default = None
489    return pytree.tree_unflatten((add_2,), self._out_spec)""",
490        )
491        self.assertExpectedInline(
492            ep.graph_module.code.strip(),
493            """\
494def forward(self, token, obj_attr, x):
495    with_effects = torch.ops.higher_order.with_effects(token, torch.ops._TorchScriptTesting.takes_foo_list_return.default, obj_attr, x);  token = None
496    getitem = with_effects[0]
497    getitem_1 = with_effects[1];  with_effects = None
498    getitem_2 = getitem_1[0]
499    getitem_3 = getitem_1[1]
500    getitem_4 = getitem_1[2];  getitem_1 = None
501    add = torch.ops.aten.add.Tensor(getitem_2, getitem_3);  getitem_2 = getitem_3 = None
502    add_1 = torch.ops.aten.add.Tensor(add, getitem_4);  add = getitem_4 = None
503    with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._TorchScriptTesting.takes_foo.default, obj_attr, add_1);  getitem = obj_attr = add_1 = None
504    getitem_5 = with_effects_1[0]
505    getitem_6 = with_effects_1[1];  with_effects_1 = None
506    add_2 = torch.ops.aten.add.Tensor(x, getitem_6);  x = getitem_6 = None
507    return (getitem_5, add_2)""",  # noqa: B950
508        )
509
510    @parametrize("pre_dispatch", [True, False])
511    def test_custom_obj_tuple_out(self, pre_dispatch):
512        class MyModule(torch.nn.Module):
513            def __init__(self) -> None:
514                super().__init__()
515                self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
516
517            def forward(self, x):
518                a = torch.ops._TorchScriptTesting.takes_foo_tuple_return(self.attr, x)
519                y = a[0] + a[1]
520                b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y)
521                return x + b
522
523        input = torch.ones(2, 3)
524        ep = self._test_export_same_as_eager(
525            MyModule(), (input,), strict=False, pre_dispatch=pre_dispatch
526        )
527        self.assertExpectedInline(
528            ep.module().code.strip(),
529            """\
530def forward(self, x):
531    x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
532    attr = self.attr
533    takes_foo_tuple_return_default = torch.ops._TorchScriptTesting.takes_foo_tuple_return.default(attr, x)
534    getitem_1 = takes_foo_tuple_return_default[0]
535    getitem_2 = takes_foo_tuple_return_default[1];  takes_foo_tuple_return_default = None
536    add = torch.ops.aten.add.Tensor(getitem_1, getitem_2);  getitem_1 = getitem_2 = None
537    takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, add);  attr = add = None
538    add_1 = torch.ops.aten.add.Tensor(x, takes_foo_default);  x = takes_foo_default = None
539    return pytree.tree_unflatten((add_1,), self._out_spec)""",
540        )
541        self.assertExpectedInline(
542            ep.graph_module.code.strip(),
543            """\
544def forward(self, token, obj_attr, x):
545    with_effects = torch.ops.higher_order.with_effects(token, torch.ops._TorchScriptTesting.takes_foo_tuple_return.default, obj_attr, x);  token = None
546    getitem = with_effects[0]
547    getitem_1 = with_effects[1]
548    getitem_2 = with_effects[2];  with_effects = None
549    add = torch.ops.aten.add.Tensor(getitem_1, getitem_2);  getitem_1 = getitem_2 = None
550    with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._TorchScriptTesting.takes_foo.default, obj_attr, add);  getitem = obj_attr = add = None
551    getitem_3 = with_effects_1[0]
552    getitem_4 = with_effects_1[1];  with_effects_1 = None
553    add_1 = torch.ops.aten.add.Tensor(x, getitem_4);  x = getitem_4 = None
554    return (getitem_3, add_1)""",  # noqa: B950
555        )
556
557    @parametrize("make_fx_tracing_mode", ["fake", "symbolic"])
558    def test_make_fx_tensor_queue_methods(self, make_fx_tracing_mode):
559        test = self
560
561        class Model(torch.nn.Module):
562            def __init__(self) -> None:
563                super().__init__()
564                self.linear = torch.nn.Linear(3, 2)
565                self.check_tq_is_fake = True
566
567            def forward(self, tq, x):
568                if self.check_tq_is_fake:
569                    test.assertTrue(isinstance(tq, FakeScriptObject))
570                tq.push(x.cos())
571                tq.push(x.sin())
572                x_cos = tq.pop() + tq.size()
573                x_sin = tq.pop() - tq.size()
574                return x_sin, x_cos, tq
575
576        mod = Model()
577        tq = torch.classes._TorchScriptTesting._TensorQueue(
578            torch.empty(
579                0,
580            ).fill_(-1)
581        )
582        tq1 = torch.classes._TorchScriptTesting._TensorQueue(
583            torch.empty(
584                0,
585            ).fill_(-1)
586        )
587        x = torch.ones(2, 3)
588        gm = make_fx(mod, tracing_mode=make_fx_tracing_mode)(tq, x)
589        self.assertEqual(self.tq_push_counter, 2)
590        self.assertEqual(self.tq_pop_counter, 2)
591        self.assertEqual(self.tq_size_counter, 2)
592        self.assertEqual(tq.size(), 0)
593        self.assertExpectedInline(
594            gm.code.strip("\n"),
595            """\
596def forward(self, arg0_1, arg1_1):
597    cos = torch.ops.aten.cos.default(arg1_1)
598    call_torchbind = torch.ops.higher_order.call_torchbind(arg0_1, 'push', cos);  cos = call_torchbind = None
599    sin = torch.ops.aten.sin.default(arg1_1);  arg1_1 = None
600    call_torchbind_1 = torch.ops.higher_order.call_torchbind(arg0_1, 'push', sin);  sin = call_torchbind_1 = None
601    call_torchbind_2 = torch.ops.higher_order.call_torchbind(arg0_1, 'pop')
602    call_torchbind_3 = torch.ops.higher_order.call_torchbind(arg0_1, 'size');  call_torchbind_3 = None
603    add = torch.ops.aten.add.Tensor(call_torchbind_2, 1);  call_torchbind_2 = None
604    call_torchbind_4 = torch.ops.higher_order.call_torchbind(arg0_1, 'pop')
605    call_torchbind_5 = torch.ops.higher_order.call_torchbind(arg0_1, 'size');  call_torchbind_5 = None
606    sub = torch.ops.aten.sub.Tensor(call_torchbind_4, 0);  call_torchbind_4 = None
607    return (sub, add, arg0_1)
608    """,
609        )
610        mod.check_tq_is_fake = False
611        _assertEqualSkipScriptObject(self, gm(tq, x), mod(tq1, x))
612
613    @parametrize("make_fx_tracing_mode", ["fake", "symbolic"])
614    def test_make_fx_tensor_queue_methods_fakify_internal_states(
615        self, make_fx_tracing_mode
616    ):
617        test = self
618
619        class Model(torch.nn.Module):
620            def __init__(self) -> None:
621                super().__init__()
622                self.linear = torch.nn.Linear(3, 2)
623                self.check_tq_is_fake = True
624                self.current_test = test
625
626            def forward(self, tq, x):
627                if self.check_tq_is_fake:
628                    self.current_test.assertTrue(isinstance(tq, FakeScriptObject))
629                x_cos = tq.pop() + tq.size() + x
630                x_sin = tq.pop() - tq.size() + x
631                return x_sin, x_cos, tq
632
633        mod = Model()
634        tq = torch.classes._TorchScriptTesting._TensorQueue(
635            torch.empty(
636                0,
637            ).fill_(-1)
638        )
639        tq1 = torch.classes._TorchScriptTesting._TensorQueue(
640            torch.empty(
641                0,
642            ).fill_(-1)
643        )
644        for _ in range(2):
645            tq.push(torch.ones(2, 3))
646            tq1.push(torch.ones(2, 3))
647        x = torch.ones(2, 3)
648        prev_size = tq.size()
649        gm = make_fx(mod, tracing_mode=make_fx_tracing_mode)(tq, x)
650        self.assertEqual(self.tq_push_counter, 0)
651        self.assertEqual(self.tq_pop_counter, 2)
652        self.assertEqual(self.tq_size_counter, 2)
653        self.assertEqual(tq.size(), prev_size)
654        self.assertExpectedInline(
655            gm.code.strip("\n"),
656            """\
657def forward(self, arg0_1, arg1_1):
658    call_torchbind = torch.ops.higher_order.call_torchbind(arg0_1, 'pop')
659    call_torchbind_1 = torch.ops.higher_order.call_torchbind(arg0_1, 'size');  call_torchbind_1 = None
660    add = torch.ops.aten.add.Tensor(call_torchbind, 1);  call_torchbind = None
661    add_1 = torch.ops.aten.add.Tensor(add, arg1_1);  add = None
662    call_torchbind_2 = torch.ops.higher_order.call_torchbind(arg0_1, 'pop')
663    call_torchbind_3 = torch.ops.higher_order.call_torchbind(arg0_1, 'size');  call_torchbind_3 = None
664    sub = torch.ops.aten.sub.Tensor(call_torchbind_2, 0);  call_torchbind_2 = None
665    add_2 = torch.ops.aten.add.Tensor(sub, arg1_1);  sub = arg1_1 = None
666    return (add_2, add_1, arg0_1)
667    """,
668        )
669        # turn off tq type checking in eager execution
670        mod.check_tq_is_fake = False
671        _assertEqualSkipScriptObject(self, gm(tq, x), mod(tq1, x))
672        self.assertEqual(tq.size(), 0)
673        self.assertEqual(tq1.size(), 0)
674
675    def test_non_strict_export_methods(self):
676        class Model(torch.nn.Module):
677            def __init__(self) -> None:
678                super().__init__()
679                self.linear = torch.nn.Linear(2, 2)
680
681            def forward(self, tq, x):
682                x_cos = tq.pop() + tq.float_size() + self.linear(x)
683                if tq.is_empty():
684                    x_sin = self.linear(tq.pop()) - tq.size() + x
685                else:
686                    x_sin = tq.pop() + tq.size() + x
687                return x_sin, x_cos, tq
688
689        mod = Model()
690        tq = _empty_tensor_queue()
691        a = torch.randn(2, 2)
692        b = torch.randn(2, 2)
693        tq.push(a)
694        tq.push(b)
695        ep = torch.export.export(mod, (tq, torch.randn(2, 2)), strict=False)
696        self.assertExpectedInline(
697            ep.graph_module.code.strip(),
698            """\
699def forward(self, token, p_linear_weight, p_linear_bias, tq, x):
700    with_effects = torch.ops.higher_order.with_effects(token, torch.ops.higher_order.call_torchbind, tq, 'pop');  token = None
701    getitem = with_effects[0]
702    getitem_1 = with_effects[1];  with_effects = None
703    with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.higher_order.call_torchbind, tq, 'float_size');  getitem = None
704    getitem_2 = with_effects_1[0];  with_effects_1 = None
705    add = torch.ops.aten.add.Tensor(getitem_1, 1.0);  getitem_1 = None
706    linear = torch.ops.aten.linear.default(x, p_linear_weight, p_linear_bias);  p_linear_weight = p_linear_bias = None
707    add_1 = torch.ops.aten.add.Tensor(add, linear);  add = linear = None
708    with_effects_2 = torch.ops.higher_order.with_effects(getitem_2, torch.ops.higher_order.call_torchbind, tq, 'is_empty');  getitem_2 = None
709    getitem_4 = with_effects_2[0];  with_effects_2 = None
710    with_effects_3 = torch.ops.higher_order.with_effects(getitem_4, torch.ops.higher_order.call_torchbind, tq, 'pop');  getitem_4 = None
711    getitem_6 = with_effects_3[0]
712    getitem_7 = with_effects_3[1];  with_effects_3 = None
713    with_effects_4 = torch.ops.higher_order.with_effects(getitem_6, torch.ops.higher_order.call_torchbind, tq, 'size');  getitem_6 = None
714    getitem_8 = with_effects_4[0];  with_effects_4 = None
715    add_2 = torch.ops.aten.add.Tensor(getitem_7, 0);  getitem_7 = None
716    add_3 = torch.ops.aten.add.Tensor(add_2, x);  add_2 = x = None
717    return (getitem_8, add_3, add_1, tq)""",  # noqa: B950
718        )
719        self.assertEqual(tq.size(), 2)
720        self.assertTrue(tq.pop() is a)
721        self.assertTrue(tq.pop() is b)
722
723    def test_safe_to_trace_with_real(self):
724        x = torch.randn(3, 3)
725        safe_obj = torch.classes._TorchScriptTesting._ConstantTensorContainer(x)
726
727        class Mod(torch.nn.Module):
728            def forward(self, safe_obj: torch.ScriptObject) -> None:
729                return safe_obj.get().sin()
730
731        mod = Mod()
732        backend = EagerAndRecordGraphs()
733        out = torch.compile(mod, backend=backend, fullgraph=True)(safe_obj)
734        self.assertEqual(out, mod(safe_obj))
735        self.assertExpectedInline(
736            backend.graphs[0].code.strip(),
737            """\
738def forward(self, L_safe_obj_ : torch.ScriptObject):
739    l_safe_obj_ = L_safe_obj_
740    call_torchbind = torch.ops.higher_order.call_torchbind(l_safe_obj_, 'get');  l_safe_obj_ = None
741    sin = call_torchbind.sin();  call_torchbind = None
742    return (sin,)""",
743        )
744
745        with enable_torchbind_tracing():
746            ep = torch.export.export(mod, (safe_obj,), strict=False)
747            self.assertExpectedInline(
748                ep.graph_module.code.strip(),
749                """\
750def forward(self, token, safe_obj):
751    with_effects = torch.ops.higher_order.with_effects(token, torch.ops.higher_order.call_torchbind, safe_obj, 'get');  token = safe_obj = None
752    getitem = with_effects[0]
753    getitem_1 = with_effects[1];  with_effects = None
754    sin = torch.ops.aten.sin.default(getitem_1);  getitem_1 = None
755    return (getitem, sin)""",  # noqa: B950
756            )
757
758    def test_identifying_torchbind_ops(self):
759        for op in self.torch_bind_ops:
760            self.assertTrue(op._has_torchbind_op_overload)
761
762        for op in [
763            torch.ops.aten.add,
764            torch.ops.aten.cos,
765        ]:
766            self.assertFalse(op._has_torchbind_op_overload)
767
768    def test_torchbind_op_register_fallthrough(self):
769        TEST_DISPATCH_KEY = torch._C.DispatchKey.AutocastCPU
770        TEST_DISPATCH_KEY_STR = "AutocastCPU"
771
772        for op_packet in self.torch_bind_ops:
773            op = op_packet.default
774            ns, _ = torch._library.utils.parse_namespace(op_packet._qualified_op_name)
775            with torch.library._scoped_library(ns, "FRAGMENT") as lib:
776                lib.impl(
777                    op.name(), torch.library.fallthrough_kernel, TEST_DISPATCH_KEY_STR
778                )
779                self.assertTrue(
780                    torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough(
781                        op.name(), TEST_DISPATCH_KEY
782                    )
783                )
784
785    def test_torchbind_op_fallthrough_keys_respects_lib_impl(self):
786        TEST_DISPATCH_KEY = torch._C.DispatchKey.AutogradCPU
787        TEST_DISPATCH_KEY_STR = "AutogradCPU"
788
789        tested = 0
790        for op_packet in self.torch_bind_ops:
791            op = op_packet.default
792            ns, _ = torch._library.utils.parse_namespace(op_packet._qualified_op_name)
793            if (
794                not torch._C._dispatch_has_kernel_for_dispatch_key(
795                    op.name(), TEST_DISPATCH_KEY
796                )
797                and TEST_DISPATCH_KEY not in op.py_kernels
798            ):
799                tested += 1
800                with torch.library._scoped_library(ns, "FRAGMENT") as lib:
801                    lib.impl(
802                        op.name(), lambda *args, **kwargs: args, TEST_DISPATCH_KEY_STR
803                    )
804                    self.assertTrue(TEST_DISPATCH_KEY not in op._fallthrough_keys())
805
806                with torch.library._scoped_library(ns, "FRAGMENT") as lib:
807                    lib.impl(
808                        op.name(),
809                        torch.library.fallthrough_kernel,
810                        TEST_DISPATCH_KEY_STR,
811                    )
812                    self.assertTrue(TEST_DISPATCH_KEY in op._fallthrough_keys())
813        self.assertTrue(tested > 0)
814
815    def test_make_fx_schema_checking_script_object(self):
816        class Model(torch.nn.Module):
817            def forward(self, tq, x, foo):
818                torch.ops._TorchScriptTesting.queue_push(foo, x.cos())
819                return tq
820
821        class ModelCallByKW(torch.nn.Module):
822            def forward(self, tq, x, foo):
823                torch.ops._TorchScriptTesting.queue_push(x=x.cos(), foo=foo)
824                return tq
825
826        mod = Model()
827        modkw = ModelCallByKW()
828
829        foo = torch.classes._TorchScriptTesting._Foo(10, 20)
830        x = torch.ones(3, 3)
831        tq = torch.classes._TorchScriptTesting._TensorQueue(
832            torch.empty(
833                0,
834            ).fill_(-1)
835        )
836        ns = "_TorchScriptTesting"
837        with torch.library._scoped_library(ns, "FRAGMENT") as lib:
838            op = torch.ops._TorchScriptTesting.queue_push
839            lib.impl(op.__name__, torch.library.fallthrough_kernel, "AutogradCPU")
840            lib.impl(op.__name__, torch.library.fallthrough_kernel, "ADInplaceOrView")
841            lib.impl(
842                op.__name__,
843                torch.library.fallthrough_kernel,
844                "PythonTLSSnapshot",
845            )
846
847            with self.assertRaisesRegex(
848                RuntimeError, "is expected to be a FakeScriptObject"
849            ):
850                _ = make_fx(mod, tracing_mode="fake")(tq, x, foo)
851
852            with self.assertRaisesRegex(
853                RuntimeError, "is expected to be a FakeScriptObject"
854            ):
855                _ = make_fx(modkw, tracing_mode="fake")(tq, x, foo)
856
857    @parametrize("fallthrough_via", ["lib_impl", "py_impl"])
858    def test_make_fx_tensor_queue_operators(self, fallthrough_via):
859        class Model(torch.nn.Module):
860            def __init__(self) -> None:
861                super().__init__()
862
863            def forward(self, tq, x):
864                with torch.autocast("cuda", dtype=torch.bfloat16):
865                    torch.ops._TorchScriptTesting.queue_push(tq, x.cos())
866                    torch.ops._TorchScriptTesting.queue_push(tq, x.sin())
867                    x_sin = torch.ops._TorchScriptTesting.queue_pop(
868                        tq
869                    ) - torch.ops._TorchScriptTesting.queue_size(tq)
870                    x_cos = torch.ops._TorchScriptTesting.queue_pop(
871                        tq
872                    ) + torch.ops._TorchScriptTesting.queue_size(tq)
873                    return x_sin, x_cos, tq
874
875        mod = Model()
876
877        tq1 = torch.classes._TorchScriptTesting._TensorQueue(
878            torch.empty(
879                0,
880            ).fill_(-1)
881        )
882        tq2 = torch.classes._TorchScriptTesting._TensorQueue(
883            torch.empty(
884                0,
885            ).fill_(-1)
886        )
887        x = torch.ones(2, 3)
888
889        mod(tq1, x)
890
891        ops = [
892            torch.ops._TorchScriptTesting.queue_push,
893            torch.ops._TorchScriptTesting.queue_pop,
894            torch.ops._TorchScriptTesting.queue_size,
895        ]
896        if fallthrough_via == "lib_impl":
897            ns = "_TorchScriptTesting"
898            with torch.library._scoped_library(ns, "FRAGMENT") as lib:
899                for op in ops:
900                    lib.impl(
901                        op.__name__, torch.library.fallthrough_kernel, "AutocastCUDA"
902                    )
903
904                gm = make_fx(mod, tracing_mode="fake")(tq1, x)
905        else:
906            for op in ops:
907                op.default.py_impl(torch._C.DispatchKey.AutocastCUDA)(
908                    torch.library.fallthrough_kernel
909                )
910            gm = make_fx(mod, tracing_mode="fake")(tq1, x)
911            for op in ops:
912                op.default._dispatch_cache.clear()
913                del op.default.py_kernels[torch._C.DispatchKey.AutocastCUDA]
914
915        self.assertExpectedInline(
916            gm.code.strip(),
917            """\
918def forward(self, arg0_1, arg1_1):
919    cos = torch.ops.aten.cos.default(arg1_1)
920    queue_push = torch.ops._TorchScriptTesting.queue_push.default(arg0_1, cos);  cos = queue_push = None
921    sin = torch.ops.aten.sin.default(arg1_1);  arg1_1 = None
922    queue_push_1 = torch.ops._TorchScriptTesting.queue_push.default(arg0_1, sin);  sin = queue_push_1 = None
923    queue_pop = torch.ops._TorchScriptTesting.queue_pop.default(arg0_1)
924    queue_size = torch.ops._TorchScriptTesting.queue_size.default(arg0_1);  queue_size = None
925    sub = torch.ops.aten.sub.Tensor(queue_pop, 1);  queue_pop = None
926    queue_pop_1 = torch.ops._TorchScriptTesting.queue_pop.default(arg0_1)
927    queue_size_1 = torch.ops._TorchScriptTesting.queue_size.default(arg0_1);  queue_size_1 = None
928    add = torch.ops.aten.add.Tensor(queue_pop_1, 0);  queue_pop_1 = None
929    return (sub, add, arg0_1)""",
930        )
931        _assertEqualSkipScriptObject(self, gm(tq1, x), mod(tq2, x))
932
933    def test_aot_export_tensor_queue_operators(self):
934        class Model(torch.nn.Module):
935            def __init__(self) -> None:
936                super().__init__()
937
938            def forward(self, tq, x):
939                torch.ops._TorchScriptTesting.queue_push(tq, x.cos())
940                torch.ops._TorchScriptTesting.queue_push(tq, x.sin())
941                x_sin = torch.ops._TorchScriptTesting.queue_pop(
942                    tq
943                ) - torch.ops._TorchScriptTesting.queue_size(tq)
944                x_cos = torch.ops._TorchScriptTesting.queue_pop(
945                    tq
946                ) + torch.ops._TorchScriptTesting.queue_size(tq)
947                return x_sin, x_cos, tq
948
949        mod = Model()
950
951        tq1 = torch.classes._TorchScriptTesting._TensorQueue(
952            torch.empty(
953                0,
954            ).fill_(-1)
955        )
956        x = torch.ones(2, 3)
957
958        fake_mode = torch._subclasses.fake_tensor.FakeTensorMode()
959        fake_tq1 = torch._library.fake_class_registry.maybe_to_fake_obj(fake_mode, tq1)
960        fake_x = fake_mode.from_tensor(x)
961        gm = aot_export_module(mod, (fake_tq1, fake_x), trace_joint=False)[0]
962
963        # inputs: token, tq, x
964        # return: token, x_sin, x_cos, tq
965        self.assertExpectedInline(
966            gm.code.strip(),
967            """\
968def forward(self, arg0_1, arg1_1, arg2_1):
969    cos = torch.ops.aten.cos.default(arg2_1)
970    with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops._TorchScriptTesting.queue_push.default, arg1_1, cos);  arg0_1 = cos = None
971    getitem = with_effects[0];  with_effects = None
972    sin = torch.ops.aten.sin.default(arg2_1);  arg2_1 = None
973    with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._TorchScriptTesting.queue_push.default, arg1_1, sin);  getitem = sin = None
974    getitem_2 = with_effects_1[0];  with_effects_1 = None
975    with_effects_2 = torch.ops.higher_order.with_effects(getitem_2, torch.ops._TorchScriptTesting.queue_pop.default, arg1_1);  getitem_2 = None
976    getitem_4 = with_effects_2[0]
977    getitem_5 = with_effects_2[1];  with_effects_2 = None
978    with_effects_3 = torch.ops.higher_order.with_effects(getitem_4, torch.ops._TorchScriptTesting.queue_size.default, arg1_1);  getitem_4 = None
979    getitem_6 = with_effects_3[0];  with_effects_3 = None
980    sub = torch.ops.aten.sub.Tensor(getitem_5, 1);  getitem_5 = None
981    with_effects_4 = torch.ops.higher_order.with_effects(getitem_6, torch.ops._TorchScriptTesting.queue_pop.default, arg1_1);  getitem_6 = None
982    getitem_8 = with_effects_4[0]
983    getitem_9 = with_effects_4[1];  with_effects_4 = None
984    with_effects_5 = torch.ops.higher_order.with_effects(getitem_8, torch.ops._TorchScriptTesting.queue_size.default, arg1_1);  getitem_8 = None
985    getitem_10 = with_effects_5[0];  with_effects_5 = None
986    add = torch.ops.aten.add.Tensor(getitem_9, 0);  getitem_9 = None
987    return (getitem_10, sub, add, arg1_1)""",  # noqa: B950
988        )
989
990    def test_export_inplace_custom_op(self):
991        class Model(torch.nn.Module):
992            def forward(self, tq: torch.ScriptObject, x: torch.Tensor) -> None:
993                torch.ops._TorchScriptTesting.queue_push(tq, x)
994                return tq
995
996        mod = Model()
997        ep = self._test_export_same_as_eager(
998            mod,
999            (_empty_tensor_queue(), torch.randn(3, 3)),
1000            strict=False,
1001            pre_dispatch=True,
1002        )
1003        self.assertExpectedInline(
1004            ep.module().code.strip(),
1005            """\
1006def forward(self, tq, x):
1007    tq, x, = fx_pytree.tree_flatten_spec(([tq, x], {}), self._in_spec)
1008    queue_push_default = torch.ops._TorchScriptTesting.queue_push.default(tq, x);  x = queue_push_default = None
1009    return pytree.tree_unflatten((tq,), self._out_spec)""",
1010        )
1011        self.assertExpectedInline(
1012            ep.graph_module.code.strip(),
1013            """\
1014def forward(self, token, tq, x):
1015    with_effects = torch.ops.higher_order.with_effects(token, torch.ops._TorchScriptTesting.queue_push.default, tq, x);  token = x = None
1016    getitem = with_effects[0];  with_effects = None
1017    return (getitem, tq)""",  # noqa: B950
1018        )
1019        self.assertExpectedInline(
1020            str(ep.graph_module.graph).strip(),
1021            """\
1022graph():
1023    %tq : [num_users=2] = placeholder[target=tq]
1024    %x : [num_users=1] = placeholder[target=x]
1025    %queue_push_default : [num_users=0] = call_function[target=torch.ops._TorchScriptTesting.queue_push.default](args = (%tq, %x), kwargs = {})
1026    return (tq,)""",  # noqa: B950
1027        )
1028
1029
1030class TestCompileTorchbind(TestCase):
1031    def setUp(self):
1032        init_torchbind_implementations()
1033
1034        @torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue")
1035        class FakeTensorQueue:
1036            def __init__(self, queue):
1037                self.queue = queue
1038
1039            @classmethod
1040            def __obj_unflatten__(cls, flattened_ctx):
1041                return cls(**dict(flattened_ctx))
1042
1043            def push(self, x):
1044                self.queue.append(x)
1045
1046            def pop(self):
1047                return self.queue.pop(0)
1048
1049            def size(self):
1050                return len(self.queue)
1051
1052        @torch._library.register_fake_class("_TorchScriptTesting::_FlattenWithTensorOp")
1053        class FakeFlatten:
1054            def __init__(self, t):
1055                self.t = t
1056
1057            def get(self):
1058                return self.t
1059
1060            @classmethod
1061            def __obj_unflatten__(cls, flattened_ctx):
1062                return cls(**dict(flattened_ctx))
1063
1064        torch._dynamo.reset()
1065
1066    def tearDown(self):
1067        torch._dynamo.reset()
1068
1069    @parametrize("backend", ["eager", "aot_eager"])
1070    def test_compile_script_object_input(self, backend):
1071        if backend == "eager":
1072            backend = EagerAndRecordGraphs()
1073
1074        class Model(torch.nn.Module):
1075            def __init__(self) -> None:
1076                super().__init__()
1077                self.check_tq_is_fake = True
1078
1079            def forward(self, tq, x):
1080                tq.push(x.cos())
1081                tq.push(x.sin())
1082                x_sin = tq.pop() - tq.size()
1083                return x_sin, tq
1084
1085        mod = Model()
1086        tq1 = torch.classes._TorchScriptTesting._TensorQueue(
1087            torch.empty(
1088                0,
1089            ).fill_(-1)
1090        )
1091        tq2 = torch.classes._TorchScriptTesting._TensorQueue(
1092            torch.empty(
1093                0,
1094            ).fill_(-1)
1095        )
1096        tq3 = torch.classes._TorchScriptTesting._TensorQueue(
1097            torch.empty(
1098                0,
1099            ).fill_(-1)
1100        )
1101        tq4 = torch.classes._TorchScriptTesting._TensorQueue(
1102            torch.empty(
1103                0,
1104            ).fill_(-1)
1105        )
1106        x = torch.randn(2, 3)
1107        ret = torch.compile(mod, backend=backend)(tq1, x)
1108        eager_ret = mod(tq2, x)
1109        _assertEqualSkipScriptObject(self, ret, eager_ret)
1110        self.assertEqual(ret[1].size(), eager_ret[1].size())
1111        self.assertEqual(ret[1].pop(), eager_ret[1].pop())
1112        # Note that dynamo captured graph
1113        # does not return L_tq_ as output. This is because it's able
1114        # to detect that L_tq_ is an input therefore don't return
1115        # it as graph output. Related logic is in dynamo/codegen.py
1116        if backend == "eager":
1117            self.assertExpectedInline(
1118                backend.graphs[0].code.strip(),
1119                """\
1120    def forward(self, L_tq_ : torch.ScriptObject, L_x_ : torch.Tensor):
1121        l_tq_ = L_tq_
1122        l_x_ = L_x_
1123        cos = l_x_.cos()
1124        call_torchbind = torch.ops.higher_order.call_torchbind(l_tq_, 'push', cos);  cos = None
1125        sin = l_x_.sin();  l_x_ = None
1126        call_torchbind_1 = torch.ops.higher_order.call_torchbind(l_tq_, 'push', sin);  sin = None
1127        call_torchbind_2 = torch.ops.higher_order.call_torchbind(l_tq_, 'pop')
1128        call_torchbind_3 = torch.ops.higher_order.call_torchbind(l_tq_, 'size');  l_tq_ = None
1129        x_sin = call_torchbind_2 - 1;  call_torchbind_2 = None
1130        return (x_sin,)""",
1131            )
1132
1133    @parametrize("backend", ["eager", "aot_eager"])
1134    def test_compile_script_object_input_guards(self, backend):
1135        class Model(torch.nn.Module):
1136            def __init__(self) -> None:
1137                super().__init__()
1138                self.check_tq_is_fake = True
1139
1140            def forward(self, tq, x):
1141                tq.push(x.cos())
1142                tq.push(x.sin())
1143                x_sin = tq.pop() - tq.size()
1144                return x_sin, tq
1145
1146        mod = Model()
1147        cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
1148        x = torch.randn(2, 3)
1149
1150        tq1 = _empty_tensor_queue()
1151        torch.compile(mod, backend=cnt)(tq1, x)
1152        self.assertEqual(cnt.frame_count, 1)
1153
1154        tq2 = _empty_tensor_queue()
1155        for _ in range(10):
1156            tq2.push(torch.randn(4, 5, requires_grad=False))
1157        torch.compile(mod, backend=cnt)(tq2, x)
1158        # Queue length change causes re-compile
1159        self.assertEqual(cnt.frame_count, 2)
1160
1161        tq3 = _empty_tensor_queue()
1162        tq3.push(torch.randn(2, 3, requires_grad=False))
1163        torch.compile(mod, backend=cnt)(tq3, x)
1164        # Tensor in queue changes shape causes re-compile
1165        self.assertEqual(cnt.frame_count, 3)
1166
1167        tq4 = _empty_tensor_queue()
1168        tq4.push(torch.randn(2, 3, requires_grad=False))
1169        torch.compile(mod, backend=cnt)(tq4, x)
1170        # No recompile
1171        self.assertEqual(cnt.frame_count, 3)
1172
1173        tq5 = _empty_tensor_queue()
1174        tq5.push(torch.randn(2, 3, requires_grad=True))
1175        torch.compile(mod, backend=cnt)(tq5, x)
1176        # Tensor in queue changes dispatch key causes re-compile
1177        self.assertEqual(cnt.frame_count, 4)
1178
1179        tq6 = _empty_tensor_queue()
1180        tq6.push(torch.randn(2, 3, requires_grad=True, dtype=torch.float64))
1181        torch.compile(mod, backend=cnt)(tq6, x)
1182        # Tensor in queue changes dtype causes re-compile
1183        self.assertEqual(cnt.frame_count, 5)
1184
1185    def test_compile_script_object_input_automatic_dynamic_shape(self):
1186        class Model(torch.nn.Module):
1187            def __init__(self) -> None:
1188                super().__init__()
1189                self.check_tq_is_fake = True
1190
1191            def forward(self, tq, x):
1192                tq.push(x.cos())
1193                tq.push(x.sin())
1194                x_sin = tq.pop() - tq.size()
1195                return x_sin, tq
1196
1197        mod = Model()
1198        cnt = torch._dynamo.testing.CompileCounter()
1199        x = torch.randn(2, 3)
1200
1201        tq1 = _empty_tensor_queue()
1202        tq1.push(torch.randn(2, 3, requires_grad=False))
1203        torch.compile(mod, backend=cnt)(tq1, x)
1204        self.assertEqual(cnt.frame_count, 1)
1205
1206        tq2 = _empty_tensor_queue()
1207        # make first tensor's secon dim dynamic
1208        tq2.push(torch.randn(2, 4, requires_grad=False))
1209        torch.compile(mod, backend=cnt)(tq2, x)
1210        self.assertEqual(cnt.frame_count, 2)
1211
1212        tq3 = _empty_tensor_queue()
1213        tq3.push(torch.randn(2, 5, requires_grad=False))
1214        # should have no-recompilation
1215        torch.compile(mod, backend=cnt)(tq3, x)
1216        self.assertEqual(cnt.frame_count, 2)
1217
1218    @parametrize("backend", ["eager", "aot_eager"])
1219    def test_compile_error_on_input_aliasing_contents(self, backend):
1220        if backend == "eager":
1221            backend = EagerAndRecordGraphs()
1222
1223        class Model(torch.nn.Module):
1224            def __init__(self) -> None:
1225                super().__init__()
1226                self.check_tq_is_fake = True
1227
1228            def forward(self, tq, x):
1229                return x.sin(), tq.pop().cos()
1230
1231        x = torch.randn(2, 3)
1232        mod = Model()
1233
1234        tq1 = _empty_tensor_queue()
1235        tq1.push(x)
1236        with self.assertRaisesRegex(RuntimeError, "is alising"):
1237            torch.compile(mod, backend=backend)(tq1, x)
1238
1239    @parametrize("backend", ["eager", "aot_eager"])
1240    def test_compile_error_on_script_obj_setattr(self, backend):
1241        if backend == "eager":
1242            backend = EagerAndRecordGraphs()
1243
1244        def setattr_f(tq):
1245            tq.a = 1
1246            return tq
1247
1248        with self.assertRaisesRegex(
1249            RuntimeError, "call method __setattr__ on script object is not safe"
1250        ):
1251            torch.compile(setattr_f, backend=backend)(_empty_tensor_queue())
1252
1253    @parametrize("backend", ["eager", "aot_eager"])
1254    def test_compile_error_on_script_obj_missing_attr(self, backend):
1255        if backend == "eager":
1256            backend = EagerAndRecordGraphs()
1257
1258        def setattr_f(tq):
1259            return tq._not_defined_attr
1260
1261        with self.assertRaisesRegex(
1262            RuntimeError, "doesn't define method _not_defined_attr"
1263        ):
1264            torch.compile(setattr_f, backend=backend)(_empty_tensor_queue())
1265
1266    @parametrize("backend", ["eager", "aot_eager"])
1267    def test_compile_body_aliasing_contents(self, backend):
1268        if backend == "eager":
1269            backend = EagerAndRecordGraphs()
1270
1271        def f(tq, x):
1272            x1 = x.view(-1)
1273            x2 = x.permute(1, 0)
1274            tq.push(x1)
1275            tq.push(x2)
1276            return x1 - tq.size(), x2 + tq.size(), tq
1277
1278        x = torch.randn(2, 3)
1279        _assertEqualScriptObject(
1280            self,
1281            f(_empty_tensor_queue(), x),
1282            torch.compile(f, backend=backend)(_empty_tensor_queue(), x),
1283        )
1284        if not torch._dynamo.is_compiling() and backend == "eager":
1285            self.assertExpectedInline(
1286                backend.graphs[0].code.strip(),
1287                """\
1288def forward(self, L_x_ : torch.Tensor, L_tq_ : torch.ScriptObject):
1289    l_x_ = L_x_
1290    l_tq_ = L_tq_
1291    x1 = l_x_.view(-1)
1292    x2 = l_x_.permute(1, 0);  l_x_ = None
1293    call_torchbind = torch.ops.higher_order.call_torchbind(l_tq_, 'push', x1)
1294    call_torchbind_1 = torch.ops.higher_order.call_torchbind(l_tq_, 'push', x2)
1295    call_torchbind_2 = torch.ops.higher_order.call_torchbind(l_tq_, 'size')
1296    sub = x1 - 2;  x1 = None
1297    call_torchbind_3 = torch.ops.higher_order.call_torchbind(l_tq_, 'size');  l_tq_ = None
1298    add = x2 + 2;  x2 = None
1299    return (sub, add)""",
1300            )
1301
1302    @parametrize("backend", ["eager", "aot_eager"])
1303    def test_compile_tensor_op_in_tensor_flatten(self, backend):
1304        test_obj = torch.classes._TorchScriptTesting._FlattenWithTensorOp(
1305            torch.randn(3, 2)
1306        )
1307
1308        class TestMod(torch.nn.Module):
1309            def forward(self, obj, x):
1310                return obj.get() + x
1311
1312        mod = TestMod()
1313
1314        torch.compile(mod, backend=backend, fullgraph=True)(test_obj, torch.randn(3, 1))
1315        ep = torch.export.export(mod, (test_obj, torch.randn(3, 1)), strict=False)
1316        self.assertExpectedInline(
1317            ep.graph_module.code.strip(),
1318            """\
1319def forward(self, token, obj, x):
1320    with_effects = torch.ops.higher_order.with_effects(token, torch.ops.higher_order.call_torchbind, obj, 'get');  token = obj = None
1321    getitem = with_effects[0]
1322    getitem_1 = with_effects[1];  with_effects = None
1323    add_3 = torch.ops.aten.add.Tensor(getitem_1, x);  getitem_1 = x = None
1324    return (getitem, add_3)""",  # noqa: B950
1325        )
1326
1327    @parametrize("backend", ["eager", "aot_eager"])
1328    def test_compile_error_on_non_fakified_method(self, backend):
1329        if backend == "eager":
1330            backend = EagerAndRecordGraphs()
1331
1332        def f(tq, x):
1333            x1 = x.view(-1)
1334            x2 = x.permute(1, 0)
1335            tq.push(x1)
1336            tq.push(x2)
1337            # though real tensor queue implemented a method clone_queue,
1338            # The fakified version doesn't.
1339            flat_obj = tq.clone_queue()
1340            return flat_obj
1341
1342        x = torch.randn(2, 3)
1343        with self.assertRaisesRegex(
1344            RuntimeError, "FakeScriptObject doesn't define method"
1345        ):
1346            torch.compile(f, backend=backend)(_empty_tensor_queue(), x)
1347
1348    @parametrize("backend", ["eager", "aot_eager"])
1349    def test_compile_obj_as_hop_input(self, backend):
1350        def f(tq, x):
1351            def fn(tq, x):
1352                tq.push(x)
1353                return x.sin()
1354
1355            return wrap(fn, tq, x)
1356
1357        x = torch.randn(2, 3)
1358        _assertEqualScriptObject(
1359            self,
1360            f(_empty_tensor_queue(), x),
1361            torch.compile(f, backend=backend)(_empty_tensor_queue(), x),
1362        )
1363
1364    @parametrize("backend", ["eager", "aot_eager"])
1365    def test_compile_obj_closure(self, backend):
1366        def f(x):
1367            def inner_f(x):
1368                tq.push(x.sin())
1369
1370            inner_f(x)
1371            return tq.pop(), tq
1372
1373        opt_f = torch.compile(f, backend="eager")
1374
1375        tq = _empty_tensor_queue()
1376        x = torch.randn(3, 2)
1377        _assertEqualScriptObject(self, f(x), opt_f(x))
1378
1379    @parametrize("backend", ["eager", "aot_eager"])
1380    def test_compile_global_obj(self, backend):
1381        global _TENSOR_QUEUE_GLOBAL_TEST
1382        _TENSOR_QUEUE_GLOBAL_TEST = _empty_tensor_queue()
1383
1384        def f(x):
1385            _TENSOR_QUEUE_GLOBAL_TEST.push(x.sin())
1386            return _TENSOR_QUEUE_GLOBAL_TEST.pop(), _TENSOR_QUEUE_GLOBAL_TEST
1387
1388        opt_f = torch.compile(f, backend=backend)
1389        x = torch.randn(3, 2)
1390        eager_ret = f(x)
1391        opt_ret = opt_f(x)
1392        _assertEqualScriptObject(self, eager_ret, opt_ret)
1393
1394    def test_compile_obj_graph_breaks(self):
1395        cnt = torch._dynamo.testing.CompileCounter()
1396
1397        def f(tq, x):
1398            tq.push(x.sin())
1399            tq.push(x.sin())
1400            torch._dynamo.graph_break()
1401            tq.pop()
1402            torch._dynamo.graph_break()
1403            tq.push(x.cos() + tq.size())
1404            torch._dynamo.graph_break()
1405            tq.push(x.cos() - tq.size())
1406            return x, tq.pop(), tq
1407
1408        opt_f = torch.compile(f, backend=cnt)
1409        x = torch.randn(3, 2)
1410        _assertEqualScriptObject(
1411            self, f(_empty_tensor_queue(), x), opt_f(_empty_tensor_queue(), x)
1412        )
1413        self.assertEqual(cnt.frame_count, 4)
1414
1415    @parametrize("backend", ["eager", "aot_eager"])
1416    def test_compile_obj_attributes(self, backend):
1417        if backend == "eager":
1418            backend = EagerAndRecordGraphs()
1419
1420        class Model(torch.nn.Module):
1421            def __init__(self) -> None:
1422                super().__init__()
1423                self.tq = _empty_tensor_queue()
1424
1425            def forward(self, x):
1426                self.tq.push(x)
1427                return self.tq.pop()
1428
1429        x = torch.randn(2, 3)
1430        opt_f = torch.compile(Model(), backend=backend)
1431        _assertEqualScriptObject(self, Model()(x), opt_f(x))
1432        if backend == "eager":
1433            self.assertEqual(len(backend.graphs), 1)
1434            # lifted as input. In the future, we would want to cosolidate this
1435            # with non-strict behavior, where they're set as attributes.
1436            self.assertExpectedInline(
1437                backend.graphs[0].code.strip(),
1438                """\
1439    def forward(self, L_self_tq : torch.ScriptObject, L_x_ : torch.Tensor):
1440        l_self_tq = L_self_tq
1441        l_x_ = L_x_
1442        call_torchbind = torch.ops.higher_order.call_torchbind(l_self_tq, 'push', l_x_);  l_x_ = None
1443        call_torchbind_1 = torch.ops.higher_order.call_torchbind(l_self_tq, 'pop');  l_self_tq = None
1444        return (call_torchbind_1,)""",
1445            )
1446
1447    @parametrize("backend", ["eager", "aot_eager"])
1448    def test_compile_obj_torchbind_op(self, backend):
1449        def f(tq, x):
1450            torch.ops._TorchScriptTesting.queue_push(tq, x.cos())
1451            torch.ops._TorchScriptTesting.queue_push(tq, x.cos() + 1)
1452            torch.ops._TorchScriptTesting.queue_pop(tq)
1453            torch.ops._TorchScriptTesting.queue_push(tq, x.sin())
1454            return tq.pop(), tq.pop() + tq.size(), tq
1455
1456        opt_f = torch.compile(f, backend=backend)
1457        x = torch.randn(2)
1458        _assertEqualScriptObject(
1459            self, f(_empty_tensor_queue(), x), opt_f(_empty_tensor_queue(), x)
1460        )
1461
1462
1463@skipIfTorchDynamo("torchbind not supported with dynamo yet")
1464class TestRegisterFakeClass(TestCase):
1465    def setUp(self):
1466        init_torchbind_implementations()
1467
1468    def tearDown(self):
1469        torch._library.fake_class_registry.global_fake_class_registry.clear()
1470
1471    def test_register_fake_class_no_torch_bind_class(self):
1472        with self.assertRaisesRegex(RuntimeError, "Tried to instantiate class"):
1473
1474            @torch._library.register_fake_class("_TorchScriptTesting::NOT_A_VALID_NAME")
1475            class Invalid:
1476                pass
1477
1478    def test_register_fake_class_no_from_real(self):
1479        with self.assertRaisesRegex(
1480            RuntimeError, "define a classmethod __obj_unflatten__"
1481        ):
1482
1483            @torch._library.register_fake_class("_TorchScriptTesting::_Foo")
1484            class InvalidFakeFoo:
1485                def __init__(self) -> None:
1486                    pass
1487
1488    def test_register_fake_class_from_real_not_classmethod(self):
1489        with self.assertRaisesRegex(RuntimeError, "is not a classmethod"):
1490
1491            @torch._library.register_fake_class("_TorchScriptTesting::_Foo")
1492            class FakeFoo:
1493                def __init__(self, x, y):
1494                    self.x = x
1495                    self.y = y
1496
1497                def __obj_unflatten__(cls, flattend_foo):  # noqa: B902
1498                    return cls(**dict(flattend_foo))
1499
1500    def test_register_fake_class_valid(self):
1501        class FakeFoo:
1502            def __init__(self, x, y):
1503                self.x = x
1504                self.y = y
1505
1506            @classmethod
1507            def __obj_unflatten__(cls, flattend_foo):
1508                return cls(**dict(flattend_foo))
1509
1510        torch._library.register_fake_class("_TorchScriptTesting::_Foo", FakeFoo)
1511
1512
1513instantiate_parametrized_tests(TestExportTorchbind)
1514instantiate_parametrized_tests(TestCompileTorchbind)
1515
1516if __name__ == "__main__":
1517    run_tests()
1518