xref: /aosp_15_r20/external/executorch/backends/qualcomm/_passes/replace_index_put_input.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Qualcomm Innovation Center, Inc.
2# All rights reserved
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6import torch
7from executorch.backends.qualcomm.utils.constants import QCOM_ENCODING, QCOM_QUANT_ATTRS
8from executorch.exir.dialects._ops import ops as exir_ops
9from executorch.exir.pass_base import ExportPass, PassResult
10
11
12class ReplaceIndexPutInput(ExportPass):
13    """
14    Index put input workaround for quantized module
15    """
16
17    dq_q_map = {
18        # per tensor
19        exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
20        exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor: exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor,
21        # per channel
22        exir_ops.edge.quantized_decomposed.dequantize_per_channel.default: exir_ops.edge.quantized_decomposed.quantize_per_channel.default,
23    }
24
25    def __init__(self, edge_program: torch.export.ExportedProgram):
26        super(ReplaceIndexPutInput, self).__init__()
27        self.edge_program = edge_program
28
29    def call(self, graph_module: torch.fx.GraphModule):
30        graph = graph_module.graph
31        for node in graph.nodes:
32            if node.target == exir_ops.edge.aten.index_put.default:
33                if (
34                    copy_node := list(node.users)[0]
35                ) and copy_node.target == exir_ops.edge.aten.copy.default:
36                    m_buffer_node = copy_node.args[0]
37                    bad_frozen_node = node.args[0]
38                    if QCOM_QUANT_ATTRS in bad_frozen_node.meta:
39                        m_buffer_node.meta[QCOM_QUANT_ATTRS] = bad_frozen_node.meta[
40                            QCOM_QUANT_ATTRS
41                        ]
42                        m_buffer_node.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING] = (
43                            self.dq_q_map[
44                                m_buffer_node.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING]
45                            ]
46                        )
47                    with graph.inserting_after(bad_frozen_node):
48                        node.replace_input_with(bad_frozen_node, m_buffer_node)
49                else:
50                    continue
51
52        graph.eliminate_dead_code()
53        graph_module.recompile()
54        return PassResult(graph_module, True)
55