1# Copyright (c) Qualcomm Innovation Center, Inc. 2# All rights reserved 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6import torch 7from executorch.backends.qualcomm.utils.constants import QCOM_ENCODING, QCOM_QUANT_ATTRS 8from executorch.exir.dialects._ops import ops as exir_ops 9from executorch.exir.pass_base import ExportPass, PassResult 10 11 12class ReplaceIndexPutInput(ExportPass): 13 """ 14 Index put input workaround for quantized module 15 """ 16 17 dq_q_map = { 18 # per tensor 19 exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, 20 exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor: exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, 21 # per channel 22 exir_ops.edge.quantized_decomposed.dequantize_per_channel.default: exir_ops.edge.quantized_decomposed.quantize_per_channel.default, 23 } 24 25 def __init__(self, edge_program: torch.export.ExportedProgram): 26 super(ReplaceIndexPutInput, self).__init__() 27 self.edge_program = edge_program 28 29 def call(self, graph_module: torch.fx.GraphModule): 30 graph = graph_module.graph 31 for node in graph.nodes: 32 if node.target == exir_ops.edge.aten.index_put.default: 33 if ( 34 copy_node := list(node.users)[0] 35 ) and copy_node.target == exir_ops.edge.aten.copy.default: 36 m_buffer_node = copy_node.args[0] 37 bad_frozen_node = node.args[0] 38 if QCOM_QUANT_ATTRS in bad_frozen_node.meta: 39 m_buffer_node.meta[QCOM_QUANT_ATTRS] = bad_frozen_node.meta[ 40 QCOM_QUANT_ATTRS 41 ] 42 m_buffer_node.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING] = ( 43 self.dq_q_map[ 44 m_buffer_node.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING] 45 ] 46 ) 47 with graph.inserting_after(bad_frozen_node): 48 node.replace_input_with(bad_frozen_node, m_buffer_node) 49 else: 50 continue 51 52 graph.eliminate_dead_code() 53 graph_module.recompile() 54 return PassResult(graph_module, True) 55