1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import copy
8import logging
9from typing import Optional, Set
10
11import torch
12from torch._export.utils import get_buffer, get_lifted_tensor_constant, get_param
13
14from torch.export import ExportedProgram
15from torch.export.exported_program import InputSpec, TensorArgument
16from torch.export.graph_signature import InputKind
17
18
19def _get_attribute_or_constants(
20    exported_program: ExportedProgram, node: torch.fx.Node
21) -> Optional[torch.Tensor]:
22    # get either attribute node or constant constant
23    maybe_param = get_param(exported_program, node)
24    maybe_buffer = get_buffer(exported_program, node)
25    maybe_lifted_tensor = get_lifted_tensor_constant(exported_program, node)
26
27    constant_or_attribute = None
28    if maybe_param is not None:
29        constant_or_attribute = maybe_param
30    elif maybe_buffer is not None:
31        constant_or_attribute = maybe_buffer
32    elif maybe_lifted_tensor is not None:
33        constant_or_attribute = maybe_lifted_tensor
34    return constant_or_attribute
35
36
37# TODO: add other passes to duplicate call_function nodes
38def duplicate_constant_node(
39    exported_program: ExportedProgram, candidate_node: str
40) -> Set[str]:
41    """
42    A pass to duplicate the attributes/constants node (the candidate_node) in the graph. Mostly used for duplicating light-weight data.
43    If the data is too large, try tag it with "no_copy" to prevent high memory usage and make it as part of the output.
44
45    Args:
46        exported_program: the exported program to be modified. If constants nodes are copied, they will be added as new
47        placeholder and the state_dict will be updated
48        candidate_node: the name of the constant node to be duplicated
49
50    Returns:
51        The set of the names of the new constant nodes
52    """
53    to_be_copied = [
54        node
55        for node in exported_program.graph.nodes
56        if node.name == candidate_node and node.op == "placeholder"
57    ]
58    if len(to_be_copied) == 0:
59        logging.info("no constant node to be copied")
60        return set()
61    new_input_specs = []
62    old_signature = exported_program.graph_signature
63    copied_nodes = set()
64    for idx, node in enumerate(exported_program.graph.nodes):
65        if node.op != "placeholder":
66            continue
67        old_input_spec = old_signature.input_specs[idx]
68        old_input_spec_copy = copy.deepcopy(old_input_spec)
69        if node == to_be_copied[0]:
70            constant_or_attribute_node = node
71            constant_or_attribute = _get_attribute_or_constants(exported_program, node)
72            if constant_or_attribute is None:
73                raise RuntimeError(
74                    f"placeholder node for non-params, non-buffer, and non-tensor constants should not be tagged: {node} "
75                )
76            users = list(node.users.keys())
77            for ith in range(1, len(node.users)):
78                copy_constant_or_attribute_fqn = node.name + f"_copy_{ith - 1}"
79                with exported_program.graph.inserting_before(
80                    constant_or_attribute_node
81                ):
82                    copied_constant_or_attribute_node = (
83                        exported_program.graph.placeholder(
84                            copy_constant_or_attribute_fqn
85                        )
86                    )
87                    copied_nodes.add(copy_constant_or_attribute_fqn)
88                    logging.info(
89                        f"Copying constant nodes {node.name} and creating {copy_constant_or_attribute_fqn}"
90                    )
91                    for k, v in node.meta.items():
92                        copied_constant_or_attribute_node.meta[k] = v
93                    copied_constant_or_attribute_node.meta["val"] = (
94                        constant_or_attribute_node.meta["val"]
95                    )
96                    new_args = tuple(
97                        [
98                            (
99                                arg
100                                if arg != constant_or_attribute_node
101                                else copied_constant_or_attribute_node
102                            )
103                            for arg in users[ith].args
104                        ]
105                    )
106                    new_kwargs = dict(
107                        {
108                            (
109                                key,
110                                (
111                                    value
112                                    if value != constant_or_attribute_node
113                                    else copied_constant_or_attribute_node
114                                ),
115                            )
116                            for key, value in users[ith].kwargs
117                        }
118                    )
119                    users[ith].args = new_args
120                    users[ith].kwargs = new_kwargs
121                    if old_input_spec.kind == InputKind.CONSTANT_TENSOR:
122                        exported_program.constants[copy_constant_or_attribute_fqn] = (
123                            copy.deepcopy(constant_or_attribute)
124                        )
125                    elif (
126                        old_input_spec.kind == InputKind.BUFFER
127                        and old_input_spec.persistent is False
128                    ):
129                        # non persistent buffer will be in the .constants
130                        exported_program.constants[copy_constant_or_attribute_fqn] = (
131                            copy.deepcopy(constant_or_attribute)
132                        )
133                    else:
134                        exported_program.state_dict[copy_constant_or_attribute_fqn] = (
135                            copy.deepcopy(constant_or_attribute)
136                        )
137                    new_input_specs.append(
138                        InputSpec(
139                            kind=old_input_spec.kind,
140                            arg=TensorArgument(name=copy_constant_or_attribute_fqn),
141                            target=old_input_spec.target,
142                            persistent=old_input_spec.persistent,
143                        )
144                    )
145        # Ensure we add the original input spec to the last one, because all the copied nodes
146        # are inserted before the candidate node.
147        new_input_specs.append(old_input_spec_copy)
148
149    exported_program.graph_signature.input_specs = new_input_specs
150    exported_program.graph_module.recompile()
151    exported_program._validate()
152    return copied_nodes
153