1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import copy 8import logging 9from typing import Optional, Set 10 11import torch 12from torch._export.utils import get_buffer, get_lifted_tensor_constant, get_param 13 14from torch.export import ExportedProgram 15from torch.export.exported_program import InputSpec, TensorArgument 16from torch.export.graph_signature import InputKind 17 18 19def _get_attribute_or_constants( 20 exported_program: ExportedProgram, node: torch.fx.Node 21) -> Optional[torch.Tensor]: 22 # get either attribute node or constant constant 23 maybe_param = get_param(exported_program, node) 24 maybe_buffer = get_buffer(exported_program, node) 25 maybe_lifted_tensor = get_lifted_tensor_constant(exported_program, node) 26 27 constant_or_attribute = None 28 if maybe_param is not None: 29 constant_or_attribute = maybe_param 30 elif maybe_buffer is not None: 31 constant_or_attribute = maybe_buffer 32 elif maybe_lifted_tensor is not None: 33 constant_or_attribute = maybe_lifted_tensor 34 return constant_or_attribute 35 36 37# TODO: add other passes to duplicate call_function nodes 38def duplicate_constant_node( 39 exported_program: ExportedProgram, candidate_node: str 40) -> Set[str]: 41 """ 42 A pass to duplicate the attributes/constants node (the candidate_node) in the graph. Mostly used for duplicating light-weight data. 43 If the data is too large, try tag it with "no_copy" to prevent high memory usage and make it as part of the output. 44 45 Args: 46 exported_program: the exported program to be modified. If constants nodes are copied, they will be added as new 47 placeholder and the state_dict will be updated 48 candidate_node: the name of the constant node to be duplicated 49 50 Returns: 51 The set of the names of the new constant nodes 52 """ 53 to_be_copied = [ 54 node 55 for node in exported_program.graph.nodes 56 if node.name == candidate_node and node.op == "placeholder" 57 ] 58 if len(to_be_copied) == 0: 59 logging.info("no constant node to be copied") 60 return set() 61 new_input_specs = [] 62 old_signature = exported_program.graph_signature 63 copied_nodes = set() 64 for idx, node in enumerate(exported_program.graph.nodes): 65 if node.op != "placeholder": 66 continue 67 old_input_spec = old_signature.input_specs[idx] 68 old_input_spec_copy = copy.deepcopy(old_input_spec) 69 if node == to_be_copied[0]: 70 constant_or_attribute_node = node 71 constant_or_attribute = _get_attribute_or_constants(exported_program, node) 72 if constant_or_attribute is None: 73 raise RuntimeError( 74 f"placeholder node for non-params, non-buffer, and non-tensor constants should not be tagged: {node} " 75 ) 76 users = list(node.users.keys()) 77 for ith in range(1, len(node.users)): 78 copy_constant_or_attribute_fqn = node.name + f"_copy_{ith - 1}" 79 with exported_program.graph.inserting_before( 80 constant_or_attribute_node 81 ): 82 copied_constant_or_attribute_node = ( 83 exported_program.graph.placeholder( 84 copy_constant_or_attribute_fqn 85 ) 86 ) 87 copied_nodes.add(copy_constant_or_attribute_fqn) 88 logging.info( 89 f"Copying constant nodes {node.name} and creating {copy_constant_or_attribute_fqn}" 90 ) 91 for k, v in node.meta.items(): 92 copied_constant_or_attribute_node.meta[k] = v 93 copied_constant_or_attribute_node.meta["val"] = ( 94 constant_or_attribute_node.meta["val"] 95 ) 96 new_args = tuple( 97 [ 98 ( 99 arg 100 if arg != constant_or_attribute_node 101 else copied_constant_or_attribute_node 102 ) 103 for arg in users[ith].args 104 ] 105 ) 106 new_kwargs = dict( 107 { 108 ( 109 key, 110 ( 111 value 112 if value != constant_or_attribute_node 113 else copied_constant_or_attribute_node 114 ), 115 ) 116 for key, value in users[ith].kwargs 117 } 118 ) 119 users[ith].args = new_args 120 users[ith].kwargs = new_kwargs 121 if old_input_spec.kind == InputKind.CONSTANT_TENSOR: 122 exported_program.constants[copy_constant_or_attribute_fqn] = ( 123 copy.deepcopy(constant_or_attribute) 124 ) 125 elif ( 126 old_input_spec.kind == InputKind.BUFFER 127 and old_input_spec.persistent is False 128 ): 129 # non persistent buffer will be in the .constants 130 exported_program.constants[copy_constant_or_attribute_fqn] = ( 131 copy.deepcopy(constant_or_attribute) 132 ) 133 else: 134 exported_program.state_dict[copy_constant_or_attribute_fqn] = ( 135 copy.deepcopy(constant_or_attribute) 136 ) 137 new_input_specs.append( 138 InputSpec( 139 kind=old_input_spec.kind, 140 arg=TensorArgument(name=copy_constant_or_attribute_fqn), 141 target=old_input_spec.target, 142 persistent=old_input_spec.persistent, 143 ) 144 ) 145 # Ensure we add the original input spec to the last one, because all the copied nodes 146 # are inserted before the candidate node. 147 new_input_specs.append(old_input_spec_copy) 148 149 exported_program.graph_signature.input_specs = new_input_specs 150 exported_program.graph_module.recompile() 151 exported_program._validate() 152 return copied_nodes 153