1# Owner(s): ["oncall: export"] 2 3 4import unittest 5 6import torch 7import torch.utils._pytree as pytree 8from torch._dynamo.testing import EagerAndRecordGraphs 9from torch._functorch.aot_autograd import aot_export_module 10from torch._higher_order_ops.torchbind import enable_torchbind_tracing 11from torch._higher_order_ops.wrap import wrap 12from torch._library.fake_class_registry import FakeScriptObject 13from torch.export import export 14from torch.export._trace import _export 15from torch.fx.experimental.proxy_tensor import make_fx 16from torch.testing._internal.common_utils import ( 17 instantiate_parametrized_tests, 18 parametrize, 19 run_tests, 20 skipIfTorchDynamo, 21 TestCase, 22) 23from torch.testing._internal.torchbind_impls import ( 24 _empty_tensor_queue, 25 init_torchbind_implementations, 26) 27 28 29def _assertEqualSkipScriptObject(test_case, exp, actual): 30 flat_exp = pytree.tree_leaves(exp) 31 flat_actual = pytree.tree_leaves(actual) 32 test_case.assertEqual(len(flat_exp), len(flat_actual)) 33 for a, b in zip(flat_exp, flat_actual): 34 if isinstance(a, torch.ScriptObject) and isinstance(b, torch.ScriptObject): 35 continue 36 test_case.assertEqual(a, b) 37 38 39def _check_script_obj_equal(test_case, a: torch.ScriptObject, b: torch.ScriptObject): 40 return test_case.assertEqual( 41 a._type().qualified_name(), b._type().qualified_name() 42 ) and test_case.assertEqual(a.__obj_flatten__(), b.__obj_flatten__()) 43 44 45def _assertEqualScriptObject( 46 test_case, exp, actual, check_obj_eq=_check_script_obj_equal 47): 48 flat_exp = pytree.tree_leaves(exp) 49 flat_actual = pytree.tree_leaves(actual) 50 test_case.assertEqual(len(flat_exp), len(flat_actual)) 51 for a, b in zip(flat_exp, flat_actual): 52 if isinstance(a, torch.ScriptObject) and isinstance(b, torch.ScriptObject): 53 check_obj_eq(test_case, a, b) 54 else: 55 test_case.assertEqual(a, b) 56 57 58@skipIfTorchDynamo("torchbind not supported with dynamo yet") 59class TestExportTorchbind(TestCase): 60 def setUp(self): 61 init_torchbind_implementations() 62 63 test = self 64 test.tq_push_counter = 0 65 test.tq_pop_counter = 0 66 test.tq_size_counter = 0 67 test.foo_add_tensor_counter = 0 68 69 @torch._library.register_fake_class("_TorchScriptTesting::_Foo") 70 class FakeFoo: 71 def __init__(self, x: int, y: int): 72 self.x = x 73 self.y = y 74 75 @classmethod 76 def __obj_unflatten__(cls, flattend_foo): 77 return cls(**dict(flattend_foo)) 78 79 def add_tensor(self, z): 80 test.foo_add_tensor_counter += 1 81 return (self.x + self.y) * z 82 83 @torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue") 84 class FakeTensorQueue: 85 def __init__(self, queue): 86 self.queue = queue 87 88 @classmethod 89 def __obj_unflatten__(cls, flattened_ctx): 90 return cls(**dict(flattened_ctx)) 91 92 def push(self, x): 93 test.tq_push_counter += 1 94 self.queue.append(x) 95 96 def pop(self): 97 test.tq_pop_counter += 1 98 return self.queue.pop(0) 99 100 def size(self): 101 test.tq_size_counter += 1 102 return len(self.queue) 103 104 def is_empty(self): 105 return len(self.queue) == 0 106 107 def float_size(self): 108 return float(len(self.queue)) 109 110 self.torch_bind_ops = [ 111 torch.ops._TorchScriptTesting.takes_foo, 112 torch.ops._TorchScriptTesting.takes_foo_python_meta, 113 torch.ops._TorchScriptTesting.takes_foo_list_return, 114 torch.ops._TorchScriptTesting.takes_foo_tuple_return, 115 torch.ops._TorchScriptTesting.take_an_instance, 116 torch.ops._TorchScriptTesting.take_an_instance_inferred, 117 torch.ops._TorchScriptTesting.takes_foo_cia, 118 torch.ops._TorchScriptTesting.queue_pop, 119 torch.ops._TorchScriptTesting.queue_push, 120 torch.ops._TorchScriptTesting.queue_size, 121 ] 122 123 def tearDown(self): 124 torch._library.fake_class_registry.deregister_fake_class( 125 "_TorchScriptTesting::_Foo" 126 ) 127 torch._library.fake_class_registry.deregister_fake_class( 128 "_TorchScriptTesting::_TensorQueue" 129 ) 130 131 def _test_export_same_as_eager( 132 self, f, args, kwargs=None, strict=True, pre_dispatch=False 133 ): 134 kwargs = kwargs or {} 135 136 def export_wrapper(f, args, kwargs, strcit, pre_dispatch): 137 with enable_torchbind_tracing(): 138 if pre_dispatch: 139 exported_program = _export( 140 f, args, kwargs, strict=strict, pre_dispatch=True 141 ) 142 else: 143 exported_program = export(f, args, kwargs, strict=strict) 144 return exported_program 145 146 exported_program = export_wrapper(f, args, kwargs, strict, pre_dispatch) 147 reversed_kwargs = {key: kwargs[key] for key in reversed(kwargs)} 148 unlifted = exported_program.module() 149 exp = f(*args, **kwargs) 150 _assertEqualScriptObject(self, unlifted(*args, **kwargs), exp) 151 _assertEqualScriptObject( 152 self, 153 unlifted(*args, **reversed_kwargs), 154 exp, 155 ) 156 157 # check re-tracing 158 retraced_ep = export_wrapper(unlifted, args, kwargs, strict, pre_dispatch) 159 _assertEqualScriptObject(self, retraced_ep.module()(*args, **kwargs), exp) 160 return exported_program 161 162 @parametrize("pre_dispatch", [True, False]) 163 def test_none(self, pre_dispatch): 164 class MyModule(torch.nn.Module): 165 def __init__(self) -> None: 166 super().__init__() 167 self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) 168 169 def forward(self, x, n): 170 return x + self.attr.add_tensor(x) 171 172 ep = self._test_export_same_as_eager( 173 MyModule(), 174 (torch.ones(2, 3), None), 175 strict=False, 176 pre_dispatch=pre_dispatch, 177 ) 178 self.assertExpectedInline( 179 ep.module().code.strip(), 180 """\ 181def forward(self, x, n): 182 x, n, = fx_pytree.tree_flatten_spec(([x, n], {}), self._in_spec) 183 attr = self.attr 184 call_torchbind = torch.ops.higher_order.call_torchbind(attr, 'add_tensor', x); attr = None 185 add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None 186 return pytree.tree_unflatten((add,), self._out_spec)""", 187 ) 188 self.assertExpectedInline( 189 ep.graph_module.code.strip(), 190 """\ 191def forward(self, token, obj_attr, x, n): 192 with_effects = torch.ops.higher_order.with_effects(token, torch.ops.higher_order.call_torchbind, obj_attr, 'add_tensor', x); token = obj_attr = None 193 getitem = with_effects[0] 194 getitem_1 = with_effects[1]; with_effects = None 195 add = torch.ops.aten.add.Tensor(x, getitem_1); x = getitem_1 = None 196 return (getitem, add)""", # noqa: B950 197 ) 198 199 def test_method_schema(self): 200 tq = _empty_tensor_queue() 201 fake_mode = torch._subclasses.fake_tensor.FakeTensorMode() 202 fake_obj = torch._library.fake_class_registry.maybe_to_fake_obj(fake_mode, tq) 203 self.assertExpectedInline( 204 str(fake_obj.push.schema), 205 """push(__torch__.torch.classes._TorchScriptTesting._TensorQueue _0, Tensor _1) -> NoneType _0""", 206 ) 207 self.assertExpectedInline( 208 str(fake_obj.pop.schema), 209 """pop(__torch__.torch.classes._TorchScriptTesting._TensorQueue _0) -> Tensor _0""", 210 ) 211 212 @parametrize("pre_dispatch", [True, False]) 213 def test_attribute(self, pre_dispatch): 214 class MyModule(torch.nn.Module): 215 def __init__(self) -> None: 216 super().__init__() 217 self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) 218 219 def forward(self, x): 220 return x + self.attr.add_tensor(x) 221 222 ep = self._test_export_same_as_eager( 223 MyModule(), (torch.ones(2, 3),), strict=False, pre_dispatch=pre_dispatch 224 ) 225 self.assertExpectedInline( 226 ep.module().code.strip(), 227 """\ 228def forward(self, x): 229 x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) 230 attr = self.attr 231 call_torchbind = torch.ops.higher_order.call_torchbind(attr, 'add_tensor', x); attr = None 232 add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None 233 return pytree.tree_unflatten((add,), self._out_spec)""", 234 ) 235 self.assertExpectedInline( 236 ep.graph_module.code.strip(), 237 """\ 238def forward(self, token, obj_attr, x): 239 with_effects = torch.ops.higher_order.with_effects(token, torch.ops.higher_order.call_torchbind, obj_attr, 'add_tensor', x); token = obj_attr = None 240 getitem = with_effects[0] 241 getitem_1 = with_effects[1]; with_effects = None 242 add = torch.ops.aten.add.Tensor(x, getitem_1); x = getitem_1 = None 243 return (getitem, add)""", # noqa: B950 244 ) 245 246 @parametrize("pre_dispatch", [True, False]) 247 def test_attribute_as_custom_op_argument(self, pre_dispatch): 248 class MyModule(torch.nn.Module): 249 def __init__(self) -> None: 250 super().__init__() 251 self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) 252 253 def forward(self, x): 254 return x + torch.ops._TorchScriptTesting.takes_foo(self.attr, x) 255 256 ep = self._test_export_same_as_eager( 257 MyModule(), (torch.ones(2, 3),), strict=False, pre_dispatch=pre_dispatch 258 ) 259 self.assertExpectedInline( 260 ep.module().code.strip(), 261 """\ 262def forward(self, x): 263 x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) 264 attr = self.attr 265 takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, x); attr = None 266 add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None 267 return pytree.tree_unflatten((add,), self._out_spec)""", 268 ) 269 self.assertExpectedInline( 270 ep.graph_module.code.strip(), 271 """\ 272def forward(self, token, obj_attr, x): 273 with_effects = torch.ops.higher_order.with_effects(token, torch.ops._TorchScriptTesting.takes_foo.default, obj_attr, x); token = obj_attr = None 274 getitem = with_effects[0] 275 getitem_1 = with_effects[1]; with_effects = None 276 add = torch.ops.aten.add.Tensor(x, getitem_1); x = getitem_1 = None 277 return (getitem, add)""", # noqa: B950 278 ) 279 280 @parametrize("pre_dispatch", [True, False]) 281 def test_input(self, pre_dispatch): 282 cc = torch.classes._TorchScriptTesting._Foo(10, 20) 283 284 class MyModule(torch.nn.Module): 285 def __init__(self) -> None: 286 super().__init__() 287 288 def forward(self, x, cc): 289 return x + cc.add_tensor(x) 290 291 ep = self._test_export_same_as_eager( 292 MyModule(), (torch.ones(2, 3), cc), strict=False, pre_dispatch=pre_dispatch 293 ) 294 self.assertExpectedInline( 295 ep.module().code.strip(), 296 """\ 297def forward(self, x, cc): 298 x, cc, = fx_pytree.tree_flatten_spec(([x, cc], {}), self._in_spec) 299 call_torchbind = torch.ops.higher_order.call_torchbind(cc, 'add_tensor', x); cc = None 300 add = torch.ops.aten.add.Tensor(x, call_torchbind); x = call_torchbind = None 301 return pytree.tree_unflatten((add,), self._out_spec)""", 302 ) 303 self.assertExpectedInline( 304 ep.graph_module.code.strip(), 305 """\ 306def forward(self, token, x, cc): 307 with_effects = torch.ops.higher_order.with_effects(token, torch.ops.higher_order.call_torchbind, cc, 'add_tensor', x); token = cc = None 308 getitem = with_effects[0] 309 getitem_1 = with_effects[1]; with_effects = None 310 add = torch.ops.aten.add.Tensor(x, getitem_1); x = getitem_1 = None 311 return (getitem, add)""", # noqa: B950 312 ) 313 # aot_export_function runs the program twice 314 # in run_functionalized_fw_and_collect_metadata and create_aot_dispatcher_function 315 # We also have a re-tracing test, which doubles the count. 316 self.assertEqual(self.foo_add_tensor_counter, 4) 317 318 @parametrize("pre_dispatch", [True, False]) 319 def test_input_as_custom_op_argument(self, pre_dispatch): 320 cc = torch.classes._TorchScriptTesting._Foo(10, 20) 321 322 class MyModule(torch.nn.Module): 323 def __init__(self) -> None: 324 super().__init__() 325 326 def forward(self, x, cc): 327 return x + torch.ops._TorchScriptTesting.takes_foo(cc, x) 328 329 del torch.ops._TorchScriptTesting.takes_foo.default.py_kernels[ 330 torch._C.DispatchKey.Meta 331 ] 332 torch.ops._TorchScriptTesting.takes_foo.default._dispatch_cache.clear() 333 # Even though a C++ implementation for takes_foo.default is registered, 334 # we still need the python implementation for takes_foo.default to trace with FakeFoo. 335 with self.assertRaisesRegex(RuntimeError, "no python implementation is found"): 336 self._test_export_same_as_eager( 337 MyModule(), 338 (torch.ones(2, 3), cc), 339 strict=False, 340 pre_dispatch=pre_dispatch, 341 ) 342 343 torch.ops._TorchScriptTesting.takes_foo.default.py_impl( 344 torch._C.DispatchKey.Meta 345 )(lambda cc, x: cc.add_tensor(x)) 346 ep = self._test_export_same_as_eager( 347 MyModule(), 348 (torch.ones(2, 3), cc), 349 strict=False, 350 pre_dispatch=pre_dispatch, 351 ) 352 353 self.assertExpectedInline( 354 ep.module().code.strip(), 355 """\ 356def forward(self, x, cc): 357 x, cc, = fx_pytree.tree_flatten_spec(([x, cc], {}), self._in_spec) 358 takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(cc, x); cc = None 359 add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None 360 return pytree.tree_unflatten((add,), self._out_spec)""", 361 ) 362 self.assertExpectedInline( 363 ep.graph_module.code.strip(), 364 """\ 365def forward(self, token, x, cc): 366 with_effects = torch.ops.higher_order.with_effects(token, torch.ops._TorchScriptTesting.takes_foo.default, cc, x); token = cc = None 367 getitem = with_effects[0] 368 getitem_1 = with_effects[1]; with_effects = None 369 add = torch.ops.aten.add.Tensor(x, getitem_1); x = getitem_1 = None 370 return (getitem, add)""", # noqa: B950 371 ) 372 373 @parametrize("pre_dispatch", [True, False]) 374 def test_torchbind_alias(self, pre_dispatch): 375 class F2(torch.nn.Module): 376 def __init__(self, foo): 377 super().__init__() 378 self.foo = foo 379 380 def forward(self, x): 381 return x + torch.ops._TorchScriptTesting.takes_foo(self.foo, x) 382 383 class F1(torch.nn.Module): 384 def __init__(self) -> None: 385 super().__init__() 386 self.alpha = torch.classes._TorchScriptTesting._Foo(10, 20) 387 self.beta = self.alpha 388 self.gamma = self.alpha 389 self.foo = F2(self.gamma) 390 391 def forward(self, x): 392 return ( 393 x 394 + torch.ops._TorchScriptTesting.takes_foo(self.gamma, x) 395 + self.foo(x) 396 ) 397 398 self._test_export_same_as_eager( 399 F1(), (torch.ones(2, 3),), strict=False, pre_dispatch=pre_dispatch 400 ) 401 402 # TODO(pianpwk): look into this 403 @unittest.expectedFailure 404 @parametrize("pre_dispatch", [True, False]) 405 def test_torchbind_input_and_alias(self, pre_dispatch): 406 # alias as model attribute 407 class F3(torch.nn.Module): 408 def forward(self, x, foo): 409 self.foo = foo 410 return x + self.foo.add_tensor(x) 411 412 foo = torch.classes._TorchScriptTesting._Foo(10, 20) 413 self._test_export_same_as_eager( 414 F3(), (torch.ones(2, 3), foo), strict=False, pre_dispatch=pre_dispatch 415 ) 416 417 @parametrize("pre_dispatch", [True, False]) 418 def test_unlift_custom_obj(self, pre_dispatch): 419 class MyModule(torch.nn.Module): 420 def __init__(self) -> None: 421 super().__init__() 422 self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) 423 424 def forward(self, x): 425 a = torch.ops._TorchScriptTesting.takes_foo(self.attr, x) 426 b = torch.ops._TorchScriptTesting.takes_foo(self.attr, a) 427 return x + b 428 429 input = torch.ones(2, 3) 430 ep = self._test_export_same_as_eager( 431 MyModule(), (input,), strict=False, pre_dispatch=pre_dispatch 432 ) 433 self.assertExpectedInline( 434 ep.module().code.strip(), 435 """\ 436def forward(self, x): 437 x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) 438 attr = self.attr 439 takes_foo_default_1 = torch.ops._TorchScriptTesting.takes_foo.default(attr, x) 440 takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, takes_foo_default_1); attr = takes_foo_default_1 = None 441 add = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None 442 return pytree.tree_unflatten((add,), self._out_spec)""", # noqa: B950 443 ) 444 self.assertExpectedInline( 445 ep.graph_module.code.strip(), 446 """\ 447def forward(self, token, obj_attr, x): 448 with_effects = torch.ops.higher_order.with_effects(token, torch.ops._TorchScriptTesting.takes_foo.default, obj_attr, x); token = None 449 getitem = with_effects[0] 450 getitem_1 = with_effects[1]; with_effects = None 451 with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._TorchScriptTesting.takes_foo.default, obj_attr, getitem_1); getitem = obj_attr = getitem_1 = None 452 getitem_2 = with_effects_1[0] 453 getitem_3 = with_effects_1[1]; with_effects_1 = None 454 add = torch.ops.aten.add.Tensor(x, getitem_3); x = getitem_3 = None 455 return (getitem_2, add)""", # noqa: B950 456 ) 457 458 @parametrize("pre_dispatch", [True, False]) 459 def test_custom_obj_list_out(self, pre_dispatch): 460 class MyModule(torch.nn.Module): 461 def __init__(self) -> None: 462 super().__init__() 463 self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) 464 465 def forward(self, x): 466 a = torch.ops._TorchScriptTesting.takes_foo_list_return(self.attr, x) 467 y = a[0] + a[1] + a[2] 468 b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y) 469 return x + b 470 471 input = torch.ones(2, 3) 472 ep = self._test_export_same_as_eager( 473 MyModule(), (input,), strict=False, pre_dispatch=pre_dispatch 474 ) 475 self.assertExpectedInline( 476 ep.module().code.strip(), 477 """\ 478def forward(self, x): 479 x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) 480 attr = self.attr 481 takes_foo_list_return_default = torch.ops._TorchScriptTesting.takes_foo_list_return.default(attr, x) 482 getitem_2 = takes_foo_list_return_default[0] 483 getitem_3 = takes_foo_list_return_default[1] 484 getitem_4 = takes_foo_list_return_default[2]; takes_foo_list_return_default = None 485 add = torch.ops.aten.add.Tensor(getitem_2, getitem_3); getitem_2 = getitem_3 = None 486 add_1 = torch.ops.aten.add.Tensor(add, getitem_4); add = getitem_4 = None 487 takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, add_1); attr = add_1 = None 488 add_2 = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None 489 return pytree.tree_unflatten((add_2,), self._out_spec)""", 490 ) 491 self.assertExpectedInline( 492 ep.graph_module.code.strip(), 493 """\ 494def forward(self, token, obj_attr, x): 495 with_effects = torch.ops.higher_order.with_effects(token, torch.ops._TorchScriptTesting.takes_foo_list_return.default, obj_attr, x); token = None 496 getitem = with_effects[0] 497 getitem_1 = with_effects[1]; with_effects = None 498 getitem_2 = getitem_1[0] 499 getitem_3 = getitem_1[1] 500 getitem_4 = getitem_1[2]; getitem_1 = None 501 add = torch.ops.aten.add.Tensor(getitem_2, getitem_3); getitem_2 = getitem_3 = None 502 add_1 = torch.ops.aten.add.Tensor(add, getitem_4); add = getitem_4 = None 503 with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._TorchScriptTesting.takes_foo.default, obj_attr, add_1); getitem = obj_attr = add_1 = None 504 getitem_5 = with_effects_1[0] 505 getitem_6 = with_effects_1[1]; with_effects_1 = None 506 add_2 = torch.ops.aten.add.Tensor(x, getitem_6); x = getitem_6 = None 507 return (getitem_5, add_2)""", # noqa: B950 508 ) 509 510 @parametrize("pre_dispatch", [True, False]) 511 def test_custom_obj_tuple_out(self, pre_dispatch): 512 class MyModule(torch.nn.Module): 513 def __init__(self) -> None: 514 super().__init__() 515 self.attr = torch.classes._TorchScriptTesting._Foo(10, 20) 516 517 def forward(self, x): 518 a = torch.ops._TorchScriptTesting.takes_foo_tuple_return(self.attr, x) 519 y = a[0] + a[1] 520 b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y) 521 return x + b 522 523 input = torch.ones(2, 3) 524 ep = self._test_export_same_as_eager( 525 MyModule(), (input,), strict=False, pre_dispatch=pre_dispatch 526 ) 527 self.assertExpectedInline( 528 ep.module().code.strip(), 529 """\ 530def forward(self, x): 531 x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec) 532 attr = self.attr 533 takes_foo_tuple_return_default = torch.ops._TorchScriptTesting.takes_foo_tuple_return.default(attr, x) 534 getitem_1 = takes_foo_tuple_return_default[0] 535 getitem_2 = takes_foo_tuple_return_default[1]; takes_foo_tuple_return_default = None 536 add = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = getitem_2 = None 537 takes_foo_default = torch.ops._TorchScriptTesting.takes_foo.default(attr, add); attr = add = None 538 add_1 = torch.ops.aten.add.Tensor(x, takes_foo_default); x = takes_foo_default = None 539 return pytree.tree_unflatten((add_1,), self._out_spec)""", 540 ) 541 self.assertExpectedInline( 542 ep.graph_module.code.strip(), 543 """\ 544def forward(self, token, obj_attr, x): 545 with_effects = torch.ops.higher_order.with_effects(token, torch.ops._TorchScriptTesting.takes_foo_tuple_return.default, obj_attr, x); token = None 546 getitem = with_effects[0] 547 getitem_1 = with_effects[1] 548 getitem_2 = with_effects[2]; with_effects = None 549 add = torch.ops.aten.add.Tensor(getitem_1, getitem_2); getitem_1 = getitem_2 = None 550 with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._TorchScriptTesting.takes_foo.default, obj_attr, add); getitem = obj_attr = add = None 551 getitem_3 = with_effects_1[0] 552 getitem_4 = with_effects_1[1]; with_effects_1 = None 553 add_1 = torch.ops.aten.add.Tensor(x, getitem_4); x = getitem_4 = None 554 return (getitem_3, add_1)""", # noqa: B950 555 ) 556 557 @parametrize("make_fx_tracing_mode", ["fake", "symbolic"]) 558 def test_make_fx_tensor_queue_methods(self, make_fx_tracing_mode): 559 test = self 560 561 class Model(torch.nn.Module): 562 def __init__(self) -> None: 563 super().__init__() 564 self.linear = torch.nn.Linear(3, 2) 565 self.check_tq_is_fake = True 566 567 def forward(self, tq, x): 568 if self.check_tq_is_fake: 569 test.assertTrue(isinstance(tq, FakeScriptObject)) 570 tq.push(x.cos()) 571 tq.push(x.sin()) 572 x_cos = tq.pop() + tq.size() 573 x_sin = tq.pop() - tq.size() 574 return x_sin, x_cos, tq 575 576 mod = Model() 577 tq = torch.classes._TorchScriptTesting._TensorQueue( 578 torch.empty( 579 0, 580 ).fill_(-1) 581 ) 582 tq1 = torch.classes._TorchScriptTesting._TensorQueue( 583 torch.empty( 584 0, 585 ).fill_(-1) 586 ) 587 x = torch.ones(2, 3) 588 gm = make_fx(mod, tracing_mode=make_fx_tracing_mode)(tq, x) 589 self.assertEqual(self.tq_push_counter, 2) 590 self.assertEqual(self.tq_pop_counter, 2) 591 self.assertEqual(self.tq_size_counter, 2) 592 self.assertEqual(tq.size(), 0) 593 self.assertExpectedInline( 594 gm.code.strip("\n"), 595 """\ 596def forward(self, arg0_1, arg1_1): 597 cos = torch.ops.aten.cos.default(arg1_1) 598 call_torchbind = torch.ops.higher_order.call_torchbind(arg0_1, 'push', cos); cos = call_torchbind = None 599 sin = torch.ops.aten.sin.default(arg1_1); arg1_1 = None 600 call_torchbind_1 = torch.ops.higher_order.call_torchbind(arg0_1, 'push', sin); sin = call_torchbind_1 = None 601 call_torchbind_2 = torch.ops.higher_order.call_torchbind(arg0_1, 'pop') 602 call_torchbind_3 = torch.ops.higher_order.call_torchbind(arg0_1, 'size'); call_torchbind_3 = None 603 add = torch.ops.aten.add.Tensor(call_torchbind_2, 1); call_torchbind_2 = None 604 call_torchbind_4 = torch.ops.higher_order.call_torchbind(arg0_1, 'pop') 605 call_torchbind_5 = torch.ops.higher_order.call_torchbind(arg0_1, 'size'); call_torchbind_5 = None 606 sub = torch.ops.aten.sub.Tensor(call_torchbind_4, 0); call_torchbind_4 = None 607 return (sub, add, arg0_1) 608 """, 609 ) 610 mod.check_tq_is_fake = False 611 _assertEqualSkipScriptObject(self, gm(tq, x), mod(tq1, x)) 612 613 @parametrize("make_fx_tracing_mode", ["fake", "symbolic"]) 614 def test_make_fx_tensor_queue_methods_fakify_internal_states( 615 self, make_fx_tracing_mode 616 ): 617 test = self 618 619 class Model(torch.nn.Module): 620 def __init__(self) -> None: 621 super().__init__() 622 self.linear = torch.nn.Linear(3, 2) 623 self.check_tq_is_fake = True 624 self.current_test = test 625 626 def forward(self, tq, x): 627 if self.check_tq_is_fake: 628 self.current_test.assertTrue(isinstance(tq, FakeScriptObject)) 629 x_cos = tq.pop() + tq.size() + x 630 x_sin = tq.pop() - tq.size() + x 631 return x_sin, x_cos, tq 632 633 mod = Model() 634 tq = torch.classes._TorchScriptTesting._TensorQueue( 635 torch.empty( 636 0, 637 ).fill_(-1) 638 ) 639 tq1 = torch.classes._TorchScriptTesting._TensorQueue( 640 torch.empty( 641 0, 642 ).fill_(-1) 643 ) 644 for _ in range(2): 645 tq.push(torch.ones(2, 3)) 646 tq1.push(torch.ones(2, 3)) 647 x = torch.ones(2, 3) 648 prev_size = tq.size() 649 gm = make_fx(mod, tracing_mode=make_fx_tracing_mode)(tq, x) 650 self.assertEqual(self.tq_push_counter, 0) 651 self.assertEqual(self.tq_pop_counter, 2) 652 self.assertEqual(self.tq_size_counter, 2) 653 self.assertEqual(tq.size(), prev_size) 654 self.assertExpectedInline( 655 gm.code.strip("\n"), 656 """\ 657def forward(self, arg0_1, arg1_1): 658 call_torchbind = torch.ops.higher_order.call_torchbind(arg0_1, 'pop') 659 call_torchbind_1 = torch.ops.higher_order.call_torchbind(arg0_1, 'size'); call_torchbind_1 = None 660 add = torch.ops.aten.add.Tensor(call_torchbind, 1); call_torchbind = None 661 add_1 = torch.ops.aten.add.Tensor(add, arg1_1); add = None 662 call_torchbind_2 = torch.ops.higher_order.call_torchbind(arg0_1, 'pop') 663 call_torchbind_3 = torch.ops.higher_order.call_torchbind(arg0_1, 'size'); call_torchbind_3 = None 664 sub = torch.ops.aten.sub.Tensor(call_torchbind_2, 0); call_torchbind_2 = None 665 add_2 = torch.ops.aten.add.Tensor(sub, arg1_1); sub = arg1_1 = None 666 return (add_2, add_1, arg0_1) 667 """, 668 ) 669 # turn off tq type checking in eager execution 670 mod.check_tq_is_fake = False 671 _assertEqualSkipScriptObject(self, gm(tq, x), mod(tq1, x)) 672 self.assertEqual(tq.size(), 0) 673 self.assertEqual(tq1.size(), 0) 674 675 def test_non_strict_export_methods(self): 676 class Model(torch.nn.Module): 677 def __init__(self) -> None: 678 super().__init__() 679 self.linear = torch.nn.Linear(2, 2) 680 681 def forward(self, tq, x): 682 x_cos = tq.pop() + tq.float_size() + self.linear(x) 683 if tq.is_empty(): 684 x_sin = self.linear(tq.pop()) - tq.size() + x 685 else: 686 x_sin = tq.pop() + tq.size() + x 687 return x_sin, x_cos, tq 688 689 mod = Model() 690 tq = _empty_tensor_queue() 691 a = torch.randn(2, 2) 692 b = torch.randn(2, 2) 693 tq.push(a) 694 tq.push(b) 695 ep = torch.export.export(mod, (tq, torch.randn(2, 2)), strict=False) 696 self.assertExpectedInline( 697 ep.graph_module.code.strip(), 698 """\ 699def forward(self, token, p_linear_weight, p_linear_bias, tq, x): 700 with_effects = torch.ops.higher_order.with_effects(token, torch.ops.higher_order.call_torchbind, tq, 'pop'); token = None 701 getitem = with_effects[0] 702 getitem_1 = with_effects[1]; with_effects = None 703 with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.higher_order.call_torchbind, tq, 'float_size'); getitem = None 704 getitem_2 = with_effects_1[0]; with_effects_1 = None 705 add = torch.ops.aten.add.Tensor(getitem_1, 1.0); getitem_1 = None 706 linear = torch.ops.aten.linear.default(x, p_linear_weight, p_linear_bias); p_linear_weight = p_linear_bias = None 707 add_1 = torch.ops.aten.add.Tensor(add, linear); add = linear = None 708 with_effects_2 = torch.ops.higher_order.with_effects(getitem_2, torch.ops.higher_order.call_torchbind, tq, 'is_empty'); getitem_2 = None 709 getitem_4 = with_effects_2[0]; with_effects_2 = None 710 with_effects_3 = torch.ops.higher_order.with_effects(getitem_4, torch.ops.higher_order.call_torchbind, tq, 'pop'); getitem_4 = None 711 getitem_6 = with_effects_3[0] 712 getitem_7 = with_effects_3[1]; with_effects_3 = None 713 with_effects_4 = torch.ops.higher_order.with_effects(getitem_6, torch.ops.higher_order.call_torchbind, tq, 'size'); getitem_6 = None 714 getitem_8 = with_effects_4[0]; with_effects_4 = None 715 add_2 = torch.ops.aten.add.Tensor(getitem_7, 0); getitem_7 = None 716 add_3 = torch.ops.aten.add.Tensor(add_2, x); add_2 = x = None 717 return (getitem_8, add_3, add_1, tq)""", # noqa: B950 718 ) 719 self.assertEqual(tq.size(), 2) 720 self.assertTrue(tq.pop() is a) 721 self.assertTrue(tq.pop() is b) 722 723 def test_safe_to_trace_with_real(self): 724 x = torch.randn(3, 3) 725 safe_obj = torch.classes._TorchScriptTesting._ConstantTensorContainer(x) 726 727 class Mod(torch.nn.Module): 728 def forward(self, safe_obj: torch.ScriptObject) -> None: 729 return safe_obj.get().sin() 730 731 mod = Mod() 732 backend = EagerAndRecordGraphs() 733 out = torch.compile(mod, backend=backend, fullgraph=True)(safe_obj) 734 self.assertEqual(out, mod(safe_obj)) 735 self.assertExpectedInline( 736 backend.graphs[0].code.strip(), 737 """\ 738def forward(self, L_safe_obj_ : torch.ScriptObject): 739 l_safe_obj_ = L_safe_obj_ 740 call_torchbind = torch.ops.higher_order.call_torchbind(l_safe_obj_, 'get'); l_safe_obj_ = None 741 sin = call_torchbind.sin(); call_torchbind = None 742 return (sin,)""", 743 ) 744 745 with enable_torchbind_tracing(): 746 ep = torch.export.export(mod, (safe_obj,), strict=False) 747 self.assertExpectedInline( 748 ep.graph_module.code.strip(), 749 """\ 750def forward(self, token, safe_obj): 751 with_effects = torch.ops.higher_order.with_effects(token, torch.ops.higher_order.call_torchbind, safe_obj, 'get'); token = safe_obj = None 752 getitem = with_effects[0] 753 getitem_1 = with_effects[1]; with_effects = None 754 sin = torch.ops.aten.sin.default(getitem_1); getitem_1 = None 755 return (getitem, sin)""", # noqa: B950 756 ) 757 758 def test_identifying_torchbind_ops(self): 759 for op in self.torch_bind_ops: 760 self.assertTrue(op._has_torchbind_op_overload) 761 762 for op in [ 763 torch.ops.aten.add, 764 torch.ops.aten.cos, 765 ]: 766 self.assertFalse(op._has_torchbind_op_overload) 767 768 def test_torchbind_op_register_fallthrough(self): 769 TEST_DISPATCH_KEY = torch._C.DispatchKey.AutocastCPU 770 TEST_DISPATCH_KEY_STR = "AutocastCPU" 771 772 for op_packet in self.torch_bind_ops: 773 op = op_packet.default 774 ns, _ = torch._library.utils.parse_namespace(op_packet._qualified_op_name) 775 with torch.library._scoped_library(ns, "FRAGMENT") as lib: 776 lib.impl( 777 op.name(), torch.library.fallthrough_kernel, TEST_DISPATCH_KEY_STR 778 ) 779 self.assertTrue( 780 torch._C._dispatch_kernel_for_dispatch_key_is_fallthrough( 781 op.name(), TEST_DISPATCH_KEY 782 ) 783 ) 784 785 def test_torchbind_op_fallthrough_keys_respects_lib_impl(self): 786 TEST_DISPATCH_KEY = torch._C.DispatchKey.AutogradCPU 787 TEST_DISPATCH_KEY_STR = "AutogradCPU" 788 789 tested = 0 790 for op_packet in self.torch_bind_ops: 791 op = op_packet.default 792 ns, _ = torch._library.utils.parse_namespace(op_packet._qualified_op_name) 793 if ( 794 not torch._C._dispatch_has_kernel_for_dispatch_key( 795 op.name(), TEST_DISPATCH_KEY 796 ) 797 and TEST_DISPATCH_KEY not in op.py_kernels 798 ): 799 tested += 1 800 with torch.library._scoped_library(ns, "FRAGMENT") as lib: 801 lib.impl( 802 op.name(), lambda *args, **kwargs: args, TEST_DISPATCH_KEY_STR 803 ) 804 self.assertTrue(TEST_DISPATCH_KEY not in op._fallthrough_keys()) 805 806 with torch.library._scoped_library(ns, "FRAGMENT") as lib: 807 lib.impl( 808 op.name(), 809 torch.library.fallthrough_kernel, 810 TEST_DISPATCH_KEY_STR, 811 ) 812 self.assertTrue(TEST_DISPATCH_KEY in op._fallthrough_keys()) 813 self.assertTrue(tested > 0) 814 815 def test_make_fx_schema_checking_script_object(self): 816 class Model(torch.nn.Module): 817 def forward(self, tq, x, foo): 818 torch.ops._TorchScriptTesting.queue_push(foo, x.cos()) 819 return tq 820 821 class ModelCallByKW(torch.nn.Module): 822 def forward(self, tq, x, foo): 823 torch.ops._TorchScriptTesting.queue_push(x=x.cos(), foo=foo) 824 return tq 825 826 mod = Model() 827 modkw = ModelCallByKW() 828 829 foo = torch.classes._TorchScriptTesting._Foo(10, 20) 830 x = torch.ones(3, 3) 831 tq = torch.classes._TorchScriptTesting._TensorQueue( 832 torch.empty( 833 0, 834 ).fill_(-1) 835 ) 836 ns = "_TorchScriptTesting" 837 with torch.library._scoped_library(ns, "FRAGMENT") as lib: 838 op = torch.ops._TorchScriptTesting.queue_push 839 lib.impl(op.__name__, torch.library.fallthrough_kernel, "AutogradCPU") 840 lib.impl(op.__name__, torch.library.fallthrough_kernel, "ADInplaceOrView") 841 lib.impl( 842 op.__name__, 843 torch.library.fallthrough_kernel, 844 "PythonTLSSnapshot", 845 ) 846 847 with self.assertRaisesRegex( 848 RuntimeError, "is expected to be a FakeScriptObject" 849 ): 850 _ = make_fx(mod, tracing_mode="fake")(tq, x, foo) 851 852 with self.assertRaisesRegex( 853 RuntimeError, "is expected to be a FakeScriptObject" 854 ): 855 _ = make_fx(modkw, tracing_mode="fake")(tq, x, foo) 856 857 @parametrize("fallthrough_via", ["lib_impl", "py_impl"]) 858 def test_make_fx_tensor_queue_operators(self, fallthrough_via): 859 class Model(torch.nn.Module): 860 def __init__(self) -> None: 861 super().__init__() 862 863 def forward(self, tq, x): 864 with torch.autocast("cuda", dtype=torch.bfloat16): 865 torch.ops._TorchScriptTesting.queue_push(tq, x.cos()) 866 torch.ops._TorchScriptTesting.queue_push(tq, x.sin()) 867 x_sin = torch.ops._TorchScriptTesting.queue_pop( 868 tq 869 ) - torch.ops._TorchScriptTesting.queue_size(tq) 870 x_cos = torch.ops._TorchScriptTesting.queue_pop( 871 tq 872 ) + torch.ops._TorchScriptTesting.queue_size(tq) 873 return x_sin, x_cos, tq 874 875 mod = Model() 876 877 tq1 = torch.classes._TorchScriptTesting._TensorQueue( 878 torch.empty( 879 0, 880 ).fill_(-1) 881 ) 882 tq2 = torch.classes._TorchScriptTesting._TensorQueue( 883 torch.empty( 884 0, 885 ).fill_(-1) 886 ) 887 x = torch.ones(2, 3) 888 889 mod(tq1, x) 890 891 ops = [ 892 torch.ops._TorchScriptTesting.queue_push, 893 torch.ops._TorchScriptTesting.queue_pop, 894 torch.ops._TorchScriptTesting.queue_size, 895 ] 896 if fallthrough_via == "lib_impl": 897 ns = "_TorchScriptTesting" 898 with torch.library._scoped_library(ns, "FRAGMENT") as lib: 899 for op in ops: 900 lib.impl( 901 op.__name__, torch.library.fallthrough_kernel, "AutocastCUDA" 902 ) 903 904 gm = make_fx(mod, tracing_mode="fake")(tq1, x) 905 else: 906 for op in ops: 907 op.default.py_impl(torch._C.DispatchKey.AutocastCUDA)( 908 torch.library.fallthrough_kernel 909 ) 910 gm = make_fx(mod, tracing_mode="fake")(tq1, x) 911 for op in ops: 912 op.default._dispatch_cache.clear() 913 del op.default.py_kernels[torch._C.DispatchKey.AutocastCUDA] 914 915 self.assertExpectedInline( 916 gm.code.strip(), 917 """\ 918def forward(self, arg0_1, arg1_1): 919 cos = torch.ops.aten.cos.default(arg1_1) 920 queue_push = torch.ops._TorchScriptTesting.queue_push.default(arg0_1, cos); cos = queue_push = None 921 sin = torch.ops.aten.sin.default(arg1_1); arg1_1 = None 922 queue_push_1 = torch.ops._TorchScriptTesting.queue_push.default(arg0_1, sin); sin = queue_push_1 = None 923 queue_pop = torch.ops._TorchScriptTesting.queue_pop.default(arg0_1) 924 queue_size = torch.ops._TorchScriptTesting.queue_size.default(arg0_1); queue_size = None 925 sub = torch.ops.aten.sub.Tensor(queue_pop, 1); queue_pop = None 926 queue_pop_1 = torch.ops._TorchScriptTesting.queue_pop.default(arg0_1) 927 queue_size_1 = torch.ops._TorchScriptTesting.queue_size.default(arg0_1); queue_size_1 = None 928 add = torch.ops.aten.add.Tensor(queue_pop_1, 0); queue_pop_1 = None 929 return (sub, add, arg0_1)""", 930 ) 931 _assertEqualSkipScriptObject(self, gm(tq1, x), mod(tq2, x)) 932 933 def test_aot_export_tensor_queue_operators(self): 934 class Model(torch.nn.Module): 935 def __init__(self) -> None: 936 super().__init__() 937 938 def forward(self, tq, x): 939 torch.ops._TorchScriptTesting.queue_push(tq, x.cos()) 940 torch.ops._TorchScriptTesting.queue_push(tq, x.sin()) 941 x_sin = torch.ops._TorchScriptTesting.queue_pop( 942 tq 943 ) - torch.ops._TorchScriptTesting.queue_size(tq) 944 x_cos = torch.ops._TorchScriptTesting.queue_pop( 945 tq 946 ) + torch.ops._TorchScriptTesting.queue_size(tq) 947 return x_sin, x_cos, tq 948 949 mod = Model() 950 951 tq1 = torch.classes._TorchScriptTesting._TensorQueue( 952 torch.empty( 953 0, 954 ).fill_(-1) 955 ) 956 x = torch.ones(2, 3) 957 958 fake_mode = torch._subclasses.fake_tensor.FakeTensorMode() 959 fake_tq1 = torch._library.fake_class_registry.maybe_to_fake_obj(fake_mode, tq1) 960 fake_x = fake_mode.from_tensor(x) 961 gm = aot_export_module(mod, (fake_tq1, fake_x), trace_joint=False)[0] 962 963 # inputs: token, tq, x 964 # return: token, x_sin, x_cos, tq 965 self.assertExpectedInline( 966 gm.code.strip(), 967 """\ 968def forward(self, arg0_1, arg1_1, arg2_1): 969 cos = torch.ops.aten.cos.default(arg2_1) 970 with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops._TorchScriptTesting.queue_push.default, arg1_1, cos); arg0_1 = cos = None 971 getitem = with_effects[0]; with_effects = None 972 sin = torch.ops.aten.sin.default(arg2_1); arg2_1 = None 973 with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._TorchScriptTesting.queue_push.default, arg1_1, sin); getitem = sin = None 974 getitem_2 = with_effects_1[0]; with_effects_1 = None 975 with_effects_2 = torch.ops.higher_order.with_effects(getitem_2, torch.ops._TorchScriptTesting.queue_pop.default, arg1_1); getitem_2 = None 976 getitem_4 = with_effects_2[0] 977 getitem_5 = with_effects_2[1]; with_effects_2 = None 978 with_effects_3 = torch.ops.higher_order.with_effects(getitem_4, torch.ops._TorchScriptTesting.queue_size.default, arg1_1); getitem_4 = None 979 getitem_6 = with_effects_3[0]; with_effects_3 = None 980 sub = torch.ops.aten.sub.Tensor(getitem_5, 1); getitem_5 = None 981 with_effects_4 = torch.ops.higher_order.with_effects(getitem_6, torch.ops._TorchScriptTesting.queue_pop.default, arg1_1); getitem_6 = None 982 getitem_8 = with_effects_4[0] 983 getitem_9 = with_effects_4[1]; with_effects_4 = None 984 with_effects_5 = torch.ops.higher_order.with_effects(getitem_8, torch.ops._TorchScriptTesting.queue_size.default, arg1_1); getitem_8 = None 985 getitem_10 = with_effects_5[0]; with_effects_5 = None 986 add = torch.ops.aten.add.Tensor(getitem_9, 0); getitem_9 = None 987 return (getitem_10, sub, add, arg1_1)""", # noqa: B950 988 ) 989 990 def test_export_inplace_custom_op(self): 991 class Model(torch.nn.Module): 992 def forward(self, tq: torch.ScriptObject, x: torch.Tensor) -> None: 993 torch.ops._TorchScriptTesting.queue_push(tq, x) 994 return tq 995 996 mod = Model() 997 ep = self._test_export_same_as_eager( 998 mod, 999 (_empty_tensor_queue(), torch.randn(3, 3)), 1000 strict=False, 1001 pre_dispatch=True, 1002 ) 1003 self.assertExpectedInline( 1004 ep.module().code.strip(), 1005 """\ 1006def forward(self, tq, x): 1007 tq, x, = fx_pytree.tree_flatten_spec(([tq, x], {}), self._in_spec) 1008 queue_push_default = torch.ops._TorchScriptTesting.queue_push.default(tq, x); x = queue_push_default = None 1009 return pytree.tree_unflatten((tq,), self._out_spec)""", 1010 ) 1011 self.assertExpectedInline( 1012 ep.graph_module.code.strip(), 1013 """\ 1014def forward(self, token, tq, x): 1015 with_effects = torch.ops.higher_order.with_effects(token, torch.ops._TorchScriptTesting.queue_push.default, tq, x); token = x = None 1016 getitem = with_effects[0]; with_effects = None 1017 return (getitem, tq)""", # noqa: B950 1018 ) 1019 self.assertExpectedInline( 1020 str(ep.graph_module.graph).strip(), 1021 """\ 1022graph(): 1023 %tq : [num_users=2] = placeholder[target=tq] 1024 %x : [num_users=1] = placeholder[target=x] 1025 %queue_push_default : [num_users=0] = call_function[target=torch.ops._TorchScriptTesting.queue_push.default](args = (%tq, %x), kwargs = {}) 1026 return (tq,)""", # noqa: B950 1027 ) 1028 1029 1030class TestCompileTorchbind(TestCase): 1031 def setUp(self): 1032 init_torchbind_implementations() 1033 1034 @torch._library.register_fake_class("_TorchScriptTesting::_TensorQueue") 1035 class FakeTensorQueue: 1036 def __init__(self, queue): 1037 self.queue = queue 1038 1039 @classmethod 1040 def __obj_unflatten__(cls, flattened_ctx): 1041 return cls(**dict(flattened_ctx)) 1042 1043 def push(self, x): 1044 self.queue.append(x) 1045 1046 def pop(self): 1047 return self.queue.pop(0) 1048 1049 def size(self): 1050 return len(self.queue) 1051 1052 @torch._library.register_fake_class("_TorchScriptTesting::_FlattenWithTensorOp") 1053 class FakeFlatten: 1054 def __init__(self, t): 1055 self.t = t 1056 1057 def get(self): 1058 return self.t 1059 1060 @classmethod 1061 def __obj_unflatten__(cls, flattened_ctx): 1062 return cls(**dict(flattened_ctx)) 1063 1064 torch._dynamo.reset() 1065 1066 def tearDown(self): 1067 torch._dynamo.reset() 1068 1069 @parametrize("backend", ["eager", "aot_eager"]) 1070 def test_compile_script_object_input(self, backend): 1071 if backend == "eager": 1072 backend = EagerAndRecordGraphs() 1073 1074 class Model(torch.nn.Module): 1075 def __init__(self) -> None: 1076 super().__init__() 1077 self.check_tq_is_fake = True 1078 1079 def forward(self, tq, x): 1080 tq.push(x.cos()) 1081 tq.push(x.sin()) 1082 x_sin = tq.pop() - tq.size() 1083 return x_sin, tq 1084 1085 mod = Model() 1086 tq1 = torch.classes._TorchScriptTesting._TensorQueue( 1087 torch.empty( 1088 0, 1089 ).fill_(-1) 1090 ) 1091 tq2 = torch.classes._TorchScriptTesting._TensorQueue( 1092 torch.empty( 1093 0, 1094 ).fill_(-1) 1095 ) 1096 tq3 = torch.classes._TorchScriptTesting._TensorQueue( 1097 torch.empty( 1098 0, 1099 ).fill_(-1) 1100 ) 1101 tq4 = torch.classes._TorchScriptTesting._TensorQueue( 1102 torch.empty( 1103 0, 1104 ).fill_(-1) 1105 ) 1106 x = torch.randn(2, 3) 1107 ret = torch.compile(mod, backend=backend)(tq1, x) 1108 eager_ret = mod(tq2, x) 1109 _assertEqualSkipScriptObject(self, ret, eager_ret) 1110 self.assertEqual(ret[1].size(), eager_ret[1].size()) 1111 self.assertEqual(ret[1].pop(), eager_ret[1].pop()) 1112 # Note that dynamo captured graph 1113 # does not return L_tq_ as output. This is because it's able 1114 # to detect that L_tq_ is an input therefore don't return 1115 # it as graph output. Related logic is in dynamo/codegen.py 1116 if backend == "eager": 1117 self.assertExpectedInline( 1118 backend.graphs[0].code.strip(), 1119 """\ 1120 def forward(self, L_tq_ : torch.ScriptObject, L_x_ : torch.Tensor): 1121 l_tq_ = L_tq_ 1122 l_x_ = L_x_ 1123 cos = l_x_.cos() 1124 call_torchbind = torch.ops.higher_order.call_torchbind(l_tq_, 'push', cos); cos = None 1125 sin = l_x_.sin(); l_x_ = None 1126 call_torchbind_1 = torch.ops.higher_order.call_torchbind(l_tq_, 'push', sin); sin = None 1127 call_torchbind_2 = torch.ops.higher_order.call_torchbind(l_tq_, 'pop') 1128 call_torchbind_3 = torch.ops.higher_order.call_torchbind(l_tq_, 'size'); l_tq_ = None 1129 x_sin = call_torchbind_2 - 1; call_torchbind_2 = None 1130 return (x_sin,)""", 1131 ) 1132 1133 @parametrize("backend", ["eager", "aot_eager"]) 1134 def test_compile_script_object_input_guards(self, backend): 1135 class Model(torch.nn.Module): 1136 def __init__(self) -> None: 1137 super().__init__() 1138 self.check_tq_is_fake = True 1139 1140 def forward(self, tq, x): 1141 tq.push(x.cos()) 1142 tq.push(x.sin()) 1143 x_sin = tq.pop() - tq.size() 1144 return x_sin, tq 1145 1146 mod = Model() 1147 cnt = torch._dynamo.testing.CompileCounterWithBackend(backend) 1148 x = torch.randn(2, 3) 1149 1150 tq1 = _empty_tensor_queue() 1151 torch.compile(mod, backend=cnt)(tq1, x) 1152 self.assertEqual(cnt.frame_count, 1) 1153 1154 tq2 = _empty_tensor_queue() 1155 for _ in range(10): 1156 tq2.push(torch.randn(4, 5, requires_grad=False)) 1157 torch.compile(mod, backend=cnt)(tq2, x) 1158 # Queue length change causes re-compile 1159 self.assertEqual(cnt.frame_count, 2) 1160 1161 tq3 = _empty_tensor_queue() 1162 tq3.push(torch.randn(2, 3, requires_grad=False)) 1163 torch.compile(mod, backend=cnt)(tq3, x) 1164 # Tensor in queue changes shape causes re-compile 1165 self.assertEqual(cnt.frame_count, 3) 1166 1167 tq4 = _empty_tensor_queue() 1168 tq4.push(torch.randn(2, 3, requires_grad=False)) 1169 torch.compile(mod, backend=cnt)(tq4, x) 1170 # No recompile 1171 self.assertEqual(cnt.frame_count, 3) 1172 1173 tq5 = _empty_tensor_queue() 1174 tq5.push(torch.randn(2, 3, requires_grad=True)) 1175 torch.compile(mod, backend=cnt)(tq5, x) 1176 # Tensor in queue changes dispatch key causes re-compile 1177 self.assertEqual(cnt.frame_count, 4) 1178 1179 tq6 = _empty_tensor_queue() 1180 tq6.push(torch.randn(2, 3, requires_grad=True, dtype=torch.float64)) 1181 torch.compile(mod, backend=cnt)(tq6, x) 1182 # Tensor in queue changes dtype causes re-compile 1183 self.assertEqual(cnt.frame_count, 5) 1184 1185 def test_compile_script_object_input_automatic_dynamic_shape(self): 1186 class Model(torch.nn.Module): 1187 def __init__(self) -> None: 1188 super().__init__() 1189 self.check_tq_is_fake = True 1190 1191 def forward(self, tq, x): 1192 tq.push(x.cos()) 1193 tq.push(x.sin()) 1194 x_sin = tq.pop() - tq.size() 1195 return x_sin, tq 1196 1197 mod = Model() 1198 cnt = torch._dynamo.testing.CompileCounter() 1199 x = torch.randn(2, 3) 1200 1201 tq1 = _empty_tensor_queue() 1202 tq1.push(torch.randn(2, 3, requires_grad=False)) 1203 torch.compile(mod, backend=cnt)(tq1, x) 1204 self.assertEqual(cnt.frame_count, 1) 1205 1206 tq2 = _empty_tensor_queue() 1207 # make first tensor's secon dim dynamic 1208 tq2.push(torch.randn(2, 4, requires_grad=False)) 1209 torch.compile(mod, backend=cnt)(tq2, x) 1210 self.assertEqual(cnt.frame_count, 2) 1211 1212 tq3 = _empty_tensor_queue() 1213 tq3.push(torch.randn(2, 5, requires_grad=False)) 1214 # should have no-recompilation 1215 torch.compile(mod, backend=cnt)(tq3, x) 1216 self.assertEqual(cnt.frame_count, 2) 1217 1218 @parametrize("backend", ["eager", "aot_eager"]) 1219 def test_compile_error_on_input_aliasing_contents(self, backend): 1220 if backend == "eager": 1221 backend = EagerAndRecordGraphs() 1222 1223 class Model(torch.nn.Module): 1224 def __init__(self) -> None: 1225 super().__init__() 1226 self.check_tq_is_fake = True 1227 1228 def forward(self, tq, x): 1229 return x.sin(), tq.pop().cos() 1230 1231 x = torch.randn(2, 3) 1232 mod = Model() 1233 1234 tq1 = _empty_tensor_queue() 1235 tq1.push(x) 1236 with self.assertRaisesRegex(RuntimeError, "is alising"): 1237 torch.compile(mod, backend=backend)(tq1, x) 1238 1239 @parametrize("backend", ["eager", "aot_eager"]) 1240 def test_compile_error_on_script_obj_setattr(self, backend): 1241 if backend == "eager": 1242 backend = EagerAndRecordGraphs() 1243 1244 def setattr_f(tq): 1245 tq.a = 1 1246 return tq 1247 1248 with self.assertRaisesRegex( 1249 RuntimeError, "call method __setattr__ on script object is not safe" 1250 ): 1251 torch.compile(setattr_f, backend=backend)(_empty_tensor_queue()) 1252 1253 @parametrize("backend", ["eager", "aot_eager"]) 1254 def test_compile_error_on_script_obj_missing_attr(self, backend): 1255 if backend == "eager": 1256 backend = EagerAndRecordGraphs() 1257 1258 def setattr_f(tq): 1259 return tq._not_defined_attr 1260 1261 with self.assertRaisesRegex( 1262 RuntimeError, "doesn't define method _not_defined_attr" 1263 ): 1264 torch.compile(setattr_f, backend=backend)(_empty_tensor_queue()) 1265 1266 @parametrize("backend", ["eager", "aot_eager"]) 1267 def test_compile_body_aliasing_contents(self, backend): 1268 if backend == "eager": 1269 backend = EagerAndRecordGraphs() 1270 1271 def f(tq, x): 1272 x1 = x.view(-1) 1273 x2 = x.permute(1, 0) 1274 tq.push(x1) 1275 tq.push(x2) 1276 return x1 - tq.size(), x2 + tq.size(), tq 1277 1278 x = torch.randn(2, 3) 1279 _assertEqualScriptObject( 1280 self, 1281 f(_empty_tensor_queue(), x), 1282 torch.compile(f, backend=backend)(_empty_tensor_queue(), x), 1283 ) 1284 if not torch._dynamo.is_compiling() and backend == "eager": 1285 self.assertExpectedInline( 1286 backend.graphs[0].code.strip(), 1287 """\ 1288def forward(self, L_x_ : torch.Tensor, L_tq_ : torch.ScriptObject): 1289 l_x_ = L_x_ 1290 l_tq_ = L_tq_ 1291 x1 = l_x_.view(-1) 1292 x2 = l_x_.permute(1, 0); l_x_ = None 1293 call_torchbind = torch.ops.higher_order.call_torchbind(l_tq_, 'push', x1) 1294 call_torchbind_1 = torch.ops.higher_order.call_torchbind(l_tq_, 'push', x2) 1295 call_torchbind_2 = torch.ops.higher_order.call_torchbind(l_tq_, 'size') 1296 sub = x1 - 2; x1 = None 1297 call_torchbind_3 = torch.ops.higher_order.call_torchbind(l_tq_, 'size'); l_tq_ = None 1298 add = x2 + 2; x2 = None 1299 return (sub, add)""", 1300 ) 1301 1302 @parametrize("backend", ["eager", "aot_eager"]) 1303 def test_compile_tensor_op_in_tensor_flatten(self, backend): 1304 test_obj = torch.classes._TorchScriptTesting._FlattenWithTensorOp( 1305 torch.randn(3, 2) 1306 ) 1307 1308 class TestMod(torch.nn.Module): 1309 def forward(self, obj, x): 1310 return obj.get() + x 1311 1312 mod = TestMod() 1313 1314 torch.compile(mod, backend=backend, fullgraph=True)(test_obj, torch.randn(3, 1)) 1315 ep = torch.export.export(mod, (test_obj, torch.randn(3, 1)), strict=False) 1316 self.assertExpectedInline( 1317 ep.graph_module.code.strip(), 1318 """\ 1319def forward(self, token, obj, x): 1320 with_effects = torch.ops.higher_order.with_effects(token, torch.ops.higher_order.call_torchbind, obj, 'get'); token = obj = None 1321 getitem = with_effects[0] 1322 getitem_1 = with_effects[1]; with_effects = None 1323 add_3 = torch.ops.aten.add.Tensor(getitem_1, x); getitem_1 = x = None 1324 return (getitem, add_3)""", # noqa: B950 1325 ) 1326 1327 @parametrize("backend", ["eager", "aot_eager"]) 1328 def test_compile_error_on_non_fakified_method(self, backend): 1329 if backend == "eager": 1330 backend = EagerAndRecordGraphs() 1331 1332 def f(tq, x): 1333 x1 = x.view(-1) 1334 x2 = x.permute(1, 0) 1335 tq.push(x1) 1336 tq.push(x2) 1337 # though real tensor queue implemented a method clone_queue, 1338 # The fakified version doesn't. 1339 flat_obj = tq.clone_queue() 1340 return flat_obj 1341 1342 x = torch.randn(2, 3) 1343 with self.assertRaisesRegex( 1344 RuntimeError, "FakeScriptObject doesn't define method" 1345 ): 1346 torch.compile(f, backend=backend)(_empty_tensor_queue(), x) 1347 1348 @parametrize("backend", ["eager", "aot_eager"]) 1349 def test_compile_obj_as_hop_input(self, backend): 1350 def f(tq, x): 1351 def fn(tq, x): 1352 tq.push(x) 1353 return x.sin() 1354 1355 return wrap(fn, tq, x) 1356 1357 x = torch.randn(2, 3) 1358 _assertEqualScriptObject( 1359 self, 1360 f(_empty_tensor_queue(), x), 1361 torch.compile(f, backend=backend)(_empty_tensor_queue(), x), 1362 ) 1363 1364 @parametrize("backend", ["eager", "aot_eager"]) 1365 def test_compile_obj_closure(self, backend): 1366 def f(x): 1367 def inner_f(x): 1368 tq.push(x.sin()) 1369 1370 inner_f(x) 1371 return tq.pop(), tq 1372 1373 opt_f = torch.compile(f, backend="eager") 1374 1375 tq = _empty_tensor_queue() 1376 x = torch.randn(3, 2) 1377 _assertEqualScriptObject(self, f(x), opt_f(x)) 1378 1379 @parametrize("backend", ["eager", "aot_eager"]) 1380 def test_compile_global_obj(self, backend): 1381 global _TENSOR_QUEUE_GLOBAL_TEST 1382 _TENSOR_QUEUE_GLOBAL_TEST = _empty_tensor_queue() 1383 1384 def f(x): 1385 _TENSOR_QUEUE_GLOBAL_TEST.push(x.sin()) 1386 return _TENSOR_QUEUE_GLOBAL_TEST.pop(), _TENSOR_QUEUE_GLOBAL_TEST 1387 1388 opt_f = torch.compile(f, backend=backend) 1389 x = torch.randn(3, 2) 1390 eager_ret = f(x) 1391 opt_ret = opt_f(x) 1392 _assertEqualScriptObject(self, eager_ret, opt_ret) 1393 1394 def test_compile_obj_graph_breaks(self): 1395 cnt = torch._dynamo.testing.CompileCounter() 1396 1397 def f(tq, x): 1398 tq.push(x.sin()) 1399 tq.push(x.sin()) 1400 torch._dynamo.graph_break() 1401 tq.pop() 1402 torch._dynamo.graph_break() 1403 tq.push(x.cos() + tq.size()) 1404 torch._dynamo.graph_break() 1405 tq.push(x.cos() - tq.size()) 1406 return x, tq.pop(), tq 1407 1408 opt_f = torch.compile(f, backend=cnt) 1409 x = torch.randn(3, 2) 1410 _assertEqualScriptObject( 1411 self, f(_empty_tensor_queue(), x), opt_f(_empty_tensor_queue(), x) 1412 ) 1413 self.assertEqual(cnt.frame_count, 4) 1414 1415 @parametrize("backend", ["eager", "aot_eager"]) 1416 def test_compile_obj_attributes(self, backend): 1417 if backend == "eager": 1418 backend = EagerAndRecordGraphs() 1419 1420 class Model(torch.nn.Module): 1421 def __init__(self) -> None: 1422 super().__init__() 1423 self.tq = _empty_tensor_queue() 1424 1425 def forward(self, x): 1426 self.tq.push(x) 1427 return self.tq.pop() 1428 1429 x = torch.randn(2, 3) 1430 opt_f = torch.compile(Model(), backend=backend) 1431 _assertEqualScriptObject(self, Model()(x), opt_f(x)) 1432 if backend == "eager": 1433 self.assertEqual(len(backend.graphs), 1) 1434 # lifted as input. In the future, we would want to cosolidate this 1435 # with non-strict behavior, where they're set as attributes. 1436 self.assertExpectedInline( 1437 backend.graphs[0].code.strip(), 1438 """\ 1439 def forward(self, L_self_tq : torch.ScriptObject, L_x_ : torch.Tensor): 1440 l_self_tq = L_self_tq 1441 l_x_ = L_x_ 1442 call_torchbind = torch.ops.higher_order.call_torchbind(l_self_tq, 'push', l_x_); l_x_ = None 1443 call_torchbind_1 = torch.ops.higher_order.call_torchbind(l_self_tq, 'pop'); l_self_tq = None 1444 return (call_torchbind_1,)""", 1445 ) 1446 1447 @parametrize("backend", ["eager", "aot_eager"]) 1448 def test_compile_obj_torchbind_op(self, backend): 1449 def f(tq, x): 1450 torch.ops._TorchScriptTesting.queue_push(tq, x.cos()) 1451 torch.ops._TorchScriptTesting.queue_push(tq, x.cos() + 1) 1452 torch.ops._TorchScriptTesting.queue_pop(tq) 1453 torch.ops._TorchScriptTesting.queue_push(tq, x.sin()) 1454 return tq.pop(), tq.pop() + tq.size(), tq 1455 1456 opt_f = torch.compile(f, backend=backend) 1457 x = torch.randn(2) 1458 _assertEqualScriptObject( 1459 self, f(_empty_tensor_queue(), x), opt_f(_empty_tensor_queue(), x) 1460 ) 1461 1462 1463@skipIfTorchDynamo("torchbind not supported with dynamo yet") 1464class TestRegisterFakeClass(TestCase): 1465 def setUp(self): 1466 init_torchbind_implementations() 1467 1468 def tearDown(self): 1469 torch._library.fake_class_registry.global_fake_class_registry.clear() 1470 1471 def test_register_fake_class_no_torch_bind_class(self): 1472 with self.assertRaisesRegex(RuntimeError, "Tried to instantiate class"): 1473 1474 @torch._library.register_fake_class("_TorchScriptTesting::NOT_A_VALID_NAME") 1475 class Invalid: 1476 pass 1477 1478 def test_register_fake_class_no_from_real(self): 1479 with self.assertRaisesRegex( 1480 RuntimeError, "define a classmethod __obj_unflatten__" 1481 ): 1482 1483 @torch._library.register_fake_class("_TorchScriptTesting::_Foo") 1484 class InvalidFakeFoo: 1485 def __init__(self) -> None: 1486 pass 1487 1488 def test_register_fake_class_from_real_not_classmethod(self): 1489 with self.assertRaisesRegex(RuntimeError, "is not a classmethod"): 1490 1491 @torch._library.register_fake_class("_TorchScriptTesting::_Foo") 1492 class FakeFoo: 1493 def __init__(self, x, y): 1494 self.x = x 1495 self.y = y 1496 1497 def __obj_unflatten__(cls, flattend_foo): # noqa: B902 1498 return cls(**dict(flattend_foo)) 1499 1500 def test_register_fake_class_valid(self): 1501 class FakeFoo: 1502 def __init__(self, x, y): 1503 self.x = x 1504 self.y = y 1505 1506 @classmethod 1507 def __obj_unflatten__(cls, flattend_foo): 1508 return cls(**dict(flattend_foo)) 1509 1510 torch._library.register_fake_class("_TorchScriptTesting::_Foo", FakeFoo) 1511 1512 1513instantiate_parametrized_tests(TestExportTorchbind) 1514instantiate_parametrized_tests(TestCompileTorchbind) 1515 1516if __name__ == "__main__": 1517 run_tests() 1518