xref: /aosp_15_r20/external/executorch/exir/program/test/test_program.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pye-strict
8
9import copy
10import unittest
11from typing import Any, Dict
12
13import torch
14from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig
15from executorch.exir.backend.test.op_partitioner_demo import (
16    AddMulPartitionerDemo,
17    NonDecompTestPartitioner,
18)
19from executorch.exir.dialects._ops import ops as exir_ops
20from executorch.exir.error import ExportError
21from executorch.exir.lowered_backend_module import get_lowered_submodules
22from executorch.exir.pass_base import ExportPass
23from executorch.exir.passes import MemoryPlanningPass
24from executorch.exir.program._program import (
25    EdgeProgramManager,
26    ExecutorchProgramManager,
27    to_edge,
28    to_edge_transform_and_lower,
29    to_edge_with_preserved_ops,
30)
31from executorch.exir.tracer import _default_decomposition_table
32from executorch.exir.verification.verifier import EXIREdgeDialectVerifier
33
34from executorch.extension.pybindings.portable_lib import (
35    _load_for_executorch_from_buffer,
36)
37from torch.export import Dim, export, ExportedProgram
38from torch.export._trace import _export
39
40from torch.library import impl, Library
41from torch.nn import functional as F
42
43
44class TestLinear(torch.nn.Module):
45    def __init__(self):
46        super().__init__()
47        self.linear = torch.nn.Linear(32, 16, bias=True)
48
49    def forward(self, x):
50        return self.linear(x)
51
52    @classmethod
53    def _get_random_inputs(cls):
54        x = torch.rand(8, 32)
55        return (x,)
56
57
58class TestSDPA(torch.nn.Module):
59    def __init__(self):
60        super().__init__()
61
62    def forward(self, query, key, value):
63        return torch.ops.aten.scaled_dot_product_attention.default(query, key, value)
64
65    @classmethod
66    def _get_random_inputs(cls):
67        d_k = 64
68        batch = 16
69        seq_len = 10
70        query = torch.rand(batch, seq_len, d_k)
71        key = torch.rand(batch, seq_len, d_k)
72        value = torch.rand(batch, seq_len, d_k)
73        return (query, key, value)
74
75
76class TestLinearSDPACombined(torch.nn.Module):
77    def __init__(self):
78        super().__init__()
79        self.linear = torch.nn.Linear(32, 16, bias=True)
80
81    def forward(self, x, query, key, value):
82        x = self.linear(x)
83        return (
84            x,
85            torch.ops.aten.scaled_dot_product_attention.default(query, key, value),
86        )
87
88    @classmethod
89    def _get_random_inputs(cls):
90        return TestLinear._get_random_inputs() + TestSDPA._get_random_inputs()
91
92
93class TestUpsample(torch.nn.Module):
94    def __init__(self):
95        super().__init__()
96
97    def forward(self, x):
98        x = F.interpolate(x, scale_factor=2, mode="nearest")
99        return x
100
101    @classmethod
102    def _get_random_inputs(cls):
103        x = torch.randn(1, 1, 8, 8)
104        return (x,)
105
106
107class TestLSTM(torch.nn.Module):
108    def __init__(self):
109        super().__init__()
110        self.lstm = torch.nn.LSTM(input_size=8, hidden_size=16, batch_first=True)
111
112    def forward(self, x):
113        return self.lstm(x)
114
115    @classmethod
116    def _get_random_inputs(cls):
117        return (torch.rand(1, 10, 8),)
118
119
120class WrapperModule(torch.nn.Module):
121    def __init__(self, fn):
122        super().__init__()
123        self.fn = fn
124
125    def forward(self, *args, **kwargs):
126        return self.fn(*args, **kwargs)
127
128
129lib = Library("exir_program_test_op", "DEF")
130
131# Fake a operator for testing.
132# This operator takes two tensors as input and returns the first one.
133lib.define("foo(Tensor self, Tensor other) -> Tensor")
134
135
136@impl(lib, "foo", "CPU")
137def foo(a, b):
138    # do nothing and return a.
139    return a + b
140
141
142@impl(lib, "foo", "Meta")
143def foo_meta(a, b):
144    # do nothing and return a.
145    return torch.empty_like(a)
146
147
148def get_exported_programs() -> Dict[str, ExportedProgram]:
149    class Forward(torch.nn.Module):
150        def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
151            z = torch.mul(x, y)
152            return torch.add(z, x)
153
154    forward = Forward()
155
156    class Foo(torch.nn.Module):
157        def forward(self, x: torch.Tensor) -> torch.Tensor:
158            return torch.add(x, torch.ones(1))
159
160    foo = Foo()
161
162    programs = {}
163    programs["forward"] = export(
164        forward,
165        args=(
166            torch.ones(1),
167            torch.zeros(1),
168        ),
169    ).run_decompositions()
170    programs["foo"] = export(
171        foo,
172        (torch.ones(1),),
173    ).run_decompositions()
174    return programs
175
176
177def get_config_methods() -> Dict[str, Any]:
178    def bam():
179        return 3
180
181    def bar():
182        return "bar"
183
184    return {"bam": bam(), "bar": bar()}
185
186
187class AddToMulPassEdge(ExportPass):
188    def call_operator(self, op, args, kwargs, meta):
189        if op == exir_ops.edge.aten.add.Tensor:
190            return super().call_operator(
191                exir_ops.edge.aten.mul.Tensor, args, kwargs, meta
192            )
193        else:
194            return super().call_operator(op, args, kwargs, meta)
195
196
197class TestProgramManagers(unittest.TestCase):
198    def test_edge_manager_basic_api(self):
199        edge_manager: EdgeProgramManager = to_edge(
200            get_exported_programs(), get_config_methods()
201        )
202
203        # test basic apis
204        self.assertEqual(edge_manager.methods, {"forward", "foo"})
205        self.assertEqual(edge_manager.config_methods, {"bam", "bar"})
206
207        # test dialect is correct
208        try:
209            EXIREdgeDialectVerifier()(
210                edge_manager.exported_program("forward").graph_module
211            )
212            EXIREdgeDialectVerifier()(edge_manager.exported_program("foo").graph_module)
213        except ExportError as e:
214            self.assertTrue(False, msg="Graph not in edge dialect : " + e.msg)
215
216    def test_executorch_manager_basic_api(self):
217        executorch_manager: ExecutorchProgramManager = to_edge(
218            get_exported_programs(), get_config_methods()
219        ).to_executorch()
220
221        # test basic apis
222        self.assertEqual(executorch_manager.methods, {"forward", "foo"})
223        self.assertEqual(executorch_manager.config_methods, {"bam", "bar"})
224
225        # test that the emitted output is correct
226        self.assertEqual(
227            len(executorch_manager._emitter_output.program.execution_plan), 4
228        )
229
230        # test that the buffer is correct
231        executorch_module = _load_for_executorch_from_buffer(executorch_manager.buffer)
232        self.assertEqual(
233            executorch_module.run_method("forward", (torch.ones(1), torch.zeros(1)))[0],
234            torch.ones(1),
235        )
236        self.assertEqual(
237            executorch_module.run_method("foo", (torch.ones(1),))[0],
238            torch.ones(1) + torch.ones(1),
239        )
240        self.assertEqual(
241            executorch_module.run_method("bar", ())[0],
242            "bar",
243        )
244        self.assertEqual(
245            executorch_module.run_method("bam", ())[0],
246            3,
247        )
248
249    def test_executorch_manager_multi_config(self):
250        def get_executorch_memory_planning_passes() -> Dict[str, MemoryPlanningPass]:
251            return {
252                "forward": MemoryPlanningPass(
253                    alloc_graph_input=True,
254                    alloc_graph_output=False,
255                ),
256                "foo": MemoryPlanningPass(
257                    alloc_graph_input=False,
258                    alloc_graph_output=True,
259                ),
260            }
261
262        executorch_manager: ExecutorchProgramManager = to_edge(
263            get_exported_programs(), get_config_methods()
264        ).to_executorch(
265            ExecutorchBackendConfig(
266                memory_planning_pass=get_executorch_memory_planning_passes()
267            )
268        )
269
270        method = executorch_manager._emitter_output.program.execution_plan[0]
271        if method.name == "forward":
272            for input_val in method.inputs:
273                evalue = method.values[input_val]
274                self.assertEqual(evalue.val.allocation_info, None)
275            for output_val in method.outputs:
276                evalue = method.values[output_val]
277                self.assertNotEqual(evalue.val.allocation_info, None)
278        else:
279            for input_val in method.inputs:
280                evalue = method.values[input_val]
281                self.assertEqual(evalue.val.allocation_info, None)
282            for output_val in method.outputs:
283                evalue = method.values[output_val]
284                self.assertNotEqual(evalue.val.allocation_info, None)
285
286    def test_no_getattr(self):
287        class Mul(torch.nn.Module):
288            def forward(self, x: torch.Tensor) -> torch.Tensor:
289                return x * 3.14
290
291        mul = Mul()
292        ep = to_edge(torch.export.export(mul, (torch.ones(1),))).exported_program()
293        for node in ep.graph.nodes:
294            self.assertNotEqual(node.op, "get_attr")
295        self.assertEqual(
296            len([node for node in ep.graph.nodes if node.op == "placeholder"]), 2
297        )
298
299    def test_constraint_present_after_dce(self):
300        import executorch.exir as exir
301
302        class M(torch.nn.Module):
303            def forward(self, x, y):
304                z = y.item()
305                torch._check(z > 0)
306                torch._check(z < 4)
307                return x[z : z + y.shape[0]]
308
309        ep = torch.export.export(M(), (torch.randn(10), torch.tensor([3])))
310
311        edge_manager = to_edge(
312            ep, compile_config=exir.EdgeCompileConfig(_check_ir_validity=False)
313        )
314        edge_manager.to_executorch()
315
316    def test_edge_manager_transform(self):
317        edge_manager: EdgeProgramManager = to_edge(
318            get_exported_programs(), get_config_methods()
319        )
320
321        original_res = edge_manager.exported_program("forward").module()(
322            torch.ones(1), torch.ones(1)
323        )
324
325        # perform transformation
326        transformed_edge = edge_manager.transform(
327            [
328                AddToMulPassEdge(),
329            ]
330        )
331
332        # still have all our methods
333        self.assertEqual(len(transformed_edge.methods), 2)
334        self.assertEqual(len(transformed_edge.config_methods), 2)
335
336        # transformation was applied
337        self.assertEqual(
338            transformed_edge.exported_program("forward").module()(
339                torch.ones(1), torch.ones(1)
340            ),
341            torch.ones(1),  # x * y * x
342        )
343
344        # original unchanged
345        self.assertEqual(
346            edge_manager.exported_program("forward").module()(
347                torch.ones(1), torch.ones(1)
348            ),
349            original_res,  # x * y + x
350        )
351
352    def test_issue_3659(self):
353
354        class Mul(torch.nn.Module):
355            def __init__(self):
356                super(Mul, self).__init__()
357
358            def forward(self, x: torch.Tensor, y: torch.Tensor):
359                return torch.matmul(x, y)
360
361            def get_eager_model(self) -> torch.nn.Module:
362                return self
363
364            def get_example_inputs(self):
365                return (torch.randn(1, 3, 10), torch.randn(1, 10, 3))
366
367            def get_dynamic_shapes(self):
368                dim1_x = Dim("Dot_dim1_x", min=2, max=100)
369                dim2_x = Dim("Dot_dim2_x", min=2, max=100)
370                return {"x": {1: dim1_x, 2: dim2_x}, "y": {1: dim2_x, 2: dim1_x}}
371
372        model = Mul()
373        ep = torch.export.export(
374            model, model.get_example_inputs(), dynamic_shapes=model.get_dynamic_shapes()
375        )
376
377        to_edge(
378            ep,
379            compile_config=EdgeCompileConfig(
380                _check_ir_validity=True,
381            ),
382        )
383
384    def test_transform_dict_api(self):
385        edge_manager = to_edge(get_exported_programs(), get_config_methods())
386
387        transformed_edge = edge_manager.transform(
388            {
389                "forward": [
390                    AddToMulPassEdge(),
391                ]
392            }
393        )
394
395        self.assertEqual(
396            transformed_edge.exported_program("forward").module()(
397                torch.ones(1), torch.ones(1)
398            ),
399            torch.ones(1),  # x * y * x
400        )
401
402        self.assertEqual(
403            transformed_edge.exported_program("foo").module()(
404                torch.ones(1),
405            ),
406            torch.ones(1) + 1,  # x + 1
407        )
408
409    def test_edge_to_backend_replaces_subgraph(self):
410        edge_manager: EdgeProgramManager = to_edge(
411            get_exported_programs(), get_config_methods()
412        )
413        delegate_manager: EdgeProgramManager = edge_manager.to_backend(
414            AddMulPartitionerDemo()
415        )
416
417        forward_program = delegate_manager.exported_program("forward")
418        self.assertEqual(
419            forward_program.module()(torch.ones(1), torch.ones(1)),
420            torch.ones(1) + 1,  # x * y + x
421        )
422
423        add_nodes = [
424            node
425            for node in forward_program.graph_module.graph.nodes
426            if node.op == "call_function"
427            and node.target == exir_ops.edge.aten.add.Tensor
428        ]
429        self.assertEqual(len(add_nodes), 0)
430
431        foo_program = delegate_manager.exported_program("foo")
432        add_nodes = [
433            node
434            for node in foo_program.graph_module.graph.nodes
435            if node.op == "call_function"
436            and node.target == exir_ops.edge.aten.add.Tensor
437        ]
438        self.assertEqual(len(add_nodes), 0)
439
440        lowered_submods = get_lowered_submodules(foo_program.graph_module)
441        self.assertEqual(len(lowered_submods), 1)
442
443        # original unchanged
444        lowered_submods = get_lowered_submodules(
445            edge_manager.exported_program("forward").graph_module
446        )
447        self.assertEqual(len(lowered_submods), 0)
448
449        # two delegate blobs for forward and foo
450        self.assertEqual(
451            len(
452                delegate_manager.to_executorch(ExecutorchBackendConfig())
453                ._emitter_output.program.execution_plan[0]
454                .delegates
455            ),
456            1,
457        )
458        self.assertEqual(
459            len(
460                delegate_manager.to_executorch(ExecutorchBackendConfig())
461                ._emitter_output.program.execution_plan[1]
462                .delegates
463            ),
464            1,
465        )
466
467    def test_edge_to_backend_selective(self):
468        edge_manager: EdgeProgramManager = to_edge(
469            get_exported_programs(), get_config_methods()
470        )
471        delegate_manager: EdgeProgramManager = edge_manager.to_backend(
472            {"forward": AddMulPartitionerDemo()}
473        )
474
475        forward_program = delegate_manager.exported_program("forward")
476        self.assertEqual(
477            forward_program.module()(torch.ones(1), torch.ones(1)),
478            torch.ones(1) + 1,  # x * y + x
479        )
480
481        add_nodes = [
482            node
483            for node in forward_program.graph_module.graph.nodes
484            if node.op == "call_function"
485            and node.target == exir_ops.edge.aten.add.Tensor
486        ]
487        self.assertEqual(len(add_nodes), 0)
488
489        # foo unchanged
490        lowered_submods = get_lowered_submodules(
491            delegate_manager.exported_program("foo").graph_module
492        )
493        self.assertEqual(len(lowered_submods), 0)
494
495        # original unchanged
496        lowered_submods = get_lowered_submodules(
497            edge_manager.exported_program("forward").graph_module
498        )
499        self.assertEqual(len(lowered_submods), 0)
500
501        # one delegate blob for forward
502        self.assertEqual(
503            len(
504                delegate_manager.to_executorch(
505                    ExecutorchBackendConfig(
506                        extract_delegate_segments=False,
507                    )
508                )
509                ._emitter_output.program.execution_plan[0]  # foo
510                .delegates
511            ),
512            0,
513        )
514        self.assertEqual(
515            len(
516                delegate_manager.to_executorch(
517                    ExecutorchBackendConfig(
518                        extract_delegate_segments=False,
519                    )
520                )
521                ._emitter_output.program.execution_plan[1]  # forward
522                .delegates
523            ),
524            1,
525        )
526
527    def test_edge_manager_dialect(self):
528        edge_manager: EdgeProgramManager = to_edge(
529            get_exported_programs(), get_config_methods()
530        )
531        self.assertTrue(edge_manager.exported_program().dialect == "EDGE")
532
533    def _test_edge_dialect_verifier(
534        self, callable, validate_ir=True, exception_list=None
535    ):
536        from executorch.exir import EdgeCompileConfig
537
538        edge_compile_config = EdgeCompileConfig(
539            _check_ir_validity=validate_ir,
540            _core_aten_ops_exception_list=exception_list,
541        )
542        # pre-autograd export. eventually this will become torch.export
543        one = torch.ones(1, dtype=torch.float)
544        two = torch.ones(1, dtype=torch.int32)
545        inputs = (
546            one,
547            two,
548        )
549        if not isinstance(callable, torch.nn.Module):
550            callable = WrapperModule(callable)
551
552        exported_foo = export(callable, inputs)
553        _ = to_edge(exported_foo, compile_config=edge_compile_config)
554
555    def test_edge_dialect_custom_op(self):
556        # We shouldn't error out if there's a custom op in the graph.
557        def _use_foo_add(a: torch.Tensor, b: torch.Tensor):
558            return torch.ops.exir_program_test_op.foo(a, b)
559
560        from torch._export.verifier import SpecViolationError
561
562        try:
563            # This should not raise error
564            self._test_edge_dialect_verifier(_use_foo_add)
565            self._test_edge_dialect_verifier(_use_foo_add, False)
566        except SpecViolationError:
567            self.fail("Should not error out on custom op")
568
569    def get_num_nondecomposed_ops(self, ep, partitioner):
570        # count the number of aten ops that the partitioner can delegate
571        # we do this by running run_decompositions() with the preserved ops given
572        # to us by the partitioner. Then we count the number of preserved aten ops
573        # which pass the filter_ops fn given by the partitioner
574        reference_ep = copy.deepcopy(ep)
575        aten_ops_not_decomposed, filter_ops = partitioner.ops_to_not_decompose(ep)
576        table = _default_decomposition_table()
577        for op in aten_ops_not_decomposed:
578            table.pop(op, None)
579        reference_decomp_ep = reference_ep.run_decompositions(decomp_table=table)
580        num_non_decomposed_aten_ops = 0
581        for node in reference_decomp_ep.graph.nodes:
582            if (
583                node.op == "call_function"
584                and node.target in aten_ops_not_decomposed
585                and (filter_ops(node) if filter_ops else True)
586            ):
587                num_non_decomposed_aten_ops += 1
588        return num_non_decomposed_aten_ops
589
590    def _test_model_with_non_decomp_partitioner(self, model: torch.nn.Module):
591        # This is the pre-dispatch export that we will be switching to primarily
592        # in the near future. The input to to_edge_transform_and_lower needs to
593        # be a graph generated by this pre dispatch export.
594        # pyre-fixme[29]: `Union[torch._tensor.Tensor,
595        #  torch.nn.modules.module.Module]` is not a function.
596        ep = _export(model, model._get_random_inputs(), pre_dispatch=True)
597        non_decomp_partitioner = NonDecompTestPartitioner()
598
599        num_non_decomposed_aten_ops = self.get_num_nondecomposed_ops(
600            ep, non_decomp_partitioner
601        )
602
603        # run to_edge_trasnform_and_lower
604        edge = to_edge_transform_and_lower(
605            ep,
606            compile_config=EdgeCompileConfig(),
607            partitioner=[NonDecompTestPartitioner()],
608        )
609        # Check that non_decomposed_edge_ops are all consumed by the delegate
610        non_decomposed_edge_ops = (
611            non_decomp_partitioner.supported_non_decomposed_edge_ops
612        )
613        for node in edge.exported_program().graph.nodes:
614            if node.op == "call_function":
615                self.assertTrue(node.target not in non_decomposed_edge_ops)
616
617        # check that the number of call_delegate_nodes is equal to the number of
618        # non_decomposed_aten_ops we found above
619        num_call_delegates = 0
620        for node in edge.exported_program().graph_module.graph.nodes:
621            # There should only be a single call_function node in the graph
622            # and that should be a call_delegate node.
623            if (
624                node.op == "call_function"
625                and node.target == torch.ops.higher_order.executorch_call_delegate
626            ):
627                num_call_delegates += 1
628
629        self.assertEqual(num_call_delegates, num_non_decomposed_aten_ops)
630
631    def test_to_edge_transform_and_lower(self):
632        self._test_model_with_non_decomp_partitioner(TestLinear())
633
634        self._test_model_with_non_decomp_partitioner(TestSDPA())
635
636        self._test_model_with_non_decomp_partitioner(TestLinearSDPACombined())
637
638        self._test_model_with_non_decomp_partitioner(TestUpsample())
639
640        self._test_model_with_non_decomp_partitioner(TestLSTM())
641
642    def test_to_edge_transform_and_lower_with_exception(self):
643        class TestLinear(torch.nn.Module):
644            def __init__(self):
645                super().__init__()
646                self.linear = torch.nn.Linear(32, 16, bias=True)
647                self.linear_no_bias = torch.nn.Linear(32, 16, bias=False)
648
649            def forward(self, x):
650                return (self.linear(x), self.linear_no_bias(x))
651
652            @classmethod
653            def _get_random_inputs(cls):
654                x = torch.rand(8, 32)
655                return (x,)
656
657        model = TestLinear()
658        ep = _export(model, model._get_random_inputs(), pre_dispatch=True)
659        edge = to_edge_transform_and_lower(
660            ep,
661            compile_config=EdgeCompileConfig(),
662            partitioner=[NonDecompTestPartitioner()],
663        )
664
665        def count_nodes(graph_module, target):
666            count = 0
667            for node in graph_module.graph.nodes:
668                if node.op == "call_function" and node.target == target:
669                    count += 1
670            return count
671
672        # There should be 1 call_delegate node and 1 node for aten.mm.default for the
673        # linear that doesn't have a bias which was decomposed as the partitioner
674        # said this node wasn't supported.
675        self.assertEqual(
676            count_nodes(
677                edge.exported_program().graph_module,
678                torch.ops.higher_order.executorch_call_delegate,
679            ),
680            1,
681        )
682        self.assertEqual(
683            count_nodes(
684                edge.exported_program().graph_module, exir_ops.edge.aten.mm.default
685            ),
686            1,
687        )
688
689    def test_edge_dialect_non_core_aten_ops(self):
690        class LinalgNorm(torch.nn.Module):
691            def __init__(self):
692                super().__init__()
693
694            def forward(self, x: torch.Tensor) -> torch.Tensor:
695                return torch.linalg.norm(x)
696
697        from torch._export.verifier import SpecViolationError
698
699        input = torch.arange(9, dtype=torch.float) - 4
700        ep = torch.export.export(LinalgNorm(), (input,))
701
702        # aten::linalg_norm is not a core op, so it should error out
703        with self.assertRaises(SpecViolationError):
704            _ = to_edge(ep, compile_config=EdgeCompileConfig(_check_ir_validity=True))
705
706        # with exception list, it should not error out
707        try:
708            # This should not raise error
709            _ = to_edge(
710                ep,
711                compile_config=EdgeCompileConfig(
712                    _check_ir_validity=True,
713                    _core_aten_ops_exception_list=[
714                        torch.ops.aten.linalg_vector_norm.default
715                    ],
716                ),
717            )
718        except SpecViolationError:
719            self.fail("Should not error out on linalg_vector_norm op")
720
721    def _test_to_edge_with_preserved_ops(
722        self, program, preserved_ops, expected_preserved_ops
723    ):
724        edge = to_edge_with_preserved_ops(program, preserve_ops=preserved_ops)
725
726        def count_nodes(graph_module, target):
727            count = 0
728            for node in graph_module.graph.nodes:
729                if node.op == "call_function" and node.target in target:
730                    count += 1
731            return count
732
733        aten_ops_non_decomposed = count_nodes(
734            program.graph_module,
735            preserved_ops,
736        )
737
738        edge_ops_non_decomposed = count_nodes(
739            edge.exported_program().graph_module,
740            expected_preserved_ops,
741        )
742
743        self.assertEqual(aten_ops_non_decomposed, edge_ops_non_decomposed)
744
745    def test_to_edge_with_single_preserved_op(self):
746        model = TestLinear()
747        program = torch.export.export(model, model._get_random_inputs())
748
749        ops_not_to_decompose = [
750            torch.ops.aten.linear.default,
751        ]
752        expected_non_decomposed_edge_ops = [
753            exir_ops.edge.aten.linear.default,
754        ]
755
756        self._test_to_edge_with_preserved_ops(
757            program, ops_not_to_decompose, expected_non_decomposed_edge_ops
758        )
759
760    def test_to_edge_with_partial_ops_preserved(self):
761        model = TestLinearSDPACombined()
762        program = torch.export.export(model, model._get_random_inputs())
763
764        ops_not_to_decompose = [
765            torch.ops.aten.linear.default,
766        ]
767        expected_non_decomposed_edge_ops = [
768            exir_ops.edge.aten.linear.default,
769        ]
770
771        self._test_to_edge_with_preserved_ops(
772            program, ops_not_to_decompose, expected_non_decomposed_edge_ops
773        )
774
775    def test_to_edge_with_multiple_ops_preserved(self):
776        model = TestLinearSDPACombined()
777        program = torch.export.export(model, model._get_random_inputs())
778
779        ops_not_to_decompose = [
780            torch.ops.aten.linear.default,
781            torch.ops.aten.scaled_dot_product_attention.default,
782        ]
783        expected_non_decomposed_edge_ops = [
784            exir_ops.edge.aten.linear.default,
785            exir_ops.edge.aten.scaled_dot_product_attention.default,
786        ]
787
788        self._test_to_edge_with_preserved_ops(
789            program, ops_not_to_decompose, expected_non_decomposed_edge_ops
790        )
791
792    def test_to_edge_with_preserved_ops_not_in_model(self):
793        model = TestSDPA()
794        program = torch.export.export(model, model._get_random_inputs())
795
796        ops_not_to_decompose = [
797            torch.ops.aten.linear.default,
798        ]
799        expected_non_decomposed_edge_ops = [
800            exir_ops.edge.aten.linear.default,
801        ]
802
803        self._test_to_edge_with_preserved_ops(
804            program, ops_not_to_decompose, expected_non_decomposed_edge_ops
805        )
806