1# mypy: allow-untyped-defs 2from __future__ import annotations 3 4import itertools 5import logging 6import weakref 7from typing import Any, List, Optional, Tuple 8 9import torch 10import torch.utils._pytree as pytree 11from torch._dynamo.utils import dynamo_timed, lazy_format_graph_code 12from torch._functorch.aot_autograd import MutationType 13from torch._functorch.compile_utils import fx_graph_cse 14from torch._inductor.constant_folding import constant_fold, replace_node_with_constant 15from torch._inductor.fx_passes.freezing_patterns import freezing_passes 16from torch._inductor.fx_passes.post_grad import view_to_reshape 17 18from . import config 19 20 21aten = torch.ops.aten 22prims = torch.ops.prims 23 24log = logging.getLogger(__name__) 25 26 27def replace_params_with_constants( 28 gm: torch.fx.GraphModule, 29 flat_params: list[Any], 30 fw_metadata: torch._functorch.aot_autograd.ViewAndMutationMeta, 31) -> List[int]: 32 """ 33 Replaces the parameters of a PyTorch GraphModule with constants wherever possible. 34 Returns a list of indices representing the input parameters that were not converted to constants. 35 """ 36 params = gm.graph.find_nodes(op="placeholder") 37 fake_inp_nodes = params[: len(params)] 38 preserved_arg_indices = [] 39 aliased_input_args = [ 40 out_info.base_idx 41 for out_info in fw_metadata.output_info 42 if out_info.base_idx is not None 43 ] 44 45 # TODO (tmanlaibaatar) figure out why this is different 46 # from mutated_inp_runtime_indices 47 mutated_inps = [ 48 i 49 for i, m in enumerate(fw_metadata.input_info) 50 if m.mutation_type 51 in (MutationType.MUTATED_IN_GRAPH, MutationType.MUTATED_OUT_GRAPH) 52 ] 53 54 for i, (real_input, node) in enumerate(zip(flat_params, fake_inp_nodes)): 55 if i in mutated_inps or i in aliased_input_args: 56 preserved_arg_indices.append(i) 57 continue 58 replace_node_with_constant(gm, node, real_input) 59 # add on non param inputs 60 preserved_arg_indices.extend(range(len(flat_params), len(params))) 61 # is this necessary ? 62 gm.recompile() 63 return preserved_arg_indices 64 65 66def freeze( 67 dynamo_gm: torch.fx.GraphModule, 68 aot_autograd_gm: torch.fx.GraphModule, 69 example_inputs: List[torch._subclasses.FakeTensor], 70) -> Tuple[torch.fx.GraphModule, List[int]]: 71 """ 72 Inlines parameters that are not mutated into constants and optimizes the graph through constant propagation 73 and other techniques. If enabled, the function also discards the original parameters of the module for memory efficiency. 74 75 Assumes that this function is run in dynamo tracing post aot_autograd. 76 77 Args: 78 dynamo_gm (torch.fx.GraphModule): The Dynamo constructed GraphModule. 79 aot_autograd_gm (torch.fx.GraphModule): The aot_autograd constructed GraphModule to be frozen. 80 example_inputs (List[torch.Tensor]): A list of example input tensors to be used in the freezing process. 81 82 Returns: 83 Tuple[torch.fx.GraphModule, List[int]]: A tuple containing the frozen GraphModule and a list of indices 84 of the inputs that were preserved (not turned into constants). 85 """ 86 # We have convert conv's weight to channels last which may meet error for .view 87 # when doing fake_tensor_prop. So we need to convert view to reshape first. 88 # See the details in fx_codegen_and_compile of compile_fx.py. 89 view_to_reshape(aot_autograd_gm) 90 91 if tracing_context := torch._guards.TracingContext.try_get(): 92 fw_metadata = tracing_context.fw_metadata 93 params_flat = tracing_context.params_flat 94 assert fw_metadata is not None and params_flat is not None 95 96 preserved_arg_indices = replace_params_with_constants( 97 aot_autograd_gm, params_flat, fw_metadata 98 ) 99 else: 100 inputs = aot_autograd_gm.graph.find_nodes(op="placeholder") 101 preserved_arg_indices = list(range(len(inputs))) 102 103 # TODO - further restrict cse ? right now needed to dedup aliasing ops 104 cse_graph = fx_graph_cse(aot_autograd_gm.graph) 105 aot_autograd_gm.graph = cse_graph 106 aot_autograd_gm.recompile() 107 108 aot_example_inputs = [example_inputs[ind] for ind in preserved_arg_indices] 109 freezing_passes(aot_autograd_gm, aot_example_inputs) 110 111 constant_fold(aot_autograd_gm) 112 # invalidate nn Modules 113 if config.freezing_discard_parameters: 114 invalidate_eager_modules() 115 discard_traced_gm_params(dynamo_gm) 116 117 log.debug( 118 "%s", lazy_format_graph_code("FROZEN GRAPH", aot_autograd_gm, colored=True) 119 ) 120 121 return aot_autograd_gm, preserved_arg_indices 122 123 124class ErasedTensor(torch.Tensor): 125 @staticmethod 126 def __new__(cls, elem, name, owning_mod): 127 return super().__new__(cls, elem.to(device="meta")) 128 129 def __init__(self, elem, name: Optional[str], mod) -> None: 130 self.erased_name = name 131 self.owning_mod_ref = weakref.ref(mod) 132 133 @classmethod 134 def __torch_dispatch__(cls, func, types, args=(), kwargs=None): 135 erased_tensors = [ 136 e 137 for e in pytree.arg_tree_leaves(*args, **kwargs) 138 if isinstance(e, ErasedTensor) 139 ] 140 assert len(erased_tensors) > 0 141 e = erased_tensors[0] 142 143 raise RuntimeError( 144 f"Trying to run Pytorch Eager Module after Dynamo Freezing. " 145 "The original parameters have been discarded for memory efficiency. " 146 f"Found in op {func} for erased parameter {e.erased_name} of {e.owning_mod_ref()}" 147 ) 148 149 150def invalidate_eager_modules(): 151 with torch.utils._python_dispatch._disable_current_modes(): 152 for ( 153 mod 154 ) in torch._guards.TracingContext.get().module_context.nn_modules.values(): 155 if not isinstance(mod, torch.nn.Module): 156 continue 157 158 for attr_name, tensor in list( 159 itertools.chain( 160 mod.named_parameters(recurse=False), 161 mod.named_buffers(recurse=False), 162 ) 163 ): 164 with torch._dispatch.python.no_python_dispatcher(): 165 e_t = ErasedTensor(tensor, attr_name, mod) 166 if isinstance(tensor, torch.nn.Parameter): 167 e_t.requires_grad_(True) 168 e_t._is_param = True # type: ignore[attr-defined] 169 setattr(mod, attr_name, e_t) 170 171 172def discard_traced_gm_params(mod: torch.fx.GraphModule): 173 with torch.utils._python_dispatch._disable_current_modes(): 174 for attr_name, tensor in list( 175 itertools.chain( 176 mod.named_parameters(recurse=False), mod.named_buffers(recurse=False) 177 ) 178 ): 179 with torch._dispatch.python.no_python_dispatcher(): 180 e_t = ErasedTensor(tensor, attr_name, mod) 181 if isinstance(tensor, torch.nn.Parameter): 182 e_t.requires_grad_(True) 183 e_t._is_param = True # type: ignore[attr-defined] 184 setattr(mod, attr_name, e_t) 185 186 187def enforce_output_layout(gm: torch.fx.GraphModule): 188 """ 189 Make sure the output node's layout does not change due to compiler optimizations 190 by adding aten.as_strided nodes with the expected strides. 191 192 Only used for inference so we can assume all graph outputs are model outputs. 193 """ 194 *_, output_node = gm.graph.nodes 195 out_list = output_node.args[0] 196 with gm.graph.inserting_before(output_node): 197 for n in out_list: 198 if not isinstance( 199 n.meta["val"], torch.Tensor 200 ) or not torch._prims_common.is_non_overlapping_and_dense(n.meta["val"]): 201 continue 202 203 # add a node to enforce eager layout 204 ft = n.meta["val"] 205 new_node = gm.graph.call_function( 206 prims.inductor_force_stride_order.default, (n, ft.stride()) 207 ) 208 209 # can not call 210 # n.replace_all_uses_with(new_node) 211 # since it will replace the usage of n in new_node itself. 212 output_node.replace_input_with(n, new_node) 213 214 gm.graph.lint() 215 gm.recompile() 216 217 218def enforce_as_strided_input_layout(gm: torch.fx.GraphModule): 219 """ 220 Make sure the as_strided node's input's layout does not change due to compiler 221 optimizations, because the as_strided strides info depends on input tensor stride info. 222 """ 223 224 as_strided_ops = [ 225 torch.ops.aten.as_strided.default, 226 torch.ops.aten.as_strided_.default, 227 torch.ops.aten.as_strided_scatter.default, 228 ] 229 strided_nodes = [n for n in gm.graph.nodes if n.target in as_strided_ops] 230 for n in strided_nodes: 231 with gm.graph.inserting_before(n): 232 # add a node to enforce eager layout 233 ft = n.args[0].meta["val"] 234 new_node = gm.graph.call_function( 235 prims.inductor_force_stride_order.default, (n.args[0], ft.stride()) 236 ) 237 n.replace_input_with(n.args[0], new_node) 238 239 gm.graph.lint() 240 gm.recompile() 241 242 243def convert_conv_weights_to_channels_last(gm: torch.fx.GraphModule): 244 """ 245 Convert 4d convolution weight tensor to channels last format. 246 247 This pass is performed before freezing so the added nodes can be constant 248 folded by freezing. 249 """ 250 with dynamo_timed("convert_conv_weights_to_channels_last"): 251 convs = [n for n in gm.graph.nodes if n.target == aten.convolution.default] 252 for conv in convs: 253 weight_node = conv.args[1] 254 if len(weight_node.meta["val"].size()) != 4 or weight_node.meta[ 255 "val" 256 ].is_contiguous(memory_format=torch.channels_last): 257 # not a 4d tensor or already channels last, skip 258 continue 259 260 with gm.graph.inserting_before(conv): 261 new_node = gm.graph.call_function( 262 aten.clone.default, 263 (weight_node,), 264 {"memory_format": torch.channels_last}, 265 ) 266 conv.replace_input_with(weight_node, new_node) 267 268 enforce_as_strided_input_layout(gm) 269 enforce_output_layout(gm) 270