xref: /aosp_15_r20/external/executorch/backends/apple/coreml/compiler/coreml_preprocess.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1#  Copyright © 2023 Apple Inc. All rights reserved.
2
3# CoreML backend for delegating a EdgeProgram to CoreML.
4
5import json
6import logging
7
8import shutil
9import uuid
10from dataclasses import asdict, dataclass
11from enum import Enum
12
13from pathlib import Path
14
15from typing import Any, Dict, final, List, Optional, Tuple
16
17import coremltools as ct
18import coremltools.optimize as cto
19import executorchcoreml
20
21from executorch.exir.backend.backend_details import (
22    BackendDetails,
23    ExportedProgram,
24    PreprocessResult,
25)
26from executorch.exir.backend.compile_spec_schema import CompileSpec
27
28logger = logging.getLogger(__name__)
29logger.setLevel(logging.WARNING)
30
31
32class COMPILE_SPEC_KEYS(Enum):
33    COMPUTE_UNITS = "compute_units"
34    MODEL_TYPE = "model_type"
35    MIN_DEPLOYMENT_TARGET = "min_deployment_target"
36    MODEL_COMPUTE_PRECISION = "model_compute_precision"
37    OP_LINEAR_QUANTIZER_CONFIG = "op_linear_quantizer_config"
38
39
40class MODEL_PATHS(Enum):
41    MODEL = "model.mlpackage"
42    COMPILED_MODEL = "model.mlmodelc"
43    METADATA = "metadata.json"
44    DEBUG_INFO = "debug_info.json"
45
46
47@dataclass
48class ModelMetadata:
49    # The model input names.
50    inputNames: List[str]
51    # The model output names.
52    outputNames: List[str]
53    # The model identifier.
54    identifier: str
55
56
57@dataclass
58class ModelDebugInfo:
59    # Version info.
60    versionInfo: Dict[str, str]
61    # Mapping from debug symbol to operation path.
62    debugSymbolToOperationPath: Dict[str, List[Dict[str, str]]]
63    # Mapping from debug symbol to handle.
64    debugSymbolToHandles: Dict[str, List[int]]
65
66
67@final
68class CoreMLBackend(BackendDetails):
69    class MODEL_TYPE(Enum):
70        MODEL = "model"
71        COMPILED_MODEL = "compiled_model"
72
73    @staticmethod
74    def generate_model_type_compile_spec(model_type: MODEL_TYPE) -> CompileSpec:
75        """
76        Returns the compile spec representing the given model type.
77
78        If the model type is ``MODEL_TYPE.Model`` then the ``CoreMLBackend`` returns
79        the in-memory representation of the ``mlpackage`` contents.
80
81        If the model type is ``MODEL_TYPE.COMPILED_MODEL`` then the ``CoreMLBackend`` compiles the model
82        and returns the in-memory representation of ``mlmodelc`` (compiled model) contents.
83        """
84        return CompileSpec(
85            COMPILE_SPEC_KEYS.MODEL_TYPE.value, model_type.value.encode("utf-8")
86        )
87
88    @staticmethod
89    def model_type_from_compile_specs(compile_specs: List[CompileSpec]) -> MODEL_TYPE:
90        """
91        Returns the model type by parsing the list of compile specs.
92        """
93        for compile_spec in compile_specs:
94            if compile_spec.key == COMPILE_SPEC_KEYS.MODEL_TYPE.value:
95                return CoreMLBackend.MODEL_TYPE(compile_spec.value.decode("utf-8"))
96
97        return CoreMLBackend.MODEL_TYPE.MODEL
98
99    @staticmethod
100    def generate_compute_precision_compile_spec(
101        compute_precision: ct.precision,
102    ) -> CompileSpec:
103        """
104        Returns the compile spec representing the model compute precision, for additional details
105        please refer to the documentation for ``coremltools.precision``.
106        """
107        return CompileSpec(
108            COMPILE_SPEC_KEYS.MODEL_COMPUTE_PRECISION.value,
109            compute_precision.value.encode("utf-8"),
110        )
111
112    @staticmethod
113    def model_compute_precision_from_compile_specs(
114        compile_specs: List[CompileSpec],
115    ) -> ct.precision:
116        """
117        Returns the model's compute precision by parsing the list of compile specs.
118        """
119        for compile_spec in compile_specs:
120            if compile_spec.key == COMPILE_SPEC_KEYS.MODEL_COMPUTE_PRECISION.value:
121                return ct.precision(compile_spec.value.decode("utf-8"))
122
123        return ct.precision.FLOAT16
124
125    @staticmethod
126    def generate_minimum_deployment_target_compile_spec(
127        min_deployment_target: ct.target,
128    ) -> CompileSpec:
129        """
130        Returns the compile spec representing the minimum deployment target on which the model can run,
131        for additional details please refer to the documentation for ``coremltools.target``.
132        """
133        return CompileSpec(
134            COMPILE_SPEC_KEYS.MIN_DEPLOYMENT_TARGET.value,
135            str(min_deployment_target.value).encode("utf-8"),
136        )
137
138    @staticmethod
139    def min_deployment_target_from_compile_specs(
140        compile_specs: List[CompileSpec],
141    ) -> ct.target:
142        """
143        Returns the minimum deployment target by parsing the list of compile specs.
144        """
145        for compile_spec in compile_specs:
146            if compile_spec.key == COMPILE_SPEC_KEYS.MIN_DEPLOYMENT_TARGET.value:
147                compile_spec_value: int = int(compile_spec.value.decode("utf-8"))
148                return ct.target(compile_spec_value)
149
150        return ct.target.iOS15
151
152    @staticmethod
153    def compute_unit_from_compile_specs(
154        compile_specs: List[CompileSpec],
155    ) -> ct.ComputeUnit:
156        """
157        Returns the minimum deployment target by parsing the list of compile specs.
158        """
159        for compile_spec in compile_specs:
160            if compile_spec.key == COMPILE_SPEC_KEYS.COMPUTE_UNITS.value:
161                return ct.ComputeUnit[compile_spec.value.decode("utf-8").upper()]
162
163        return ct.ComputeUnit.ALL
164
165    @staticmethod
166    def generate_compute_unit_compile_spec(
167        compute_unit: ct.ComputeUnit,
168    ) -> CompileSpec:
169        """
170        Returns the compile spec representing the compute units on which the model can run, for additional details
171        please refer to the documentation for ``coremltools.ComputeUnit`.
172        """
173        return CompileSpec(
174            COMPILE_SPEC_KEYS.COMPUTE_UNITS.value,
175            compute_unit.name.lower().encode("utf-8"),
176        )
177
178    @staticmethod
179    def generate_op_linear_quantizer_config_compile_spec(
180        op_linear_quantizer_config: Dict,
181    ) -> CompileSpec:
182        """
183        Returns the compile spec representing the model post conversion quantization,
184        which is a dict that will construct cto.coreml.OpLinearQuantizerConfig
185        """
186        str_representation = json.dumps(op_linear_quantizer_config)
187        byte_representation = str_representation.encode("utf-8")
188        return CompileSpec(
189            COMPILE_SPEC_KEYS.OP_LINEAR_QUANTIZER_CONFIG.value,
190            byte_representation,
191        )
192
193    @staticmethod
194    def op_linear_quantizer_config_from_compile_specs(
195        compile_specs: List[CompileSpec],
196    ) -> cto.coreml.OpLinearQuantizerConfig:
197        """
198        Returns the model's post conversion quantization by parsing the list of compile specs.
199        """
200        for compile_spec in compile_specs:
201            if compile_spec.key == COMPILE_SPEC_KEYS.OP_LINEAR_QUANTIZER_CONFIG.value:
202                config_dict_str = compile_spec.value.decode("utf-8")
203                config_dict = json.loads(config_dict_str)
204                config = cto.coreml.OpLinearQuantizerConfig._from_dict(config_dict)
205                return config
206
207        return None
208
209    @staticmethod
210    def generate_compile_specs(
211        compute_unit: ct.ComputeUnit = ct.ComputeUnit.ALL,
212        minimum_deployment_target: ct.target = ct.target.iOS15,
213        compute_precision: ct.precision = ct.precision.FLOAT16,
214        model_type: MODEL_TYPE = MODEL_TYPE.MODEL,
215        op_linear_quantizer_config: Optional[Dict] = None,
216    ) -> List[CompileSpec]:
217        """
218        Returns the list of compile specs that's used by CoreMLBackend to lower the module.
219        """
220        compile_specs: List[CompileSpec] = []
221        compile_specs.append(
222            CoreMLBackend.generate_compute_unit_compile_spec(compute_unit)
223        )
224        compile_specs.append(
225            CoreMLBackend.generate_minimum_deployment_target_compile_spec(
226                minimum_deployment_target
227            )
228        )
229        compile_specs.append(
230            CoreMLBackend.generate_compute_precision_compile_spec(compute_precision)
231        )
232        compile_specs.append(CoreMLBackend.generate_model_type_compile_spec(model_type))
233        if op_linear_quantizer_config is not None:
234            compile_specs.append(
235                CoreMLBackend.generate_op_linear_quantizer_config_compile_spec(
236                    op_linear_quantizer_config
237                )
238            )
239
240        return compile_specs
241
242    @staticmethod
243    def model_metadata_from_spec(
244        model_spec: ct.proto.Model_pb2, identifier: str  # pyre-ignore
245    ) -> ModelMetadata:
246        input_names: List[str] = [input.name for input in model_spec.description.input]
247        output_names = [output.name for output in model_spec.description.output]
248
249        return ModelMetadata(
250            inputNames=input_names, outputNames=output_names, identifier=identifier
251        )
252
253    @staticmethod
254    def get_debug_symbol(operation_path: List[Dict[str, str]]) -> Optional[str]:
255        if len(operation_path) == 0:
256            return None
257
258        operator_name: Optional[str] = operation_path[-1].get("Operator", None)
259        output_name: Optional[str] = operation_path[-1].get("Output", None)
260        if output_name is None or operator_name is None:
261            return None
262
263        return output_name + ":" + operator_name
264
265    @staticmethod
266    def get_model_debug_info(model_package_dir: Path) -> Optional[ModelDebugInfo]:
267        delegate_info_file = model_package_dir / "executorch_debug_handle_mapping.json"
268
269        if not delegate_info_file.is_file():
270            return None
271
272        delegate_info: Optional[Dict[str, Any]] = None
273
274        try:
275            with open(delegate_info_file) as f:
276                delegate_info = json.load(f)
277        except ValueError:
278            return None
279
280        if delegate_info is None:
281            return None
282
283        debug_handle_to_operation_path_mapping: Optional[Dict[str, Any]] = (
284            delegate_info.get("mapping", None)
285        )
286
287        if debug_handle_to_operation_path_mapping is None:
288            return None
289
290        debug_symbol_to_operation_path: Dict[str, List[Dict[str, str]]] = {}
291        debug_symbol_to_handles: Dict[str, List[int]] = {}
292        for (
293            debug_handle,
294            operation_paths,
295        ) in debug_handle_to_operation_path_mapping.items():
296            debug_handle_value: Optional[int] = None
297            try:
298                debug_handle_value = int(debug_handle)
299            except ValueError:
300                debug_handle_value = None
301
302            if debug_handle_value is None:
303                continue
304
305            for operation_path in operation_paths:
306                debug_symbol: Optional[str] = CoreMLBackend.get_debug_symbol(
307                    operation_path=operation_path
308                )
309
310                if debug_symbol is None:
311                    continue
312
313                debug_handle_values: List[int] = debug_symbol_to_handles.get(
314                    debug_symbol, []
315                )
316                debug_handle_values.append(debug_handle_value)
317                debug_symbol_to_handles[debug_symbol] = debug_handle_values
318
319                debug_symbol_to_operation_path[debug_symbol] = operation_path
320
321        version_info: Dict[str, str] = delegate_info.get("version", {})
322
323        return ModelDebugInfo(
324            versionInfo=version_info,
325            debugSymbolToOperationPath=debug_symbol_to_operation_path,
326            debugSymbolToHandles=debug_symbol_to_handles,
327        )
328
329    @staticmethod
330    def save_model_metadata(model_metadata: ModelMetadata, model_dir_path: Path):
331        # Store model metadata.
332        model_metadata_path = Path(model_dir_path) / MODEL_PATHS.METADATA.value
333        model_metadata_json = json.dumps(asdict(model_metadata))
334        with open(model_metadata_path, "w") as outfile:
335            outfile.write(model_metadata_json)
336
337    @staticmethod
338    def save_model_debug_info(model_debug_info: ModelDebugInfo, model_dir_path: Path):
339        # Store model debug info.
340        model_debug_info_path = Path(model_dir_path) / MODEL_PATHS.DEBUG_INFO.value
341        model_debug_info_json = json.dumps(asdict(model_debug_info))
342        with open(model_debug_info_path, "w") as outfile:
343            outfile.write(model_debug_info_json)
344
345    @staticmethod
346    def preprocess_model(
347        mlmodel: ct.models.MLModel, model_type: MODEL_TYPE
348    ) -> PreprocessResult:
349        identifier = "executorch_" + str(uuid.uuid4())
350        dir_path: Path = Path("tmp") / identifier
351        model_dir_path: Path = dir_path / "lowered_module"
352        model_spec: ct.proto.Model_pb2 = mlmodel.get_spec()
353        model_metadata: ModelMetadata = CoreMLBackend.model_metadata_from_spec(
354            model_spec=model_spec,
355            identifier=identifier,
356        )
357
358        # Save model.
359        model_path = model_dir_path / MODEL_PATHS.MODEL.value
360        mlmodel.save(str(model_path))
361        # Extract delegate mapping file.
362        model_debug_info: Optional[ModelDebugInfo] = CoreMLBackend.get_model_debug_info(
363            model_path
364        )
365
366        match model_type:
367            case CoreMLBackend.MODEL_TYPE.COMPILED_MODEL:
368                shutil.rmtree(str(model_path.resolve()))
369                model_path = model_dir_path / MODEL_PATHS.COMPILED_MODEL.value
370                compiled_model_path = mlmodel.get_compiled_model_path()
371                shutil.move(
372                    compiled_model_path,
373                    str(model_path.resolve()),
374                )
375
376            case _:
377                pass
378
379        CoreMLBackend.save_model_metadata(
380            model_metadata=model_metadata, model_dir_path=model_dir_path
381        )
382        if model_debug_info is not None:
383            CoreMLBackend.save_model_debug_info(
384                model_debug_info=model_debug_info, model_dir_path=model_dir_path
385            )
386
387        processed_bytes: bytes = (
388            executorchcoreml.flatten_directory_contents(str(model_dir_path.resolve()))
389            or b""
390        )
391
392        debug_handle_map: Optional[Dict[str, Tuple[int]]] = None
393        if model_debug_info is not None:
394            debug_handle_map = {
395                key: tuple(value)
396                for key, value in model_debug_info.debugSymbolToHandles.items()
397            }
398
399        shutil.rmtree(str(dir_path.resolve()))
400        return PreprocessResult(
401            processed_bytes=processed_bytes,
402            debug_handle_map=debug_handle_map,
403        )
404
405    @staticmethod
406    def preprocess(
407        edge_program: ExportedProgram,
408        compile_specs: List[CompileSpec],
409    ) -> PreprocessResult:
410        model_type: CoreMLBackend.MODEL_TYPE = (
411            CoreMLBackend.model_type_from_compile_specs(
412                compile_specs,
413            )
414        )
415        model_compute_precision: ct.precision = (
416            CoreMLBackend.model_compute_precision_from_compile_specs(compile_specs)
417        )
418        minimum_deployment_target: ct.target = (
419            CoreMLBackend.min_deployment_target_from_compile_specs(compile_specs)
420        )
421        compute_units: ct.ComputeUnit = CoreMLBackend.compute_unit_from_compile_specs(
422            compile_specs
423        )
424        op_linear_quantizer_config = (
425            CoreMLBackend.op_linear_quantizer_config_from_compile_specs(compile_specs)
426        )
427
428        # Load the model if MODEL_TYPE is 'COMPILED_MODEL'. This step is necessary because
429        # get_compiled_model_path() requires a loaded model.
430        skip_model_load = model_type != CoreMLBackend.MODEL_TYPE.COMPILED_MODEL
431        mlmodel = ct.convert(
432            model=edge_program,
433            source="pytorch",
434            convert_to="mlprogram",
435            pass_pipeline=ct.PassPipeline.DEFAULT,
436            skip_model_load=skip_model_load,
437            compute_precision=model_compute_precision,
438            minimum_deployment_target=minimum_deployment_target,
439            compute_units=compute_units,
440        )
441
442        if op_linear_quantizer_config is not None:
443            logger.warning(
444                "Core ML Backend op_linear_quantizer_config API is experimental"
445            )
446            config = cto.coreml.OptimizationConfig(
447                global_config=op_linear_quantizer_config,
448                # skip embedding
449                op_type_configs={"gather": None},
450            )
451            mlmodel = cto.coreml.linear_quantize_weights(mlmodel, config=config)
452
453        return CoreMLBackend.preprocess_model(mlmodel, model_type=model_type)
454