# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. from typing import Optional import torch from executorch.exir import ExportedProgram from torch._export.utils import ( get_buffer, get_lifted_tensor_constant, get_param, is_buffer, is_lifted_tensor_constant, is_param, ) def is_get_attr_node(node: torch.fx.Node) -> bool: """ Returns true if the given node is a get attr node for a tensor of the model """ return isinstance(node, torch.fx.Node) and node.op == "get_attr" def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool: return ( is_get_attr_node(node) or is_param(exp_prog, node) or is_buffer(exp_prog, node) or is_lifted_tensor_constant(exp_prog, node) ) def get_param_tensor( exp_prog: ExportedProgram, node: torch.fx.Node ) -> Optional[torch.Tensor]: if node is None: return None elif is_param(exp_prog, node): return get_param(exp_prog, node) elif is_buffer(exp_prog, node): return get_buffer(exp_prog, node) elif is_lifted_tensor_constant(exp_prog, node): return get_lifted_tensor_constant(exp_prog, node) elif is_get_attr_node(node): # This is a hack to support both lifted and unlifted graph try: return getattr(node.graph.owning_module, node.target) except AttributeError: return getattr(exp_prog.graph_module, node.target) raise RuntimeError(f"unsupported param type, {node.op}.")