xref: /aosp_15_r20/external/pytorch/torch/fx/experimental/unify_refinements.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from torch.fx.experimental.graph_gradual_typechecker import Refine
3from torch.fx.tensor_type import TensorType
4from torch.fx.experimental.unification import Var, unify  # type: ignore[attr-defined]
5
6
7def infer_symbolic_types_single_pass(traced):
8    """
9    Calls our symbolic inferencer once.
10    """
11    r = Refine(traced)
12    r.refine()
13    mgu = unify_eq(r.constraints)
14    substitute_all_types(traced.graph, mgu)
15
16def infer_symbolic_types(traced):
17    """
18    Calls our symbolic inferencer twice.
19    This is useful when one pass is not enough
20    to infer all the information such as the case
21    for braodcasting.
22    """
23    r = Refine(traced)
24    r.refine()
25    mgu = unify_eq(r.constraints)
26    substitute_all_types(traced.graph, mgu)
27
28    r = Refine(traced)
29    r.refine()
30    mgu = unify_eq(r.constraints)
31    substitute_all_types(traced.graph, mgu)
32
33    r.symbolic_relations()
34
35def convert_eq(list_of_eq):
36    """
37    Convert equality constraints in the right format
38    to be used by unification library.
39    """
40    lhs = []
41    rhs = []
42    for eq in list_of_eq:
43        lhs.append(eq.lhs)
44        rhs.append(eq.rhs)
45    return tuple(lhs), tuple(rhs)
46
47
48def unify_eq(list_of_eq):
49    """
50    Apply unification to a set of
51    equality constraints
52    """
53    lhs, rhs = convert_eq(list_of_eq)
54    return unify(lhs, rhs)
55
56
57def substitute_solution_one_type(mapping, t):
58    """
59    Apply the most general unifier to a type
60    """
61    if isinstance(t, Var):
62        if t in mapping.keys():
63            return mapping[t]
64        else:
65            return t
66
67    elif isinstance(t, TensorType):
68        new_type = []
69        for typ in t.__args__:
70            if typ in mapping.keys():
71                new_type.append(mapping[typ])
72            else:
73                new_type.append(typ)
74        return TensorType(tuple(new_type))
75
76    elif isinstance(t, list):
77        new_type = []
78        for typ in t:
79            new_type.append(substitute_solution_one_type(mapping, typ))
80        return new_type
81
82    elif isinstance(t, tuple):
83        new_type = []
84        for typ in t:
85            new_type.append(substitute_solution_one_type(mapping, typ))
86        return tuple(new_type)
87
88    else:
89        return t
90
91
92def substitute_all_types(graph, mapping):
93    """
94    Apply the most general unifier to all types in a graph
95    till reaching a fixed point. If the input and output graph
96    are the same, we converge.
97    """
98    flag = True
99    while flag:
100        flag = False
101        for k in mapping:
102            old_mapping_val = mapping[k]
103            if mapping[k] in mapping.keys():
104                new_key = mapping[k]
105                mapping[k] = mapping[new_key]
106            if old_mapping_val != mapping[k]:
107                flag = True
108
109    for n in graph.nodes:
110        n.type = substitute_solution_one_type(mapping, n.type)
111
112def check_for_type_equality(g1, g2):
113    """
114    A check equality to be used in fixed points.
115    We do not use graph equality but instead type
116    equality.
117    """
118    for n, m in zip(g1.nodes, g2.nodes):
119        if n.type != m.type:
120            return False
121    return True
122