xref: /aosp_15_r20/external/executorch/backends/transforms/utils.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7from typing import Optional
8
9import torch
10from executorch.exir import ExportedProgram
11
12from torch._export.utils import (
13    get_buffer,
14    get_lifted_tensor_constant,
15    get_param,
16    is_buffer,
17    is_lifted_tensor_constant,
18    is_param,
19)
20
21
22def is_get_attr_node(node: torch.fx.Node) -> bool:
23    """
24    Returns true if the given node is a get attr node for a tensor of the model
25    """
26    return isinstance(node, torch.fx.Node) and node.op == "get_attr"
27
28
29def is_param_node(exp_prog: ExportedProgram, node: torch.fx.Node) -> bool:
30    return (
31        is_get_attr_node(node)
32        or is_param(exp_prog, node)
33        or is_buffer(exp_prog, node)
34        or is_lifted_tensor_constant(exp_prog, node)
35    )
36
37
38def get_param_tensor(
39    exp_prog: ExportedProgram, node: torch.fx.Node
40) -> Optional[torch.Tensor]:
41    if node is None:
42        return None
43    elif is_param(exp_prog, node):
44        return get_param(exp_prog, node)
45    elif is_buffer(exp_prog, node):
46        return get_buffer(exp_prog, node)
47    elif is_lifted_tensor_constant(exp_prog, node):
48        return get_lifted_tensor_constant(exp_prog, node)
49    elif is_get_attr_node(node):
50        # This is a hack to support both lifted and unlifted graph
51        try:
52            return getattr(node.graph.owning_module, node.target)
53        except AttributeError:
54            return getattr(exp_prog.graph_module, node.target)
55    raise RuntimeError(f"unsupported param type, {node.op}.")
56