xref: /aosp_15_r20/external/executorch/backends/qualcomm/utils/utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 # Copyright (c) Qualcomm Innovation Center, Inc.
2 # All rights reserved
3 #
4 # This source code is licensed under the BSD-style license found in the
5 # LICENSE file in the root directory of this source tree.
6 
7 import operator
8 import warnings
9 from collections import OrderedDict
10 from typing import Callable, Dict, FrozenSet, List, Tuple
11 
12 import executorch.backends.qualcomm.python.PyQnnManagerAdaptor as PyQnnManagerAdaptor
13 
14 import executorch.exir as exir
15 
16 import torch
17 from executorch.backends.qualcomm._passes.annotate_and_quant_scalar import (
18     AnnotateAndQuantScalar,
19 )
20 from executorch.backends.qualcomm._passes.annotate_decomposed import AnnotateDecomposed
21 from executorch.backends.qualcomm._passes.annotate_quant_attrs import AnnotateQuantAttrs
22 from executorch.backends.qualcomm._passes.convert_binary_op_with_scalar import (
23     ConvertBinaryOpsWithScalar,
24 )
25 from executorch.backends.qualcomm._passes.convert_bmm_to_matmul import (
26     ConvertBmmToMatmul,
27 )
28 from executorch.backends.qualcomm._passes.convert_interpolate_with_upsample2d import (
29     ConvertInterpolateWithUpsample2D,
30 )
31 from executorch.backends.qualcomm._passes.convert_prelu import ConvertPReLU
32 from executorch.backends.qualcomm._passes.convert_to_linear import ConvertToLinear
33 from executorch.backends.qualcomm._passes.expand_broadcast_tensor_shape import (
34     ExpandBroadcastTensorShape,
35 )
36 from executorch.backends.qualcomm._passes.fold_qdq import FoldQDQ
37 from executorch.backends.qualcomm._passes.i64_to_i32 import I64toI32
38 from executorch.backends.qualcomm._passes.layout_transform import LayoutTransform
39 from executorch.backends.qualcomm._passes.recompose_pixel_unshuffle import (
40     RecomposePixelUnshuffle,
41 )
42 from executorch.backends.qualcomm._passes.recompose_rms_norm import RecomposeRmsNorm
43 from executorch.backends.qualcomm._passes.remove_redundancy import RemoveRedundancy
44 from executorch.backends.qualcomm._passes.replace_index_put_input import (
45     ReplaceIndexPutInput,
46 )
47 
48 from executorch.backends.qualcomm.builders.node_visitor import (
49     QNN_QUANT_TYPE_MAP,
50     QNN_TENSOR_TYPE_MAP,
51 )
52 from executorch.backends.qualcomm.builders.qnn_constants import OpContextLoader
53 from executorch.backends.qualcomm.partition.qnn_partitioner import (
54     generate_qnn_executorch_option,
55     QnnPartitioner,
56 )
57 from executorch.backends.qualcomm.serialization.qc_schema import (
58     _soc_info_table,
59     HtpArch,
60     QcomChipset,
61     QnnExecuTorchBackendOptions,
62     QnnExecuTorchBackendType,
63     QnnExecuTorchHtpBackendOptions,
64     QnnExecuTorchHtpPerformanceMode,
65     QnnExecuTorchHtpPrecision,
66     QnnExecuTorchLogLevel,
67     QnnExecuTorchOptions,
68     QnnExecuTorchProfileLevel,
69 )
70 from executorch.backends.qualcomm.serialization.qc_schema_serialize import (
71     flatbuffer_to_option,
72     option_to_flatbuffer,
73 )
74 from executorch.backends.qualcomm.utils.constants import (
75     QCOM_PASS_EXPAND_BROADCAST_SHAPE,
76     QCOM_PASS_SKIP_ADVANCED_REQUANT,
77     QCOM_QNN_COMPILE_SPEC,
78     QCOM_QUANTIZED_IO,
79 )
80 
81 from executorch.exir import (
82     EdgeCompileConfig,
83     ExecutorchProgramManager,
84     ExirExportedProgram,
85     to_edge,
86 )
87 from executorch.exir.backend.compile_spec_schema import CompileSpec
88 from executorch.exir.capture import ExecutorchBackendConfig
89 from executorch.exir.lowered_backend_module import LoweredBackendModule
90 from executorch.exir.program._program import _get_updated_graph_signature
91 from torch._decomp import core_aten_decompositions as torch_core_aten_decompositions
92 from torch.export.exported_program import ExportedProgram
93 from torch.fx import passes
94 from torch.fx.passes.operator_support import OperatorSupportBase
95 from torch.library import Library
96 
97 
98 class _AnnotationSkipper(OperatorSupportBase):
99     """
100     Class used to partition out unwanted graph nodes.
101     e.g. - nodes are prevented from quantization annotation
102          - nodes have been grouped together as a submodule
103 
104     Attributes
105     ----------
106     fp_node_id_set : set
107         a set contains nodes' name to be left in fp precision
108     fp_node_op_set : set
109         a set contains nodes' target (aten dialect) to be left in fp precision
110     skip_annotated_submodule : bool
111         flag to skip annotated submodule or not
112 
113     Methods
114     -------
115     should_delegate(n: torch.fx.Node)
116         identify the residual nodes haven't be lowered with fixed-precision
117     should_skip(n: torch.fx.Node)
118         identify the nodes should be kept out with fixed-precision or not
119     is_node_supported(_, node: torch.fx.Node)
120         overridden method for graph partitioning
121     """
122 
123     def __init__(
124         self,
125         fp_node_id_set: set = None,
126         fp_node_op_set: set = None,
127         skip_annotated_submodule: bool = False,
128     ):
129         self.fp_node_id_set = fp_node_id_set
130         self.fp_node_op_set = fp_node_op_set
131         self.skip_annotated_submodule = skip_annotated_submodule
132 
133     def should_delegate(self, n: torch.fx.Node):
134         return n.op == "call_function" and n.target != operator.getitem
135 
136     def should_skip(self, n: torch.fx.Node):
137         return n.name in self.fp_node_id_set or n.target in self.fp_node_op_set
138 
139     def is_node_supported(self, _, node: torch.fx.Node) -> bool:
140         if self.skip_annotated_submodule:
141             if node.op == "get_attr":
142                 return all(self.should_delegate(user) for user in node.users)
143             return self.should_delegate(node)
144 
145         if any(
146             [
147                 node.op in ("placeholder", "output"),
148                 self.should_skip(node),
149                 # check if parameters belong to fallbacked operator
150                 (
151                     node.op == "get_attr"
152                     and all(self.should_skip(user) for user in node.users)
153                 ),
154             ]
155         ):
156             print(f"[QNN Quantizer Annotation]: {node.name} | Skipped")
157             return False
158 
159         return True
160 
161 
162 def qnn_capture_config():
163     return exir.CaptureConfig(enable_aot=True)
164 
165 
166 def qnn_edge_config() -> exir.EdgeCompileConfig:
167     return exir.EdgeCompileConfig(
168         _check_ir_validity=False,
169         _skip_dim_order=True,  # TODO(T182928844): Delegate dim order op to backend.
170     )
171 
172 
173 def convert_linear_to_conv2d(module: torch.nn.Module):
174     class Conv2D(torch.nn.Module):
175         def __init__(self, weight, bias=None):
176             super().__init__()
177             use_bias = bias is not None
178             self.conv = torch.nn.Conv2d(
179                 in_channels=weight.shape[0],
180                 out_channels=weight.shape[1],
181                 kernel_size=1,
182                 padding=0,
183                 bias=use_bias,
184             )
185             self.conv.weight = torch.nn.Parameter(weight.reshape(*weight.shape, 1, 1))
186             if use_bias:
187                 self.conv.bias = torch.nn.Parameter(bias)
188 
189         def forward(self, x):
190             rank = x.dim()
191             x = x.unsqueeze(-1) if rank == 3 else x.reshape(1, *x.shape, 1)
192             x = torch.transpose(x, 1, 2)
193             res = self.conv(x)
194             res = torch.transpose(res, 1, 2)
195             res = res.squeeze(-1) if rank == 3 else res.reshape(*res.shape[1:3])
196             return res
197 
198     def replace_linear(module: torch.nn.Module):
199         attr_strs = dir(module)
200         if isinstance(module, torch.nn.ModuleList):
201             attr_strs += [str(i) for i in range(len(module))]
202 
203         for attr_str in attr_strs:
204             target_attr = getattr(module, attr_str)
205             if isinstance(target_attr, torch.nn.Linear):
206                 setattr(module, attr_str, Conv2D(target_attr.weight, target_attr.bias))
207 
208         for _, sub_module in module.named_children():
209             sub_module = replace_linear(sub_module)
210         return module
211 
212     return replace_linear(module)
213 
214 
215 def update_spill_fill_size(
216     exported_program: ExportedProgram | List[LoweredBackendModule],
217 ):
218     # check if user specifies to use multi_contexts
219     # this is a generic approach in case there exists multiple backends
220     def get_program_info(program):
221         def process_exported_program(prog):
222             max_sf_buf_size, module_map = 0, {}
223             for _, m in prog.graph_module._modules.items():
224                 # currently only 1 compile spec is expected in each partition
225                 options = flatbuffer_to_option(m.compile_specs[0].value)
226                 if (
227                     options.backend_options.backend_type
228                     == QnnExecuTorchBackendType.kHtpBackend
229                     and options.backend_options.htp_options.use_multi_contexts
230                 ):
231                     qnn_mgr = PyQnnManagerAdaptor.QnnManager(
232                         m.compile_specs[0].value, m.processed_bytes
233                     )
234                     assert qnn_mgr.Init().value == 0, "failed to load context binary"
235                     max_sf_buf_size = max(
236                         max_sf_buf_size, qnn_mgr.GetSpillFillBufferSize()
237                     )
238                     module_map[m] = options
239                     qnn_mgr.Destroy()
240             return max_sf_buf_size, module_map
241 
242         def process_lowered_module(module):
243             qnn_mgr = PyQnnManagerAdaptor.QnnManager(
244                 module.compile_specs[0].value, module.processed_bytes
245             )
246             assert qnn_mgr.Init().value == 0, "failed to load context binary"
247             spill_fill_size = qnn_mgr.GetSpillFillBufferSize()
248             qnn_mgr.Destroy()
249             return spill_fill_size, {
250                 module: flatbuffer_to_option(module.compile_specs[0].value)
251             }
252 
253         dispatch = {
254             ExportedProgram: process_exported_program,
255             LoweredBackendModule: process_lowered_module,
256         }
257         return dispatch[type(program)](program)
258 
259     def update_program(max_sf_buf_size, module_map):
260         def set_spec(module, options):
261             spec = CompileSpec(QCOM_QNN_COMPILE_SPEC, option_to_flatbuffer(options))
262             if isinstance(module, ExportedProgram):
263                 module.compile_specs[0] = spec
264             else:
265                 module._compile_specs[0] = spec
266 
267         for module, options in module_map.items():
268             options.backend_options.htp_options.max_sf_buf_size = max_sf_buf_size
269             set_spec(module, options)
270 
271     if isinstance(exported_program, list):
272         max_sf_size, modules_map = 0, {}
273         for prog in exported_program:
274             max_sf_buf_size, module_map = get_program_info(prog)
275             max_sf_size = max(max_sf_size, max_sf_buf_size)
276             modules_map.update(module_map)
277         update_program(max_sf_size, modules_map)
278     else:
279         update_program(*get_program_info(exported_program))
280 
281 
282 def get_decomp_table() -> Dict[torch._ops.OperatorBase, Callable]:
283     source_decompositions = torch_core_aten_decompositions()
284     # The below super ops are supported by QNN
285     remove_decompositions = [
286         torch.ops.aten.pixel_shuffle.default,
287         torch.ops.aten.pixel_unshuffle.default,
288         torch.ops.aten.hardsigmoid.default,
289         torch.ops.aten.hardswish.default,
290         torch.ops.aten._safe_softmax.default,
291     ]
292 
293     for key in remove_decompositions:
294         source_decompositions.pop(key)
295 
296     return source_decompositions
297 
298 
299 def _transform(
300     edge_program: ExportedProgram, custom_pass_config: FrozenSet[str] = frozenset()
301 ) -> ExportedProgram:
302     # currently ExirExportedProgram.transform does not accept
303     # changes of input number which was caused by FoldQDQ
304     # apply passes one by one here to avoid IR capture failure
305     graph_module = edge_program.graph_module
306     RemoveRedundancy()(graph_module)
307     RecomposePixelUnshuffle()(graph_module)
308     RecomposeRmsNorm()(graph_module)
309     ConvertToLinear()(graph_module)
310     ConvertPReLU(edge_program)(graph_module)
311     ConvertBmmToMatmul()(graph_module)
312     ConvertInterpolateWithUpsample2D()(graph_module)
313     I64toI32(edge_program)(graph_module)
314     AnnotateQuantAttrs(
315         edge_program, QCOM_PASS_SKIP_ADVANCED_REQUANT in custom_pass_config
316     )(graph_module)
317     AnnotateAndQuantScalar(edge_program)(graph_module)
318     AnnotateDecomposed(edge_program)(graph_module)
319     FoldQDQ()(graph_module)
320     # this pass is not necessary for network without layout-sensitive ops
321     # enable defaultly will introduce overhead from extra view_copy nodes
322     if QCOM_PASS_EXPAND_BROADCAST_SHAPE in custom_pass_config:
323         ExpandBroadcastTensorShape()(graph_module)
324     LayoutTransform(edge_program)(graph_module)
325     ReplaceIndexPutInput(edge_program)(graph_module)
326 
327     # Since QDQ nodes are stripped, update graph signature again to validate program
328     edge_program._graph_signature = _get_updated_graph_signature(
329         edge_program.graph_signature,
330         edge_program.graph_module,
331     )
332     edge_program._validate()
333     return edge_program
334 
335 
336 def capture_program(
337     module: torch.nn.Module,
338     inputs: Tuple[torch.Tensor],
339     custom_pass_config: FrozenSet[str] = frozenset(),
340 ) -> exir.ExirExportedProgram:
341     ep = torch.export.export(module, inputs)
342     decomposed_ep = ep.run_decompositions(get_decomp_table())
343     # We choose call_operator by target in ConvertBinaryOpsWithScalar
344     # because it is the same source_fn_stack for MultiheadAttention
345     # TODO: Should modify the scalar op in the op builder instead of
346     #       using transformation
347     core_ep = ExirExportedProgram(decomposed_ep, False)
348     core_ep.transform(ConvertBinaryOpsWithScalar())
349     edge_ep = core_ep.to_edge(qnn_edge_config())
350     _transform(edge_ep.exported_program, custom_pass_config)
351     return edge_ep
352 
353 
354 def _partition_graph_into_submodules(gm, subgm_tag, subgm_cb, ptn):
355     from torch.fx.passes.utils.fuser_utils import (
356         erase_nodes,
357         fuse_as_graphmodule,
358         insert_subgm,
359         legalize_graph,
360         topo_sort,
361     )
362 
363     partitions = ptn.propose_partitions()
364     # insert meta for each partition group
365     for i, partition in enumerate(partitions):
366         for node in partition.nodes:
367             node.meta[subgm_tag] = i
368 
369     for i in range(len(partitions)):
370         # find nodes with same group id in current graph
371         node_list = [
372             node for node in gm.graph.nodes if node.meta.get(subgm_tag, "") == i
373         ]
374         # fuse group nodes into submodule
375         sorted_nodes = topo_sort(node_list)
376         submodule_name = f"{subgm_tag}_{i}"
377         subgm, orig_inputs, orig_outputs = fuse_as_graphmodule(
378             gm, sorted_nodes, submodule_name
379         )
380         # insert submodule & trim group nodes
381         gm = insert_subgm(
382             gm,
383             subgm_cb(subgm, submodule_name),
384             orig_inputs,
385             orig_outputs,
386         )
387         erase_nodes(gm, sorted_nodes)
388         legalize_graph(gm)
389 
390     gm.recompile()
391     return gm
392 
393 
394 def _canonicalize_graph_with_lowered_module(gm, subgm_tag, ptn):
395     from executorch.exir.backend.backend_api import to_backend
396 
397     # return lowered program for user to debug
398     exported_progs = []
399     # partition each submodule which went through convert_pt2e
400     for node in gm.graph.nodes:
401         if node.op == "call_module" and subgm_tag in node.name:
402             # obtain sample inputs through meta
403             subgm_input = [
404                 torch.ones(arg.meta["val"].shape, dtype=arg.meta["val"].dtype)
405                 for arg in node.args
406             ]
407             # program meets QNN backend requirement
408             sub_prog = capture_program(gm.get_submodule(node.name), tuple(subgm_input))
409             # start lowering with given partitioner
410             exported_progs.append(to_backend(sub_prog.exported_program, ptn))
411             # replace submodule with lowered module
412             gm.set_submodule(
413                 node.name,
414                 exported_progs[-1].graph_module,
415             )
416             # if node has multiple outputs, getitems will be default generated
417             if all(n.target != operator.getitem for n in node.users):
418                 with gm.graph.inserting_after(node):
419                     getitem_node = gm.graph.call_function(
420                         operator.getitem,
421                         (node, 0),
422                     )
423                     getitem_node.meta = node.meta
424                     node.replace_all_uses_with(
425                         replace_with=getitem_node,
426                         delete_user_cb=lambda user: user.target != operator.getitem,
427                     )
428 
429     gm.recompile()
430     return gm, exported_progs
431 
432 
433 def skip_annotation(
434     nn_module: torch.nn.Module,
435     quantizer,
436     partitioner,
437     sample_input: Tuple[torch.Tensor, ...],
438     calibration_cb: Callable[[torch.fx.GraphModule], None],
439     fp_node_id_set: set = None,
440     fp_node_op_set: set = None,
441     fallback_to_cpu: bool = True,
442 ):
443     r"""
444     Exclude speific operators from quantizer annotation.
445     Skipped operators will defaultly stay in CPU, set 'fallback_to_cpu'
446     to False for trying to delegate them with FP16 precision.
447 
448     e.g.: consider following graph:
449     bias_1 weight_1 input_1   bias_2 weight_2 input_2
450       | (placeholder) |         | (placeholder) |
451        \      |      /           \      |      /
452         \     |     /             \     |     /
453          \    |    /               \    |    /
454            conv2d_1                 conv2d_2
455            (torch.ops.aten.conv2d.default)
456                \                       /
457                 \                     /
458                  \_______     _______/
459                          add_1
460              (torch.ops.aten.add.default)
461                            |
462                          output
463 
464     If user wants to skip convolution op by names with
465     'skip_node_id_set' = {"conv2d_1"}
466     "bias_1 / weight_1 / input_1 / input_2 / conv2d_1"
467     will be partitioned out and not annotated / lowered with QNN.
468 
469     [Generated graph]
470     bias_1 weight_1 input_1   input_2
471       | (placeholder) |          |
472        \      |      /           |
473         \     |     /            |
474          \    |    /             |
475            conv2d_1              |
476               \                 /
477                \               /
478                 \             /
479                lowered_module_1
480             (QNN fixed precision)
481                       |
482                     output
483 
484     If user wants to skip convolution op by target with
485     'skip_node_op_set' = {torch.ops.aten.conv2d.default}
486     "bias_1 / weight_1 / input_1 / conv2d_1,
487      bias_2 / weight_2 / input_2 / conv2d_2"
488     will be partitioned out and not annotated / lowered with QNN.
489 
490     [Generated graph]
491     bias_1 weight_1 input_1   bias_2 weight_2 input_2
492       | (placeholder) |         | (placeholder) |
493        \      |      /           \      |      /
494         \     |     /             \     |     /
495          \    |    /               \    |    /
496            conv2d_1                 conv2d_2
497            (torch.ops.aten.conv2d.default)
498                \                       /
499                 \                     /
500                  \__               __/
501                     lowered_module_1
502                  (QNN fixed precision)
503                            |
504                          output
505 
506     If user wants to delegate the skipped conv2d from above graph
507     with 'fallback_to_cpu' = False:
508 
509     [Generated graph]
510        input_1         input_2
511     (placeholder)   (placeholder)
512           |               |
513           \               /
514           lowered_module_2
515          (QNN fp16 precision)
516                   |
517                   |
518           lowered_module_1
519          (QNN fixed precision)
520                   |
521                 output
522 
523     Args:
524         nn_module (torch.nn.Module): The module to be lowered.
525         quantizer (QnnQuantizer): Instance of QnnQuantizer.
526         partitioner (QnnPartitioner): Instance of QnnPartitioner.
527         sample_input ((torch.Tensor, ...)): Sample input tensors for graph exporting.
528         calibration_cb (callable): Callback function for user-defined calibration.
529         fp_node_id_set ({str, ...}): Set of operator names to be left in fp precision.
530         fp_node_op_set ({torch.ops.aten.xxx, ...}): Set of operator targets to be left in fp precision.
531         fallback_to_cpu (bool): Whether to lower skipped nodes to fp16 or not.
532 
533     Returns:
534         exported_programs: List of programs lowered to QnnBackend (quantized graphs only).
535     """
536     from executorch.backends.qualcomm.serialization.qc_schema import (
537         QnnExecuTorchHtpPrecision,
538     )
539     from executorch.backends.qualcomm.serialization.qc_schema_serialize import (
540         flatbuffer_to_option,
541     )
542     from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e
543     from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner
544 
545     def prepare_subgm(subgm, subgm_name):
546         # prepare current submodule for quantization annotation
547         subgm_prepared = prepare_pt2e(subgm, quantizer)
548         # overwrite this attribute or name will be set to "GraphModule"
549         # we could not identify each submodule if action is not performed
550         subgm_prepared.__class__.__name__ = subgm_name
551         return subgm_prepared
552 
553     fp_node_id_set = fp_node_id_set if fp_node_id_set is not None else set()
554     fp_node_op_set = fp_node_op_set if fp_node_op_set is not None else set()
555     graph_module = torch.export.export(nn_module, sample_input).module()
556     # define node support type
557     capability_partitioner = CapabilityBasedPartitioner(
558         graph_module,
559         _AnnotationSkipper(fp_node_id_set, fp_node_op_set),
560         allows_single_node_partition=True,
561     )
562     subgm_tag = "annotated_group"
563     graph_module = _partition_graph_into_submodules(
564         gm=graph_module,
565         subgm_tag=subgm_tag,
566         subgm_cb=prepare_subgm,
567         ptn=capability_partitioner,
568     )
569     # perform calibration
570     calibration_cb(graph_module)
571     # convert sub modules which went through prepare_pt2e
572     for node in graph_module.graph.nodes:
573         if node.op == "call_module":
574             graph_module.set_submodule(
575                 node.name, convert_pt2e(graph_module.get_submodule(node.name))
576             )
577     # canonicalize graph for lowering again
578     graph_module, exported_progs = _canonicalize_graph_with_lowered_module(
579         gm=graph_module,
580         subgm_tag=subgm_tag,
581         ptn=partitioner,
582     )
583 
584     if not fallback_to_cpu:
585         try:
586             from executorch.exir.backend.partitioner import DelegationSpec
587 
588             # change HTP compiler spec for hardware to enable fp16
589             qnn_option = generate_qnn_executorch_option(
590                 partitioner.compiler_specs_snapshot
591             )
592             compile_option = flatbuffer_to_option(qnn_option)
593             htp_options = compile_option.backend_options.htp_options
594             htp_options.precision = QnnExecuTorchHtpPrecision.kHtpFp16
595             partitioner.delegation_spec = DelegationSpec(
596                 "QnnBackend",
597                 [
598                     CompileSpec(
599                         QCOM_QNN_COMPILE_SPEC, option_to_flatbuffer(compile_option)
600                     )
601                 ],
602             )
603         except:
604             print(
605                 "Failed to change HTP compiler spec with 'use_fp16' as True,"
606                 " skipped operators will fallback to cpu,"
607             )
608             return graph_module, exported_progs
609 
610         # try lowering skipped operator into fp16
611         capability_partitioner = CapabilityBasedPartitioner(
612             graph_module,
613             _AnnotationSkipper(skip_annotated_submodule=True),
614             allows_single_node_partition=True,
615         )
616         subgm_tag = "skipped_group"
617         graph_module = _partition_graph_into_submodules(
618             gm=graph_module,
619             subgm_tag=subgm_tag,
620             subgm_cb=lambda subgm, _: subgm,
621             ptn=capability_partitioner,
622         )
623         graph_module, exported_progs_fp = _canonicalize_graph_with_lowered_module(
624             gm=graph_module,
625             subgm_tag=subgm_tag,
626             ptn=partitioner,
627         )
628         exported_progs.extend(exported_progs_fp)
629 
630     return graph_module, exported_progs
631 
632 
633 def from_context_binary(  # noqa: C901
634     ctx_path: str | bytes,
635     op_name: str,
636     soc_model: QcomChipset = QcomChipset.SM8650,
637     custom_info: Dict = None,
638 ):
639     from pathlib import Path
640 
641     def implement_op(custom_op, op_name, outputs):
642         @torch.library.impl(
643             custom_op, str(op_name), dispatch_key="CompositeExplicitAutograd"
644         )
645         def op_impl(inputs: List[torch.Tensor]):
646             return tuple(
647                 torch.zeros(tuple(v.shape), device="meta", dtype=v.dtype)
648                 for v in outputs.values()
649             )
650 
651     def build_graph(inputs, outputs):
652         # custom op declaration
653         inputs_str = "Tensor[] inputs"
654         func_proto = f"{op_name}({inputs_str}) -> Any"
655         custom_op = Library(OpContextLoader.namespace, "FRAGMENT")
656         custom_op.define(func_proto)
657         # custom op implementation
658         implement_op(custom_op, op_name, outputs)
659 
660         # model architecture mimicking context binary
661         class Model(torch.nn.Module):
662             def forward(self, *inputs):
663                 return getattr(
664                     getattr(torch.ops, OpContextLoader.namespace), op_name
665                 ).default(inputs)
666 
667         model = Model()
668         prog = torch.export.export(model, tuple(inputs.values()))
669         # bookkeeping for variables' life cycle
670         return {
671             "custom_op": custom_op,
672             "custom_module": model,
673             "exported_program": prog,
674         }
675 
676     def build_tensor(tensors, dtype_map):
677         ret = OrderedDict()
678         for t in tensors:
679             dtype = t.GetDataType()
680             dtype_torch = dtype_map.get(dtype, None)
681             assert dtype_torch is not None, f"unknown qnn data type {dtype}"
682             ret[t.GetName()] = torch.zeros(tuple(t.GetDims()), dtype=dtype_torch)
683 
684         return ret
685 
686     def preprocess_binary(ctx_bin, compiler_specs):
687         qnn_mgr = PyQnnManagerAdaptor.QnnManager(
688             generate_qnn_executorch_option(compiler_specs),
689         )
690         return bytes(qnn_mgr.MakeBinaryInfo(ctx_bin))
691 
692     # dummy compiler spec would be fine, since we're not compiling
693     backend_options = generate_htp_compiler_spec(use_fp16=False)
694     compiler_specs = generate_qnn_executorch_compiler_spec(
695         soc_model=soc_model,
696         backend_options=backend_options,
697         is_from_context_binary=True,
698     )
699 
700     ctx_bin = (
701         ctx_path
702         if not isinstance(ctx_path, str)
703         else preprocess_binary(Path(f"{ctx_path}").read_bytes(), compiler_specs)
704     )
705 
706     dtype_map = {}
707     for type_map in (QNN_QUANT_TYPE_MAP, QNN_TENSOR_TYPE_MAP):
708         for k, v in type_map.items():
709             dtype_map.setdefault(v, k)
710 
711     if custom_info is not None:
712         # since some context binaries might fail to open on host
713         # if they are compiled with special flags:
714         # e.g. weight sharing
715         # use custom information here instead
716         inputs = build_tensor(custom_info["graph_inputs"], dtype_map)
717         outputs = build_tensor(custom_info["graph_outputs"], dtype_map)
718         graph_name = custom_info["graph_name"]
719     else:
720         # get context-binary io tensor info through qnn manager
721         qnn_mgr = PyQnnManagerAdaptor.QnnManager(
722             generate_qnn_executorch_option(compiler_specs),
723             ctx_bin,
724         )
725         assert qnn_mgr.Init().value == 0, "failed to load context binary"
726         # assume we only have one graph in current context
727         graph_name = qnn_mgr.GetGraphNames()[0]
728         qnn_mgr.AllocateTensor(graph_name)
729         inputs = build_tensor(qnn_mgr.GetGraphInputs(graph_name), dtype_map)
730         outputs = build_tensor(qnn_mgr.GetGraphOutputs(graph_name), dtype_map)
731         qnn_mgr.Destroy()
732 
733     # generate graph specific for loading context
734     bundle_prog = build_graph(inputs, outputs)
735     bundle_prog.update({"inputs": inputs, "outputs": outputs})
736     edge_prog_mgr = to_edge(
737         programs={graph_name: bundle_prog["exported_program"]},
738         # do not alter name for custom op
739         compile_config=EdgeCompileConfig(_use_edge_ops=False),
740     )
741     # update meta with context binary
742     for n in edge_prog_mgr._edge_programs[graph_name].graph.nodes:
743         if n.op == "call_function" and OpContextLoader.namespace in str(n.target):
744             n.meta[OpContextLoader.meta_ctx_bin] = ctx_bin
745             break
746 
747     bundle_prog["edge_program_manager"] = edge_prog_mgr.to_backend(
748         QnnPartitioner(compiler_specs)
749     )
750     return bundle_prog
751 
752 
753 def draw_graph(title, path, graph_module: torch.fx.GraphModule):
754     graph = passes.graph_drawer.FxGraphDrawer(graph_module, title)
755     with open(f"{path}/{title}.svg", "wb") as f:
756         f.write(graph.get_dot_graph().create_svg())
757 
758 
759 def generate_multi_graph_program(
760     compiler_specs: List[CompileSpec],
761     processed_bytes: List[bytes],
762     backend_config: ExecutorchBackendConfig = None,
763 ) -> ExecutorchProgramManager:
764     # compile multiple graphs in qcir into single context binary
765     graph_inputs, graph_outputs = {}, {}
766     qnn_mgr = PyQnnManagerAdaptor.QnnManager(
767         generate_qnn_executorch_option(compiler_specs), processed_bytes
768     )
769     assert qnn_mgr.Init().value == 0, "failed to load processed bytes"
770     binary_info = bytes(qnn_mgr.Compile())
771     assert len(binary_info) != 0, "failed to generate QNN context binary"
772     graph_names = qnn_mgr.GetGraphNames()
773     for graph_name in graph_names:
774         graph_inputs[graph_name] = qnn_mgr.GetGraphInputs(graph_name)
775         graph_outputs[graph_name] = qnn_mgr.GetGraphOutputs(graph_name)
776     qnn_mgr.Destroy()
777 
778     # build custom ops with different graph signatures
779     compiler_options = flatbuffer_to_option(compiler_specs[0].value)
780     bundle_progs = [
781         from_context_binary(
782             ctx_path=binary_info,
783             op_name=f"loader_{graph_name}",
784             soc_model=compiler_options.soc_info.soc_model,
785             custom_info={
786                 "graph_inputs": graph_inputs[graph_name],
787                 "graph_outputs": graph_outputs[graph_name],
788                 "graph_name": graph_name,
789             },
790         )
791         for graph_name in graph_names
792     ]
793     # leverage ExecutorchProgramManager for generating pte with multi-methods
794     edge_prog_mgr = to_edge(
795         programs={
796             graph_name: bundle_prog["exported_program"]
797             for graph_name, bundle_prog in zip(graph_names, bundle_progs)
798         },
799         # do not alter name for custom op
800         compile_config=EdgeCompileConfig(_use_edge_ops=False),
801     )
802     # restore meta losed in generating EdgeProgramManager
803     for graph_name in graph_names:
804         for n in edge_prog_mgr._edge_programs[graph_name].graph.nodes:
805             if graph_name in n.name:
806                 n.meta[OpContextLoader.meta_ctx_bin] = binary_info
807                 break
808 
809     return edge_prog_mgr.to_backend(QnnPartitioner(compiler_specs)).to_executorch(
810         config=backend_config or ExecutorchBackendConfig()
811     )
812 
813 
814 def generate_htp_compiler_spec(
815     use_fp16: bool,
816     use_dlbc: bool = False,
817     use_multi_contexts: bool = False,
818 ) -> QnnExecuTorchBackendOptions:
819     """
820     Helper function generating backend options for QNN HTP
821 
822     Args:
823         use_fp16: If true, the model is compiled to QNN HTP fp16 runtime.
824             Note that not all SoC support QNN HTP fp16. Only premium tier SoC
825             like Snapdragon 8 Gen 1 or newer can support HTP fp16.
826         use_dlbc: Deep Learning Bandwidth Compression allows inputs to be
827             compressed, such that the processing bandwidth can be lowered.
828         use_multi_contexts: When multiple contexts are generated inside the same
829             pte, it is possible to reserve a single spill-fill allocation that
830             could be re-used across all the splits.
831 
832     Returns:
833         QnnExecuTorchHtpBackendOptions: backend options for QNN HTP.
834     """
835     htp_options = QnnExecuTorchHtpBackendOptions()
836     htp_options.precision = (
837         QnnExecuTorchHtpPrecision.kHtpFp16
838         if use_fp16
839         else QnnExecuTorchHtpPrecision.kHtpQuantized
840     )
841     # This actually is not an option which can affect the compiled blob.
842     # But we don't have other place to pass this option at execution stage.
843     # TODO: enable voting mechanism in runtime and make this as an option
844     htp_options.performance_mode = QnnExecuTorchHtpPerformanceMode.kHtpBurst
845     htp_options.use_multi_contexts = use_multi_contexts
846     htp_options.use_dlbc = use_dlbc
847     return QnnExecuTorchBackendOptions(
848         backend_type=QnnExecuTorchBackendType.kHtpBackend,
849         htp_options=htp_options,
850     )
851 
852 
853 def generate_qnn_executorch_compiler_spec(
854     soc_model: QcomChipset,
855     backend_options: QnnExecuTorchBackendOptions,
856     debug: bool = False,
857     saver: bool = False,
858     online_prepare: bool = False,
859     dump_intermediate_outputs: bool = False,
860     profile: bool = False,
861     optrace: bool = False,
862     shared_buffer: bool = False,
863     is_from_context_binary: bool = False,
864     multiple_graphs: bool = False,
865     graph_name: str = "forward",
866 ) -> List[CompileSpec]:
867     """
868     Helper function generating compiler specs for Qualcomm AI Engine Direct
869 
870     Args:
871         soc_model: The SoC you plan to run the compiled model. Please check
872             QcomChipset for supported SoC.
873             SM8450 (Snapdragon 8 Gen 1)
874             SM8475(Snapdragon 8 Gen 1+)
875             SM8550(Snapdragon 8 Gen 2)
876             SM8650(Snapdragon 8 Gen 3)
877         backend_options: Options required by different backends.
878         debug: Enable verbose logging. Disclaimer: this option must change in
879             the near future.
880         online_prepare: Compose QNN graph on device if set to True
881         saver: Instead of compiling the model, run QNN Saver. Please check
882             documents of Qualcomm AI Engine Direct SDK. This feature is usually
883             for debugging purpose.
884         dump_intermediate_outputs: If tensor dump is enabled, all intermediate tensors output will be dumped.
885             This option exists for debugging accuracy issues
886         profile: Enable profile the performance of per operator.
887             Note that for now only support kProfileDetailed to
888             profile the performance of each operator with cycle unit.
889         shared_buffer: Enables usage of shared buffer between application
890             and backend for graph I/O.
891         is_from_context_binary: True if current graph comes from pre-built context binary.
892         multiple_graphs: True if multiple methods are expected to have in single .pte file.
893             Please see test cases for post-processing example.
894         graph_name: Assign unique graph name if 'multiple_graphs' is used.
895 
896     Returns:
897         List[CompileSpec]: Compiler specs for Qualcomm AI Engine Direct.
898 
899     Raises:
900         ValueError: The value QcomChipset is currently not supported.
901         ValueError: Confliction between compiler specs.
902     """
903     _supported_soc_models = {soc_model.value for soc_model in QcomChipset}
904     if soc_model not in _supported_soc_models:
905         raise ValueError(f"unknown SoC model for QNN: {soc_model}")
906 
907     if profile and dump_intermediate_outputs:
908         warnings.warn(
909             "It is not recommended to turn on both profiling and dump_intermediate_outputs the same time"
910             ", because dump_intermediate_outputs will cause performance drop.",
911             stacklevel=1,
912         )
913 
914     qnn_executorch_options = QnnExecuTorchOptions(
915         _soc_info_table[soc_model], backend_options
916     )
917     qnn_executorch_options.graph_name = graph_name
918     qnn_executorch_options.log_level = (
919         QnnExecuTorchLogLevel.kLogLevelDebug
920         if debug
921         else QnnExecuTorchLogLevel.kLogLevelWarn
922     )
923 
924     qnn_executorch_options.dump_intermediate_outputs = dump_intermediate_outputs
925 
926     if saver:
927         qnn_executorch_options.library_path = "libQnnSaver.so"
928 
929     if optrace:
930         qnn_executorch_options.profile_level = QnnExecuTorchProfileLevel.kProfileOptrace
931     elif profile:
932         qnn_executorch_options.profile_level = (
933             QnnExecuTorchProfileLevel.kProfileDetailed
934         )
935     else:
936         qnn_executorch_options.profile_level = QnnExecuTorchProfileLevel.kProfileOff
937 
938     if (
939         online_prepare
940         and backend_options.backend_type == QnnExecuTorchBackendType.kHtpBackend
941         and backend_options.htp_options.use_multi_contexts
942     ):
943         raise ValueError(
944             "'use_multi_context' could not function in online prepare mode, "
945             "please set 'online_prepare' to False"
946         )
947 
948     qnn_executorch_options.shared_buffer = shared_buffer
949     qnn_executorch_options.online_prepare = online_prepare
950     qnn_executorch_options.is_from_context_binary = is_from_context_binary
951     qnn_executorch_options.multiple_graphs = multiple_graphs
952 
953     if multiple_graphs:
954         # enable weight sharing mechanism if multiple graphs appear
955         if backend_options.backend_type == QnnExecuTorchBackendType.kHtpBackend:
956             backend_options.htp_options.use_weight_sharing = True
957 
958     return [
959         CompileSpec(QCOM_QNN_COMPILE_SPEC, option_to_flatbuffer(qnn_executorch_options))
960     ]
961 
962 
963 def get_soc_to_arch_map():
964     return {
965         "SSG2115P": HtpArch.V73,
966         "SM8650": HtpArch.V75,
967         "SM8550": HtpArch.V73,
968         "SM8475": HtpArch.V69,
969         "SM8450": HtpArch.V69,
970         "SA8295": HtpArch.V68,
971     }
972 
973 
974 def get_soc_to_chipset_map():
975     return {
976         "SSG2115P": QcomChipset.SSG2115P,
977         "SM8650": QcomChipset.SM8650,
978         "SM8550": QcomChipset.SM8550,
979         "SM8475": QcomChipset.SM8475,
980         "SM8450": QcomChipset.SM8450,
981         "SA8295": QcomChipset.SA8295,
982     }
983 
984 
985 def tag_quant_io(gm: torch.fx.GraphModule, get_quant_io_dtype_fn: Callable):
986     """
987     Tag io nodes which get/output quantized tensor. No need to insert q/dq in qnn_preprocess
988     """
989     for node in gm.graph.nodes:
990         if dtype := get_quant_io_dtype_fn(node):
991             node.meta[QCOM_QUANTIZED_IO] = dtype
992