1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7from typing import Optional 8 9import torch 10from executorch.exir import ExportedProgram 11 12from torch._export.utils import ( 13 get_buffer, 14 get_lifted_tensor_constant, 15 get_param, 16 is_buffer, 17 is_lifted_tensor_constant, 18 is_param, 19) 20 21 22def is_get_attr_node(node: torch.fx.Node) -> bool: 23 """ 24 Returns true if the given node is a get attr node for a tensor of the model 25 """ 26 return isinstance(node, torch.fx.Node) and node.op == "get_attr" 27 28 29def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool: 30 return ( 31 is_get_attr_node(node) 32 or is_param(exp_prog, node) 33 or is_buffer(exp_prog, node) 34 or is_lifted_tensor_constant(exp_prog, node) 35 ) 36 37 38def get_param_tensor( 39 exp_prog: ExportedProgram, node: torch.fx.Node 40) -> Optional[torch.Tensor]: 41 if node is None: 42 return None 43 elif is_param(exp_prog, node): 44 return get_param(exp_prog, node) 45 elif is_buffer(exp_prog, node): 46 return get_buffer(exp_prog, node) 47 elif is_lifted_tensor_constant(exp_prog, node): 48 return get_lifted_tensor_constant(exp_prog, node) 49 elif is_get_attr_node(node): 50 # This is a hack to support both lifted and unlifted graph 51 try: 52 return getattr(node.graph.owning_module, node.target) 53 except AttributeError: 54 return getattr(exp_prog.graph_module, node.target) 55 raise RuntimeError(f"unsupported param type, {node.op}.") 56