1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pye-strict 8 9import copy 10import unittest 11from typing import Any, Dict 12 13import torch 14from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig 15from executorch.exir.backend.test.op_partitioner_demo import ( 16 AddMulPartitionerDemo, 17 NonDecompTestPartitioner, 18) 19from executorch.exir.dialects._ops import ops as exir_ops 20from executorch.exir.error import ExportError 21from executorch.exir.lowered_backend_module import get_lowered_submodules 22from executorch.exir.pass_base import ExportPass 23from executorch.exir.passes import MemoryPlanningPass 24from executorch.exir.program._program import ( 25 EdgeProgramManager, 26 ExecutorchProgramManager, 27 to_edge, 28 to_edge_transform_and_lower, 29 to_edge_with_preserved_ops, 30) 31from executorch.exir.tracer import _default_decomposition_table 32from executorch.exir.verification.verifier import EXIREdgeDialectVerifier 33 34from executorch.extension.pybindings.portable_lib import ( 35 _load_for_executorch_from_buffer, 36) 37from torch.export import Dim, export, ExportedProgram 38from torch.export._trace import _export 39 40from torch.library import impl, Library 41from torch.nn import functional as F 42 43 44class TestLinear(torch.nn.Module): 45 def __init__(self): 46 super().__init__() 47 self.linear = torch.nn.Linear(32, 16, bias=True) 48 49 def forward(self, x): 50 return self.linear(x) 51 52 @classmethod 53 def _get_random_inputs(cls): 54 x = torch.rand(8, 32) 55 return (x,) 56 57 58class TestSDPA(torch.nn.Module): 59 def __init__(self): 60 super().__init__() 61 62 def forward(self, query, key, value): 63 return torch.ops.aten.scaled_dot_product_attention.default(query, key, value) 64 65 @classmethod 66 def _get_random_inputs(cls): 67 d_k = 64 68 batch = 16 69 seq_len = 10 70 query = torch.rand(batch, seq_len, d_k) 71 key = torch.rand(batch, seq_len, d_k) 72 value = torch.rand(batch, seq_len, d_k) 73 return (query, key, value) 74 75 76class TestLinearSDPACombined(torch.nn.Module): 77 def __init__(self): 78 super().__init__() 79 self.linear = torch.nn.Linear(32, 16, bias=True) 80 81 def forward(self, x, query, key, value): 82 x = self.linear(x) 83 return ( 84 x, 85 torch.ops.aten.scaled_dot_product_attention.default(query, key, value), 86 ) 87 88 @classmethod 89 def _get_random_inputs(cls): 90 return TestLinear._get_random_inputs() + TestSDPA._get_random_inputs() 91 92 93class TestUpsample(torch.nn.Module): 94 def __init__(self): 95 super().__init__() 96 97 def forward(self, x): 98 x = F.interpolate(x, scale_factor=2, mode="nearest") 99 return x 100 101 @classmethod 102 def _get_random_inputs(cls): 103 x = torch.randn(1, 1, 8, 8) 104 return (x,) 105 106 107class TestLSTM(torch.nn.Module): 108 def __init__(self): 109 super().__init__() 110 self.lstm = torch.nn.LSTM(input_size=8, hidden_size=16, batch_first=True) 111 112 def forward(self, x): 113 return self.lstm(x) 114 115 @classmethod 116 def _get_random_inputs(cls): 117 return (torch.rand(1, 10, 8),) 118 119 120class WrapperModule(torch.nn.Module): 121 def __init__(self, fn): 122 super().__init__() 123 self.fn = fn 124 125 def forward(self, *args, **kwargs): 126 return self.fn(*args, **kwargs) 127 128 129lib = Library("exir_program_test_op", "DEF") 130 131# Fake a operator for testing. 132# This operator takes two tensors as input and returns the first one. 133lib.define("foo(Tensor self, Tensor other) -> Tensor") 134 135 136@impl(lib, "foo", "CPU") 137def foo(a, b): 138 # do nothing and return a. 139 return a + b 140 141 142@impl(lib, "foo", "Meta") 143def foo_meta(a, b): 144 # do nothing and return a. 145 return torch.empty_like(a) 146 147 148def get_exported_programs() -> Dict[str, ExportedProgram]: 149 class Forward(torch.nn.Module): 150 def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 151 z = torch.mul(x, y) 152 return torch.add(z, x) 153 154 forward = Forward() 155 156 class Foo(torch.nn.Module): 157 def forward(self, x: torch.Tensor) -> torch.Tensor: 158 return torch.add(x, torch.ones(1)) 159 160 foo = Foo() 161 162 programs = {} 163 programs["forward"] = export( 164 forward, 165 args=( 166 torch.ones(1), 167 torch.zeros(1), 168 ), 169 ).run_decompositions() 170 programs["foo"] = export( 171 foo, 172 (torch.ones(1),), 173 ).run_decompositions() 174 return programs 175 176 177def get_config_methods() -> Dict[str, Any]: 178 def bam(): 179 return 3 180 181 def bar(): 182 return "bar" 183 184 return {"bam": bam(), "bar": bar()} 185 186 187class AddToMulPassEdge(ExportPass): 188 def call_operator(self, op, args, kwargs, meta): 189 if op == exir_ops.edge.aten.add.Tensor: 190 return super().call_operator( 191 exir_ops.edge.aten.mul.Tensor, args, kwargs, meta 192 ) 193 else: 194 return super().call_operator(op, args, kwargs, meta) 195 196 197class TestProgramManagers(unittest.TestCase): 198 def test_edge_manager_basic_api(self): 199 edge_manager: EdgeProgramManager = to_edge( 200 get_exported_programs(), get_config_methods() 201 ) 202 203 # test basic apis 204 self.assertEqual(edge_manager.methods, {"forward", "foo"}) 205 self.assertEqual(edge_manager.config_methods, {"bam", "bar"}) 206 207 # test dialect is correct 208 try: 209 EXIREdgeDialectVerifier()( 210 edge_manager.exported_program("forward").graph_module 211 ) 212 EXIREdgeDialectVerifier()(edge_manager.exported_program("foo").graph_module) 213 except ExportError as e: 214 self.assertTrue(False, msg="Graph not in edge dialect : " + e.msg) 215 216 def test_executorch_manager_basic_api(self): 217 executorch_manager: ExecutorchProgramManager = to_edge( 218 get_exported_programs(), get_config_methods() 219 ).to_executorch() 220 221 # test basic apis 222 self.assertEqual(executorch_manager.methods, {"forward", "foo"}) 223 self.assertEqual(executorch_manager.config_methods, {"bam", "bar"}) 224 225 # test that the emitted output is correct 226 self.assertEqual( 227 len(executorch_manager._emitter_output.program.execution_plan), 4 228 ) 229 230 # test that the buffer is correct 231 executorch_module = _load_for_executorch_from_buffer(executorch_manager.buffer) 232 self.assertEqual( 233 executorch_module.run_method("forward", (torch.ones(1), torch.zeros(1)))[0], 234 torch.ones(1), 235 ) 236 self.assertEqual( 237 executorch_module.run_method("foo", (torch.ones(1),))[0], 238 torch.ones(1) + torch.ones(1), 239 ) 240 self.assertEqual( 241 executorch_module.run_method("bar", ())[0], 242 "bar", 243 ) 244 self.assertEqual( 245 executorch_module.run_method("bam", ())[0], 246 3, 247 ) 248 249 def test_executorch_manager_multi_config(self): 250 def get_executorch_memory_planning_passes() -> Dict[str, MemoryPlanningPass]: 251 return { 252 "forward": MemoryPlanningPass( 253 alloc_graph_input=True, 254 alloc_graph_output=False, 255 ), 256 "foo": MemoryPlanningPass( 257 alloc_graph_input=False, 258 alloc_graph_output=True, 259 ), 260 } 261 262 executorch_manager: ExecutorchProgramManager = to_edge( 263 get_exported_programs(), get_config_methods() 264 ).to_executorch( 265 ExecutorchBackendConfig( 266 memory_planning_pass=get_executorch_memory_planning_passes() 267 ) 268 ) 269 270 method = executorch_manager._emitter_output.program.execution_plan[0] 271 if method.name == "forward": 272 for input_val in method.inputs: 273 evalue = method.values[input_val] 274 self.assertEqual(evalue.val.allocation_info, None) 275 for output_val in method.outputs: 276 evalue = method.values[output_val] 277 self.assertNotEqual(evalue.val.allocation_info, None) 278 else: 279 for input_val in method.inputs: 280 evalue = method.values[input_val] 281 self.assertEqual(evalue.val.allocation_info, None) 282 for output_val in method.outputs: 283 evalue = method.values[output_val] 284 self.assertNotEqual(evalue.val.allocation_info, None) 285 286 def test_no_getattr(self): 287 class Mul(torch.nn.Module): 288 def forward(self, x: torch.Tensor) -> torch.Tensor: 289 return x * 3.14 290 291 mul = Mul() 292 ep = to_edge(torch.export.export(mul, (torch.ones(1),))).exported_program() 293 for node in ep.graph.nodes: 294 self.assertNotEqual(node.op, "get_attr") 295 self.assertEqual( 296 len([node for node in ep.graph.nodes if node.op == "placeholder"]), 2 297 ) 298 299 def test_constraint_present_after_dce(self): 300 import executorch.exir as exir 301 302 class M(torch.nn.Module): 303 def forward(self, x, y): 304 z = y.item() 305 torch._check(z > 0) 306 torch._check(z < 4) 307 return x[z : z + y.shape[0]] 308 309 ep = torch.export.export(M(), (torch.randn(10), torch.tensor([3]))) 310 311 edge_manager = to_edge( 312 ep, compile_config=exir.EdgeCompileConfig(_check_ir_validity=False) 313 ) 314 edge_manager.to_executorch() 315 316 def test_edge_manager_transform(self): 317 edge_manager: EdgeProgramManager = to_edge( 318 get_exported_programs(), get_config_methods() 319 ) 320 321 original_res = edge_manager.exported_program("forward").module()( 322 torch.ones(1), torch.ones(1) 323 ) 324 325 # perform transformation 326 transformed_edge = edge_manager.transform( 327 [ 328 AddToMulPassEdge(), 329 ] 330 ) 331 332 # still have all our methods 333 self.assertEqual(len(transformed_edge.methods), 2) 334 self.assertEqual(len(transformed_edge.config_methods), 2) 335 336 # transformation was applied 337 self.assertEqual( 338 transformed_edge.exported_program("forward").module()( 339 torch.ones(1), torch.ones(1) 340 ), 341 torch.ones(1), # x * y * x 342 ) 343 344 # original unchanged 345 self.assertEqual( 346 edge_manager.exported_program("forward").module()( 347 torch.ones(1), torch.ones(1) 348 ), 349 original_res, # x * y + x 350 ) 351 352 def test_issue_3659(self): 353 354 class Mul(torch.nn.Module): 355 def __init__(self): 356 super(Mul, self).__init__() 357 358 def forward(self, x: torch.Tensor, y: torch.Tensor): 359 return torch.matmul(x, y) 360 361 def get_eager_model(self) -> torch.nn.Module: 362 return self 363 364 def get_example_inputs(self): 365 return (torch.randn(1, 3, 10), torch.randn(1, 10, 3)) 366 367 def get_dynamic_shapes(self): 368 dim1_x = Dim("Dot_dim1_x", min=2, max=100) 369 dim2_x = Dim("Dot_dim2_x", min=2, max=100) 370 return {"x": {1: dim1_x, 2: dim2_x}, "y": {1: dim2_x, 2: dim1_x}} 371 372 model = Mul() 373 ep = torch.export.export( 374 model, model.get_example_inputs(), dynamic_shapes=model.get_dynamic_shapes() 375 ) 376 377 to_edge( 378 ep, 379 compile_config=EdgeCompileConfig( 380 _check_ir_validity=True, 381 ), 382 ) 383 384 def test_transform_dict_api(self): 385 edge_manager = to_edge(get_exported_programs(), get_config_methods()) 386 387 transformed_edge = edge_manager.transform( 388 { 389 "forward": [ 390 AddToMulPassEdge(), 391 ] 392 } 393 ) 394 395 self.assertEqual( 396 transformed_edge.exported_program("forward").module()( 397 torch.ones(1), torch.ones(1) 398 ), 399 torch.ones(1), # x * y * x 400 ) 401 402 self.assertEqual( 403 transformed_edge.exported_program("foo").module()( 404 torch.ones(1), 405 ), 406 torch.ones(1) + 1, # x + 1 407 ) 408 409 def test_edge_to_backend_replaces_subgraph(self): 410 edge_manager: EdgeProgramManager = to_edge( 411 get_exported_programs(), get_config_methods() 412 ) 413 delegate_manager: EdgeProgramManager = edge_manager.to_backend( 414 AddMulPartitionerDemo() 415 ) 416 417 forward_program = delegate_manager.exported_program("forward") 418 self.assertEqual( 419 forward_program.module()(torch.ones(1), torch.ones(1)), 420 torch.ones(1) + 1, # x * y + x 421 ) 422 423 add_nodes = [ 424 node 425 for node in forward_program.graph_module.graph.nodes 426 if node.op == "call_function" 427 and node.target == exir_ops.edge.aten.add.Tensor 428 ] 429 self.assertEqual(len(add_nodes), 0) 430 431 foo_program = delegate_manager.exported_program("foo") 432 add_nodes = [ 433 node 434 for node in foo_program.graph_module.graph.nodes 435 if node.op == "call_function" 436 and node.target == exir_ops.edge.aten.add.Tensor 437 ] 438 self.assertEqual(len(add_nodes), 0) 439 440 lowered_submods = get_lowered_submodules(foo_program.graph_module) 441 self.assertEqual(len(lowered_submods), 1) 442 443 # original unchanged 444 lowered_submods = get_lowered_submodules( 445 edge_manager.exported_program("forward").graph_module 446 ) 447 self.assertEqual(len(lowered_submods), 0) 448 449 # two delegate blobs for forward and foo 450 self.assertEqual( 451 len( 452 delegate_manager.to_executorch(ExecutorchBackendConfig()) 453 ._emitter_output.program.execution_plan[0] 454 .delegates 455 ), 456 1, 457 ) 458 self.assertEqual( 459 len( 460 delegate_manager.to_executorch(ExecutorchBackendConfig()) 461 ._emitter_output.program.execution_plan[1] 462 .delegates 463 ), 464 1, 465 ) 466 467 def test_edge_to_backend_selective(self): 468 edge_manager: EdgeProgramManager = to_edge( 469 get_exported_programs(), get_config_methods() 470 ) 471 delegate_manager: EdgeProgramManager = edge_manager.to_backend( 472 {"forward": AddMulPartitionerDemo()} 473 ) 474 475 forward_program = delegate_manager.exported_program("forward") 476 self.assertEqual( 477 forward_program.module()(torch.ones(1), torch.ones(1)), 478 torch.ones(1) + 1, # x * y + x 479 ) 480 481 add_nodes = [ 482 node 483 for node in forward_program.graph_module.graph.nodes 484 if node.op == "call_function" 485 and node.target == exir_ops.edge.aten.add.Tensor 486 ] 487 self.assertEqual(len(add_nodes), 0) 488 489 # foo unchanged 490 lowered_submods = get_lowered_submodules( 491 delegate_manager.exported_program("foo").graph_module 492 ) 493 self.assertEqual(len(lowered_submods), 0) 494 495 # original unchanged 496 lowered_submods = get_lowered_submodules( 497 edge_manager.exported_program("forward").graph_module 498 ) 499 self.assertEqual(len(lowered_submods), 0) 500 501 # one delegate blob for forward 502 self.assertEqual( 503 len( 504 delegate_manager.to_executorch( 505 ExecutorchBackendConfig( 506 extract_delegate_segments=False, 507 ) 508 ) 509 ._emitter_output.program.execution_plan[0] # foo 510 .delegates 511 ), 512 0, 513 ) 514 self.assertEqual( 515 len( 516 delegate_manager.to_executorch( 517 ExecutorchBackendConfig( 518 extract_delegate_segments=False, 519 ) 520 ) 521 ._emitter_output.program.execution_plan[1] # forward 522 .delegates 523 ), 524 1, 525 ) 526 527 def test_edge_manager_dialect(self): 528 edge_manager: EdgeProgramManager = to_edge( 529 get_exported_programs(), get_config_methods() 530 ) 531 self.assertTrue(edge_manager.exported_program().dialect == "EDGE") 532 533 def _test_edge_dialect_verifier( 534 self, callable, validate_ir=True, exception_list=None 535 ): 536 from executorch.exir import EdgeCompileConfig 537 538 edge_compile_config = EdgeCompileConfig( 539 _check_ir_validity=validate_ir, 540 _core_aten_ops_exception_list=exception_list, 541 ) 542 # pre-autograd export. eventually this will become torch.export 543 one = torch.ones(1, dtype=torch.float) 544 two = torch.ones(1, dtype=torch.int32) 545 inputs = ( 546 one, 547 two, 548 ) 549 if not isinstance(callable, torch.nn.Module): 550 callable = WrapperModule(callable) 551 552 exported_foo = export(callable, inputs) 553 _ = to_edge(exported_foo, compile_config=edge_compile_config) 554 555 def test_edge_dialect_custom_op(self): 556 # We shouldn't error out if there's a custom op in the graph. 557 def _use_foo_add(a: torch.Tensor, b: torch.Tensor): 558 return torch.ops.exir_program_test_op.foo(a, b) 559 560 from torch._export.verifier import SpecViolationError 561 562 try: 563 # This should not raise error 564 self._test_edge_dialect_verifier(_use_foo_add) 565 self._test_edge_dialect_verifier(_use_foo_add, False) 566 except SpecViolationError: 567 self.fail("Should not error out on custom op") 568 569 def get_num_nondecomposed_ops(self, ep, partitioner): 570 # count the number of aten ops that the partitioner can delegate 571 # we do this by running run_decompositions() with the preserved ops given 572 # to us by the partitioner. Then we count the number of preserved aten ops 573 # which pass the filter_ops fn given by the partitioner 574 reference_ep = copy.deepcopy(ep) 575 aten_ops_not_decomposed, filter_ops = partitioner.ops_to_not_decompose(ep) 576 table = _default_decomposition_table() 577 for op in aten_ops_not_decomposed: 578 table.pop(op, None) 579 reference_decomp_ep = reference_ep.run_decompositions(decomp_table=table) 580 num_non_decomposed_aten_ops = 0 581 for node in reference_decomp_ep.graph.nodes: 582 if ( 583 node.op == "call_function" 584 and node.target in aten_ops_not_decomposed 585 and (filter_ops(node) if filter_ops else True) 586 ): 587 num_non_decomposed_aten_ops += 1 588 return num_non_decomposed_aten_ops 589 590 def _test_model_with_non_decomp_partitioner(self, model: torch.nn.Module): 591 # This is the pre-dispatch export that we will be switching to primarily 592 # in the near future. The input to to_edge_transform_and_lower needs to 593 # be a graph generated by this pre dispatch export. 594 # pyre-fixme[29]: `Union[torch._tensor.Tensor, 595 # torch.nn.modules.module.Module]` is not a function. 596 ep = _export(model, model._get_random_inputs(), pre_dispatch=True) 597 non_decomp_partitioner = NonDecompTestPartitioner() 598 599 num_non_decomposed_aten_ops = self.get_num_nondecomposed_ops( 600 ep, non_decomp_partitioner 601 ) 602 603 # run to_edge_trasnform_and_lower 604 edge = to_edge_transform_and_lower( 605 ep, 606 compile_config=EdgeCompileConfig(), 607 partitioner=[NonDecompTestPartitioner()], 608 ) 609 # Check that non_decomposed_edge_ops are all consumed by the delegate 610 non_decomposed_edge_ops = ( 611 non_decomp_partitioner.supported_non_decomposed_edge_ops 612 ) 613 for node in edge.exported_program().graph.nodes: 614 if node.op == "call_function": 615 self.assertTrue(node.target not in non_decomposed_edge_ops) 616 617 # check that the number of call_delegate_nodes is equal to the number of 618 # non_decomposed_aten_ops we found above 619 num_call_delegates = 0 620 for node in edge.exported_program().graph_module.graph.nodes: 621 # There should only be a single call_function node in the graph 622 # and that should be a call_delegate node. 623 if ( 624 node.op == "call_function" 625 and node.target == torch.ops.higher_order.executorch_call_delegate 626 ): 627 num_call_delegates += 1 628 629 self.assertEqual(num_call_delegates, num_non_decomposed_aten_ops) 630 631 def test_to_edge_transform_and_lower(self): 632 self._test_model_with_non_decomp_partitioner(TestLinear()) 633 634 self._test_model_with_non_decomp_partitioner(TestSDPA()) 635 636 self._test_model_with_non_decomp_partitioner(TestLinearSDPACombined()) 637 638 self._test_model_with_non_decomp_partitioner(TestUpsample()) 639 640 self._test_model_with_non_decomp_partitioner(TestLSTM()) 641 642 def test_to_edge_transform_and_lower_with_exception(self): 643 class TestLinear(torch.nn.Module): 644 def __init__(self): 645 super().__init__() 646 self.linear = torch.nn.Linear(32, 16, bias=True) 647 self.linear_no_bias = torch.nn.Linear(32, 16, bias=False) 648 649 def forward(self, x): 650 return (self.linear(x), self.linear_no_bias(x)) 651 652 @classmethod 653 def _get_random_inputs(cls): 654 x = torch.rand(8, 32) 655 return (x,) 656 657 model = TestLinear() 658 ep = _export(model, model._get_random_inputs(), pre_dispatch=True) 659 edge = to_edge_transform_and_lower( 660 ep, 661 compile_config=EdgeCompileConfig(), 662 partitioner=[NonDecompTestPartitioner()], 663 ) 664 665 def count_nodes(graph_module, target): 666 count = 0 667 for node in graph_module.graph.nodes: 668 if node.op == "call_function" and node.target == target: 669 count += 1 670 return count 671 672 # There should be 1 call_delegate node and 1 node for aten.mm.default for the 673 # linear that doesn't have a bias which was decomposed as the partitioner 674 # said this node wasn't supported. 675 self.assertEqual( 676 count_nodes( 677 edge.exported_program().graph_module, 678 torch.ops.higher_order.executorch_call_delegate, 679 ), 680 1, 681 ) 682 self.assertEqual( 683 count_nodes( 684 edge.exported_program().graph_module, exir_ops.edge.aten.mm.default 685 ), 686 1, 687 ) 688 689 def test_edge_dialect_non_core_aten_ops(self): 690 class LinalgNorm(torch.nn.Module): 691 def __init__(self): 692 super().__init__() 693 694 def forward(self, x: torch.Tensor) -> torch.Tensor: 695 return torch.linalg.norm(x) 696 697 from torch._export.verifier import SpecViolationError 698 699 input = torch.arange(9, dtype=torch.float) - 4 700 ep = torch.export.export(LinalgNorm(), (input,)) 701 702 # aten::linalg_norm is not a core op, so it should error out 703 with self.assertRaises(SpecViolationError): 704 _ = to_edge(ep, compile_config=EdgeCompileConfig(_check_ir_validity=True)) 705 706 # with exception list, it should not error out 707 try: 708 # This should not raise error 709 _ = to_edge( 710 ep, 711 compile_config=EdgeCompileConfig( 712 _check_ir_validity=True, 713 _core_aten_ops_exception_list=[ 714 torch.ops.aten.linalg_vector_norm.default 715 ], 716 ), 717 ) 718 except SpecViolationError: 719 self.fail("Should not error out on linalg_vector_norm op") 720 721 def _test_to_edge_with_preserved_ops( 722 self, program, preserved_ops, expected_preserved_ops 723 ): 724 edge = to_edge_with_preserved_ops(program, preserve_ops=preserved_ops) 725 726 def count_nodes(graph_module, target): 727 count = 0 728 for node in graph_module.graph.nodes: 729 if node.op == "call_function" and node.target in target: 730 count += 1 731 return count 732 733 aten_ops_non_decomposed = count_nodes( 734 program.graph_module, 735 preserved_ops, 736 ) 737 738 edge_ops_non_decomposed = count_nodes( 739 edge.exported_program().graph_module, 740 expected_preserved_ops, 741 ) 742 743 self.assertEqual(aten_ops_non_decomposed, edge_ops_non_decomposed) 744 745 def test_to_edge_with_single_preserved_op(self): 746 model = TestLinear() 747 program = torch.export.export(model, model._get_random_inputs()) 748 749 ops_not_to_decompose = [ 750 torch.ops.aten.linear.default, 751 ] 752 expected_non_decomposed_edge_ops = [ 753 exir_ops.edge.aten.linear.default, 754 ] 755 756 self._test_to_edge_with_preserved_ops( 757 program, ops_not_to_decompose, expected_non_decomposed_edge_ops 758 ) 759 760 def test_to_edge_with_partial_ops_preserved(self): 761 model = TestLinearSDPACombined() 762 program = torch.export.export(model, model._get_random_inputs()) 763 764 ops_not_to_decompose = [ 765 torch.ops.aten.linear.default, 766 ] 767 expected_non_decomposed_edge_ops = [ 768 exir_ops.edge.aten.linear.default, 769 ] 770 771 self._test_to_edge_with_preserved_ops( 772 program, ops_not_to_decompose, expected_non_decomposed_edge_ops 773 ) 774 775 def test_to_edge_with_multiple_ops_preserved(self): 776 model = TestLinearSDPACombined() 777 program = torch.export.export(model, model._get_random_inputs()) 778 779 ops_not_to_decompose = [ 780 torch.ops.aten.linear.default, 781 torch.ops.aten.scaled_dot_product_attention.default, 782 ] 783 expected_non_decomposed_edge_ops = [ 784 exir_ops.edge.aten.linear.default, 785 exir_ops.edge.aten.scaled_dot_product_attention.default, 786 ] 787 788 self._test_to_edge_with_preserved_ops( 789 program, ops_not_to_decompose, expected_non_decomposed_edge_ops 790 ) 791 792 def test_to_edge_with_preserved_ops_not_in_model(self): 793 model = TestSDPA() 794 program = torch.export.export(model, model._get_random_inputs()) 795 796 ops_not_to_decompose = [ 797 torch.ops.aten.linear.default, 798 ] 799 expected_non_decomposed_edge_ops = [ 800 exir_ops.edge.aten.linear.default, 801 ] 802 803 self._test_to_edge_with_preserved_ops( 804 program, ops_not_to_decompose, expected_non_decomposed_edge_ops 805 ) 806