xref: /aosp_15_r20/external/executorch/examples/apple/coreml/scripts/export.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright © 2023 Apple Inc. All rights reserved.
2#
3# Please refer to the license found in the LICENSE file in the root directory of the source tree.
4
5import argparse
6import copy
7
8import pathlib
9import sys
10
11import coremltools as ct
12
13import executorch.exir as exir
14
15import torch
16
17# pyre-fixme[21]: Could not find module `executorch.backends.apple.coreml.compiler`.
18from executorch.backends.apple.coreml.compiler import CoreMLBackend
19
20# pyre-fixme[21]: Could not find module `executorch.backends.apple.coreml.partition`.
21from executorch.backends.apple.coreml.partition import CoreMLPartitioner
22from executorch.devtools.etrecord import generate_etrecord
23from executorch.exir import to_edge
24
25from executorch.exir.backend.backend_api import to_backend
26
27from torch.export import export
28
29REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent.parent.parent.parent
30EXAMPLES_DIR = REPO_ROOT / "examples"
31sys.path.append(str(EXAMPLES_DIR.absolute()))
32
33from executorch.examples.models import MODEL_NAME_TO_MODEL
34from executorch.examples.models.model_factory import EagerModelFactory
35
36# Script to export a model with coreml delegation.
37
38_EDGE_COMPILE_CONFIG = exir.EdgeCompileConfig(
39    _check_ir_validity=False,
40    _skip_dim_order=True,  # TODO(T182928844): enable dim_order in backend
41)
42
43
44def parse_args() -> argparse.ArgumentParser:
45    parser = argparse.ArgumentParser()
46
47    parser.add_argument(
48        "-m",
49        "--model_name",
50        required=True,
51        help=f"Provide model name. Valid ones: {list(MODEL_NAME_TO_MODEL.keys())}",
52    )
53
54    parser.add_argument(
55        "-c",
56        "--compute_unit",
57        required=False,
58        default=ct.ComputeUnit.ALL.name.lower(),
59        help=f"Provide compute unit for the model. Valid ones: {[[compute_unit.name.lower() for compute_unit in ct.ComputeUnit]]}",
60    )
61
62    parser.add_argument(
63        "-precision",
64        "--compute_precision",
65        required=False,
66        default=ct.precision.FLOAT16.value,
67        help=f"Provide compute precision for the model. Valid ones: {[[precision.value for precision in ct.precision]]}",
68    )
69
70    parser.add_argument(
71        "--compile",
72        action=argparse.BooleanOptionalAction,
73        required=False,
74        default=False,
75    )
76    parser.add_argument("--use_partitioner", action=argparse.BooleanOptionalAction)
77    parser.add_argument("--generate_etrecord", action=argparse.BooleanOptionalAction)
78    parser.add_argument("--save_processed_bytes", action=argparse.BooleanOptionalAction)
79
80    args = parser.parse_args()
81    # pyre-fixme[7]: Expected `ArgumentParser` but got `Namespace`.
82    return args
83
84
85def partition_module_to_coreml(module):
86    module = module.eval()
87
88
89def lower_module_to_coreml(module, compile_specs, example_inputs):
90    module = module.eval()
91    edge = to_edge(export(module, example_inputs), compile_config=_EDGE_COMPILE_CONFIG)
92    # All of the subsequent calls on the edge_dialect_graph generated above (such as delegation or
93    # to_executorch()) are done in place and the graph is also modified in place. For debugging purposes
94    # we would like to keep a copy of the original edge dialect graph and hence we create a deepcopy of
95    # it here that will later then be serialized into a etrecord.
96    edge_copy = copy.deepcopy(edge)
97
98    lowered_module = to_backend(
99        CoreMLBackend.__name__,
100        edge.exported_program(),
101        compile_specs,
102    )
103
104    return lowered_module, edge_copy
105
106
107def export_lowered_module_to_executorch_program(lowered_module, example_inputs):
108    lowered_module(*example_inputs)
109    exec_prog = to_edge(
110        export(lowered_module, example_inputs), compile_config=_EDGE_COMPILE_CONFIG
111    ).to_executorch(config=exir.ExecutorchBackendConfig(extract_delegate_segments=True))
112
113    return exec_prog
114
115
116def save_executorch_program(exec_prog, model_name, compute_unit):
117    buffer = exec_prog.buffer
118    filename = f"{model_name}_coreml_{compute_unit}.pte"
119    print(f"Saving exported program to {filename}")
120    with open(filename, "wb") as file:
121        file.write(buffer)
122    return
123
124
125def save_processed_bytes(processed_bytes, model_name, compute_unit):
126    filename = f"{model_name}_coreml_{compute_unit}.bin"
127    print(f"Saving processed bytes to {filename}")
128    with open(filename, "wb") as file:
129        file.write(processed_bytes)
130    return
131
132
133def generate_compile_specs_from_args(args):
134    model_type = CoreMLBackend.MODEL_TYPE.MODEL
135    if args.compile:
136        model_type = CoreMLBackend.MODEL_TYPE.COMPILED_MODEL
137
138    compute_precision = ct.precision(args.compute_precision)
139    compute_unit = ct.ComputeUnit[args.compute_unit.upper()]
140
141    return CoreMLBackend.generate_compile_specs(
142        compute_precision=compute_precision,
143        compute_unit=compute_unit,
144        model_type=model_type,
145    )
146
147
148def main():
149    args = parse_args()
150
151    if args.model_name not in MODEL_NAME_TO_MODEL:
152        raise RuntimeError(
153            f"Model {args.model_name} is not a valid name. "
154            f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}."
155        )
156
157    valid_compute_units = [compute_unit.name.lower() for compute_unit in ct.ComputeUnit]
158    if args.compute_unit not in valid_compute_units:
159        raise RuntimeError(
160            f"{args.compute_unit} is invalid. "
161            f"Valid compute units are {valid_compute_units}."
162        )
163
164    model, example_inputs, _, _ = EagerModelFactory.create_model(
165        *MODEL_NAME_TO_MODEL[args.model_name]
166    )
167
168    compile_specs = generate_compile_specs_from_args(args)
169    lowered_module = None
170
171    if args.use_partitioner:
172        model.eval()
173        exir_program_aten = torch.export.export(model, example_inputs)
174
175        edge_program_manager = exir.to_edge(exir_program_aten)
176        edge_copy = copy.deepcopy(edge_program_manager)
177        partitioner = CoreMLPartitioner(
178            skip_ops_for_coreml_delegation=None, compile_specs=compile_specs
179        )
180        delegated_program_manager = edge_program_manager.to_backend(partitioner)
181        exec_program = delegated_program_manager.to_executorch(
182            config=exir.ExecutorchBackendConfig(extract_delegate_segments=True)
183        )
184    else:
185        lowered_module, edge_copy = lower_module_to_coreml(
186            module=model,
187            example_inputs=example_inputs,
188            compile_specs=compile_specs,
189        )
190        exec_program = export_lowered_module_to_executorch_program(
191            lowered_module,
192            example_inputs,
193        )
194
195    model_name = f"{args.model_name}_compiled" if args.compile else args.model_name
196    save_executorch_program(exec_program, model_name, args.compute_unit)
197    generate_etrecord(f"{args.model_name}_coreml_etrecord.bin", edge_copy, exec_program)
198
199    if args.save_processed_bytes and lowered_module is not None:
200        save_processed_bytes(
201            lowered_module.processed_bytes, args.model_name, args.compute_unit
202        )
203
204
205if __name__ == "__main__":
206    main()
207