xref: /aosp_15_r20/external/executorch/test/models/export_delegated_program.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import inspect
9import os
10import sys
11from typing import Dict, final, Optional, Sequence, Type
12
13import executorch.exir as exir
14
15import torch
16from executorch.exir import to_edge
17from executorch.exir.backend.backend_api import to_backend
18from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult
19from executorch.exir.backend.test.backend_with_compiler_demo import (
20    BackendWithCompilerDemo,
21)
22from torch import nn
23from torch.export import export
24
25"""Traces and exports delegated nn.Modules to ExecuTorch .pte program files.
26
27Creates two versions of each file:
28- <module-name>.pte: Delegate data stored in segments outside of the flatbuffer data.
29- <module-name>-nosegments.pte: Delegate data is stored directly in the flatbuffer data.
30
31This tool mainly exists to export programs for C++ tests, but can also
32be used to export models manually.
33"""
34
35#
36# Modules
37#
38
39
40class ModuleAddMul(nn.Module):
41    def __init__(self):
42        super().__init__()
43
44    def forward(
45        self, a: torch.Tensor, x: torch.Tensor, b: torch.Tensor
46    ) -> torch.Tensor:
47        y: torch.Tensor = torch.mm(a, x)
48        z: torch.Tensor = torch.add(y, b)
49        return z
50
51    def get_random_inputs(self) -> Sequence[torch.Tensor]:
52        return (torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2))
53
54
55#
56# Backends
57#
58
59
60@final
61class StubBackend(BackendDetails):
62    """No-op backend to test serialization/init."""
63
64    @staticmethod
65    def preprocess(*args, **kwargs) -> PreprocessResult:
66        return PreprocessResult(processed_bytes=b"StubBackend:data")
67
68
69#
70# Program logic
71#
72
73
74def export_module_to_program(
75    module_class: Type[nn.Module],
76    *,
77    backend_id: str,
78    extract_delegate_segments: bool,
79    constant_tensor_alignemnt: Optional[int] = None,
80    delegate_alignment: Optional[int] = None,
81    method: str = "forward",
82) -> bytes:
83    eager_module = module_class().eval()
84    inputs = ()
85    if hasattr(eager_module, "get_random_inputs"):
86        # pyre-fixme[29]: `Union[nn.modules.module.Module, torch._tensor.Tensor]` is
87        #  not a function.
88        inputs = eager_module.get_random_inputs()
89
90    class WrapperModule(torch.nn.Module):
91        def __init__(self, fn):
92            super().__init__()
93            self.fn = fn
94
95        def forward(self, *args, **kwargs):
96            return self.fn(*args, **kwargs)
97
98    edge: exir.EdgeProgramManager = to_edge(
99        export(WrapperModule(getattr(eager_module, method)), args=inputs)
100    )
101
102    lowered_module = to_backend(backend_id, edge.exported_program(), compile_specs=[])
103
104    class CompositeModule(nn.Module):
105        def __init__(self):
106            super().__init__()
107            self.lowered_module = lowered_module
108
109        def forward(self, *args, **kwargs):
110            return self.lowered_module(*args, **kwargs)
111
112    composite_module = CompositeModule()
113    composite_module(*inputs)
114
115    executorch_program = to_edge(export(composite_module, args=inputs)).to_executorch(
116        config=exir.ExecutorchBackendConfig(
117            extract_delegate_segments=extract_delegate_segments,
118            constant_tensor_alignment=constant_tensor_alignemnt,
119            delegate_alignment=delegate_alignment,
120        )
121    )
122
123    return executorch_program.buffer
124
125
126def main() -> None:
127    known_backend_ids = [
128        BackendWithCompilerDemo.__name__,
129        StubBackend.__name__,
130    ]
131
132    # These args are optimized for genrule usage. There's a lot of startup
133    # overhead for this tool, so it's faster to export multiple models at once
134    # when possible.
135    parser = argparse.ArgumentParser(
136        prog="export_delegated_program",
137        description="Exports delegated nn.Module models to ExecuTorch .pte files",
138    )
139    parser.add_argument(
140        "--modules",
141        help="Comma-separated list of model class names to export; "
142        + "e.g., '--modules=ModuleOne,ModuleTwo'",
143        type=lambda s: [item.strip() for item in s.split(",")],
144    )
145    parser.add_argument(
146        "--backend_id",
147        type=str,
148        default=StubBackend.__name__,
149        help="ID of the backend to use for delegation; "
150        + f"one of {known_backend_ids}",
151    )
152    parser.add_argument(
153        "--outdir",
154        type=str,
155        required=True,
156        help="Path to the directory to write <classname>[-<suffix>[...]].pte "
157        + "files to.",
158    )
159    args = parser.parse_args()
160
161    # Find the classes to export. Only looks in this module for now, but could
162    # be extended to look in other modules if helpful.
163    module_names_to_classes: Dict[str, Type[nn.Module]] = {}
164    for module in args.modules:
165        module_class = getattr(sys.modules[__name__], module, None)
166        if not (inspect.isclass(module_class) and issubclass(module_class, nn.Module)):
167            raise NameError(f"Could not find nn.Module class named '{module}'")
168        module_names_to_classes[module] = module_class
169
170    # Export and write to the output files.
171    os.makedirs(args.outdir, exist_ok=True)
172    for module_name, module_class in module_names_to_classes.items():
173        for extract_delegate_segments in (True, False):
174            suffix = "" if extract_delegate_segments else "-nosegments"
175            # Create files with the default alignment, and a large alignment.
176            # This alignment should be so large that it's extremely unlikely for
177            # the data to accidentally be aligned to it in the default case.
178            for delegate_alignment in (None, 1024):
179                suffix += f"-da{delegate_alignment}" if delegate_alignment else ""
180                outfile = os.path.join(args.outdir, f"{module_name}{suffix}.pte")
181                with open(outfile, "wb") as fp:
182                    fp.write(
183                        export_module_to_program(
184                            module_class,
185                            backend_id=args.backend_id,
186                            extract_delegate_segments=extract_delegate_segments,
187                            delegate_alignment=delegate_alignment,
188                        )
189                    )
190                print(f"Exported {module_name} and wrote program data to {outfile}")
191
192
193if __name__ == "__main__":
194    main()
195