xref: /aosp_15_r20/external/pytorch/torch/_inductor/freezing.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from __future__ import annotations
3
4import itertools
5import logging
6import weakref
7from typing import Any, List, Optional, Tuple
8
9import torch
10import torch.utils._pytree as pytree
11from torch._dynamo.utils import dynamo_timed, lazy_format_graph_code
12from torch._functorch.aot_autograd import MutationType
13from torch._functorch.compile_utils import fx_graph_cse
14from torch._inductor.constant_folding import constant_fold, replace_node_with_constant
15from torch._inductor.fx_passes.freezing_patterns import freezing_passes
16from torch._inductor.fx_passes.post_grad import view_to_reshape
17
18from . import config
19
20
21aten = torch.ops.aten
22prims = torch.ops.prims
23
24log = logging.getLogger(__name__)
25
26
27def replace_params_with_constants(
28    gm: torch.fx.GraphModule,
29    flat_params: list[Any],
30    fw_metadata: torch._functorch.aot_autograd.ViewAndMutationMeta,
31) -> List[int]:
32    """
33    Replaces the parameters of a PyTorch GraphModule with constants wherever possible.
34    Returns a list of indices representing the input parameters that were not converted to constants.
35    """
36    params = gm.graph.find_nodes(op="placeholder")
37    fake_inp_nodes = params[: len(params)]
38    preserved_arg_indices = []
39    aliased_input_args = [
40        out_info.base_idx
41        for out_info in fw_metadata.output_info
42        if out_info.base_idx is not None
43    ]
44
45    # TODO (tmanlaibaatar) figure out why this is different
46    # from mutated_inp_runtime_indices
47    mutated_inps = [
48        i
49        for i, m in enumerate(fw_metadata.input_info)
50        if m.mutation_type
51        in (MutationType.MUTATED_IN_GRAPH, MutationType.MUTATED_OUT_GRAPH)
52    ]
53
54    for i, (real_input, node) in enumerate(zip(flat_params, fake_inp_nodes)):
55        if i in mutated_inps or i in aliased_input_args:
56            preserved_arg_indices.append(i)
57            continue
58        replace_node_with_constant(gm, node, real_input)
59    # add on non param inputs
60    preserved_arg_indices.extend(range(len(flat_params), len(params)))
61    # is this necessary ?
62    gm.recompile()
63    return preserved_arg_indices
64
65
66def freeze(
67    dynamo_gm: torch.fx.GraphModule,
68    aot_autograd_gm: torch.fx.GraphModule,
69    example_inputs: List[torch._subclasses.FakeTensor],
70) -> Tuple[torch.fx.GraphModule, List[int]]:
71    """
72    Inlines parameters that are not mutated into constants and optimizes the graph through constant propagation
73    and other techniques. If enabled, the function also discards the original parameters of the module for memory efficiency.
74
75    Assumes that this function is run in dynamo tracing post aot_autograd.
76
77    Args:
78        dynamo_gm (torch.fx.GraphModule): The Dynamo constructed GraphModule.
79        aot_autograd_gm (torch.fx.GraphModule): The aot_autograd constructed GraphModule to be frozen.
80        example_inputs (List[torch.Tensor]): A list of example input tensors to be used in the freezing process.
81
82    Returns:
83        Tuple[torch.fx.GraphModule, List[int]]: A tuple containing the frozen GraphModule and a list of indices
84        of the inputs that were preserved (not turned into constants).
85    """
86    # We have convert conv's weight to channels last which may meet error for .view
87    # when doing fake_tensor_prop. So we need to convert view to reshape first.
88    # See the details in fx_codegen_and_compile of compile_fx.py.
89    view_to_reshape(aot_autograd_gm)
90
91    if tracing_context := torch._guards.TracingContext.try_get():
92        fw_metadata = tracing_context.fw_metadata
93        params_flat = tracing_context.params_flat
94        assert fw_metadata is not None and params_flat is not None
95
96        preserved_arg_indices = replace_params_with_constants(
97            aot_autograd_gm, params_flat, fw_metadata
98        )
99    else:
100        inputs = aot_autograd_gm.graph.find_nodes(op="placeholder")
101        preserved_arg_indices = list(range(len(inputs)))
102
103    # TODO - further restrict cse ? right now needed to dedup aliasing ops
104    cse_graph = fx_graph_cse(aot_autograd_gm.graph)
105    aot_autograd_gm.graph = cse_graph
106    aot_autograd_gm.recompile()
107
108    aot_example_inputs = [example_inputs[ind] for ind in preserved_arg_indices]
109    freezing_passes(aot_autograd_gm, aot_example_inputs)
110
111    constant_fold(aot_autograd_gm)
112    # invalidate nn Modules
113    if config.freezing_discard_parameters:
114        invalidate_eager_modules()
115        discard_traced_gm_params(dynamo_gm)
116
117    log.debug(
118        "%s", lazy_format_graph_code("FROZEN GRAPH", aot_autograd_gm, colored=True)
119    )
120
121    return aot_autograd_gm, preserved_arg_indices
122
123
124class ErasedTensor(torch.Tensor):
125    @staticmethod
126    def __new__(cls, elem, name, owning_mod):
127        return super().__new__(cls, elem.to(device="meta"))
128
129    def __init__(self, elem, name: Optional[str], mod) -> None:
130        self.erased_name = name
131        self.owning_mod_ref = weakref.ref(mod)
132
133    @classmethod
134    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
135        erased_tensors = [
136            e
137            for e in pytree.arg_tree_leaves(*args, **kwargs)
138            if isinstance(e, ErasedTensor)
139        ]
140        assert len(erased_tensors) > 0
141        e = erased_tensors[0]
142
143        raise RuntimeError(
144            f"Trying to run Pytorch Eager Module after Dynamo Freezing. "
145            "The original parameters have been discarded for memory efficiency. "
146            f"Found in op {func} for erased parameter {e.erased_name} of {e.owning_mod_ref()}"
147        )
148
149
150def invalidate_eager_modules():
151    with torch.utils._python_dispatch._disable_current_modes():
152        for (
153            mod
154        ) in torch._guards.TracingContext.get().module_context.nn_modules.values():
155            if not isinstance(mod, torch.nn.Module):
156                continue
157
158            for attr_name, tensor in list(
159                itertools.chain(
160                    mod.named_parameters(recurse=False),
161                    mod.named_buffers(recurse=False),
162                )
163            ):
164                with torch._dispatch.python.no_python_dispatcher():
165                    e_t = ErasedTensor(tensor, attr_name, mod)
166                if isinstance(tensor, torch.nn.Parameter):
167                    e_t.requires_grad_(True)
168                    e_t._is_param = True  # type: ignore[attr-defined]
169                setattr(mod, attr_name, e_t)
170
171
172def discard_traced_gm_params(mod: torch.fx.GraphModule):
173    with torch.utils._python_dispatch._disable_current_modes():
174        for attr_name, tensor in list(
175            itertools.chain(
176                mod.named_parameters(recurse=False), mod.named_buffers(recurse=False)
177            )
178        ):
179            with torch._dispatch.python.no_python_dispatcher():
180                e_t = ErasedTensor(tensor, attr_name, mod)
181            if isinstance(tensor, torch.nn.Parameter):
182                e_t.requires_grad_(True)
183                e_t._is_param = True  # type: ignore[attr-defined]
184            setattr(mod, attr_name, e_t)
185
186
187def enforce_output_layout(gm: torch.fx.GraphModule):
188    """
189    Make sure the output node's layout does not change due to compiler optimizations
190    by adding aten.as_strided nodes with the expected strides.
191
192    Only used for inference so we can assume all graph outputs are model outputs.
193    """
194    *_, output_node = gm.graph.nodes
195    out_list = output_node.args[0]
196    with gm.graph.inserting_before(output_node):
197        for n in out_list:
198            if not isinstance(
199                n.meta["val"], torch.Tensor
200            ) or not torch._prims_common.is_non_overlapping_and_dense(n.meta["val"]):
201                continue
202
203            # add a node to enforce eager layout
204            ft = n.meta["val"]
205            new_node = gm.graph.call_function(
206                prims.inductor_force_stride_order.default, (n, ft.stride())
207            )
208
209            # can not call
210            # n.replace_all_uses_with(new_node)
211            # since it will replace the usage of n in new_node itself.
212            output_node.replace_input_with(n, new_node)
213
214    gm.graph.lint()
215    gm.recompile()
216
217
218def enforce_as_strided_input_layout(gm: torch.fx.GraphModule):
219    """
220    Make sure the as_strided node's input's layout does not change due to compiler
221    optimizations, because the as_strided strides info depends on input tensor stride info.
222    """
223
224    as_strided_ops = [
225        torch.ops.aten.as_strided.default,
226        torch.ops.aten.as_strided_.default,
227        torch.ops.aten.as_strided_scatter.default,
228    ]
229    strided_nodes = [n for n in gm.graph.nodes if n.target in as_strided_ops]
230    for n in strided_nodes:
231        with gm.graph.inserting_before(n):
232            # add a node to enforce eager layout
233            ft = n.args[0].meta["val"]
234            new_node = gm.graph.call_function(
235                prims.inductor_force_stride_order.default, (n.args[0], ft.stride())
236            )
237            n.replace_input_with(n.args[0], new_node)
238
239    gm.graph.lint()
240    gm.recompile()
241
242
243def convert_conv_weights_to_channels_last(gm: torch.fx.GraphModule):
244    """
245    Convert 4d convolution weight tensor to channels last format.
246
247    This pass is performed before freezing so the added nodes can be constant
248    folded by freezing.
249    """
250    with dynamo_timed("convert_conv_weights_to_channels_last"):
251        convs = [n for n in gm.graph.nodes if n.target == aten.convolution.default]
252        for conv in convs:
253            weight_node = conv.args[1]
254            if len(weight_node.meta["val"].size()) != 4 or weight_node.meta[
255                "val"
256            ].is_contiguous(memory_format=torch.channels_last):
257                # not a 4d tensor or already channels last, skip
258                continue
259
260            with gm.graph.inserting_before(conv):
261                new_node = gm.graph.call_function(
262                    aten.clone.default,
263                    (weight_node,),
264                    {"memory_format": torch.channels_last},
265                )
266                conv.replace_input_with(weight_node, new_node)
267
268        enforce_as_strided_input_layout(gm)
269        enforce_output_layout(gm)
270