1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import argparse 8import inspect 9import os 10import sys 11from typing import Dict, final, Optional, Sequence, Type 12 13import executorch.exir as exir 14 15import torch 16from executorch.exir import to_edge 17from executorch.exir.backend.backend_api import to_backend 18from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult 19from executorch.exir.backend.test.backend_with_compiler_demo import ( 20 BackendWithCompilerDemo, 21) 22from torch import nn 23from torch.export import export 24 25"""Traces and exports delegated nn.Modules to ExecuTorch .pte program files. 26 27Creates two versions of each file: 28- <module-name>.pte: Delegate data stored in segments outside of the flatbuffer data. 29- <module-name>-nosegments.pte: Delegate data is stored directly in the flatbuffer data. 30 31This tool mainly exists to export programs for C++ tests, but can also 32be used to export models manually. 33""" 34 35# 36# Modules 37# 38 39 40class ModuleAddMul(nn.Module): 41 def __init__(self): 42 super().__init__() 43 44 def forward( 45 self, a: torch.Tensor, x: torch.Tensor, b: torch.Tensor 46 ) -> torch.Tensor: 47 y: torch.Tensor = torch.mm(a, x) 48 z: torch.Tensor = torch.add(y, b) 49 return z 50 51 def get_random_inputs(self) -> Sequence[torch.Tensor]: 52 return (torch.ones(2, 2), 2 * torch.ones(2, 2), 3 * torch.ones(2, 2)) 53 54 55# 56# Backends 57# 58 59 60@final 61class StubBackend(BackendDetails): 62 """No-op backend to test serialization/init.""" 63 64 @staticmethod 65 def preprocess(*args, **kwargs) -> PreprocessResult: 66 return PreprocessResult(processed_bytes=b"StubBackend:data") 67 68 69# 70# Program logic 71# 72 73 74def export_module_to_program( 75 module_class: Type[nn.Module], 76 *, 77 backend_id: str, 78 extract_delegate_segments: bool, 79 constant_tensor_alignemnt: Optional[int] = None, 80 delegate_alignment: Optional[int] = None, 81 method: str = "forward", 82) -> bytes: 83 eager_module = module_class().eval() 84 inputs = () 85 if hasattr(eager_module, "get_random_inputs"): 86 # pyre-fixme[29]: `Union[nn.modules.module.Module, torch._tensor.Tensor]` is 87 # not a function. 88 inputs = eager_module.get_random_inputs() 89 90 class WrapperModule(torch.nn.Module): 91 def __init__(self, fn): 92 super().__init__() 93 self.fn = fn 94 95 def forward(self, *args, **kwargs): 96 return self.fn(*args, **kwargs) 97 98 edge: exir.EdgeProgramManager = to_edge( 99 export(WrapperModule(getattr(eager_module, method)), args=inputs) 100 ) 101 102 lowered_module = to_backend(backend_id, edge.exported_program(), compile_specs=[]) 103 104 class CompositeModule(nn.Module): 105 def __init__(self): 106 super().__init__() 107 self.lowered_module = lowered_module 108 109 def forward(self, *args, **kwargs): 110 return self.lowered_module(*args, **kwargs) 111 112 composite_module = CompositeModule() 113 composite_module(*inputs) 114 115 executorch_program = to_edge(export(composite_module, args=inputs)).to_executorch( 116 config=exir.ExecutorchBackendConfig( 117 extract_delegate_segments=extract_delegate_segments, 118 constant_tensor_alignment=constant_tensor_alignemnt, 119 delegate_alignment=delegate_alignment, 120 ) 121 ) 122 123 return executorch_program.buffer 124 125 126def main() -> None: 127 known_backend_ids = [ 128 BackendWithCompilerDemo.__name__, 129 StubBackend.__name__, 130 ] 131 132 # These args are optimized for genrule usage. There's a lot of startup 133 # overhead for this tool, so it's faster to export multiple models at once 134 # when possible. 135 parser = argparse.ArgumentParser( 136 prog="export_delegated_program", 137 description="Exports delegated nn.Module models to ExecuTorch .pte files", 138 ) 139 parser.add_argument( 140 "--modules", 141 help="Comma-separated list of model class names to export; " 142 + "e.g., '--modules=ModuleOne,ModuleTwo'", 143 type=lambda s: [item.strip() for item in s.split(",")], 144 ) 145 parser.add_argument( 146 "--backend_id", 147 type=str, 148 default=StubBackend.__name__, 149 help="ID of the backend to use for delegation; " 150 + f"one of {known_backend_ids}", 151 ) 152 parser.add_argument( 153 "--outdir", 154 type=str, 155 required=True, 156 help="Path to the directory to write <classname>[-<suffix>[...]].pte " 157 + "files to.", 158 ) 159 args = parser.parse_args() 160 161 # Find the classes to export. Only looks in this module for now, but could 162 # be extended to look in other modules if helpful. 163 module_names_to_classes: Dict[str, Type[nn.Module]] = {} 164 for module in args.modules: 165 module_class = getattr(sys.modules[__name__], module, None) 166 if not (inspect.isclass(module_class) and issubclass(module_class, nn.Module)): 167 raise NameError(f"Could not find nn.Module class named '{module}'") 168 module_names_to_classes[module] = module_class 169 170 # Export and write to the output files. 171 os.makedirs(args.outdir, exist_ok=True) 172 for module_name, module_class in module_names_to_classes.items(): 173 for extract_delegate_segments in (True, False): 174 suffix = "" if extract_delegate_segments else "-nosegments" 175 # Create files with the default alignment, and a large alignment. 176 # This alignment should be so large that it's extremely unlikely for 177 # the data to accidentally be aligned to it in the default case. 178 for delegate_alignment in (None, 1024): 179 suffix += f"-da{delegate_alignment}" if delegate_alignment else "" 180 outfile = os.path.join(args.outdir, f"{module_name}{suffix}.pte") 181 with open(outfile, "wb") as fp: 182 fp.write( 183 export_module_to_program( 184 module_class, 185 backend_id=args.backend_id, 186 extract_delegate_segments=extract_delegate_segments, 187 delegate_alignment=delegate_alignment, 188 ) 189 ) 190 print(f"Exported {module_name} and wrote program data to {outfile}") 191 192 193if __name__ == "__main__": 194 main() 195