xref: /aosp_15_r20/external/pytorch/torch/fx/traceback.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerimport traceback
3*da0073e9SAndroid Build Coastguard Workerfrom contextlib import contextmanager
4*da0073e9SAndroid Build Coastguard Workerfrom typing import List, Any, Dict
5*da0073e9SAndroid Build Coastguard Workerfrom ._compatibility import compatibility
6*da0073e9SAndroid Build Coastguard Worker
7*da0073e9SAndroid Build Coastguard Worker__all__ = ['preserve_node_meta', 'has_preserved_node_meta',
8*da0073e9SAndroid Build Coastguard Worker           'set_stack_trace', 'set_grad_fn_seq_nr', 'reset_grad_fn_seq_nr',
9*da0073e9SAndroid Build Coastguard Worker           'format_stack', 'set_current_meta', 'get_current_meta']
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Workercurrent_meta: Dict[str, Any] = {}
12*da0073e9SAndroid Build Coastguard Workershould_preserve_node_meta = False
13*da0073e9SAndroid Build Coastguard Worker
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False)
16*da0073e9SAndroid Build Coastguard Worker@contextmanager
17*da0073e9SAndroid Build Coastguard Workerdef preserve_node_meta():
18*da0073e9SAndroid Build Coastguard Worker    global should_preserve_node_meta
19*da0073e9SAndroid Build Coastguard Worker    global current_meta
20*da0073e9SAndroid Build Coastguard Worker
21*da0073e9SAndroid Build Coastguard Worker    saved_should_preserve_node_meta = should_preserve_node_meta
22*da0073e9SAndroid Build Coastguard Worker    # Shallow copy is OK since fields of current_meta are not mutated
23*da0073e9SAndroid Build Coastguard Worker    saved_current_meta = current_meta.copy()
24*da0073e9SAndroid Build Coastguard Worker    try:
25*da0073e9SAndroid Build Coastguard Worker        should_preserve_node_meta = True
26*da0073e9SAndroid Build Coastguard Worker        yield
27*da0073e9SAndroid Build Coastguard Worker    finally:
28*da0073e9SAndroid Build Coastguard Worker        should_preserve_node_meta = saved_should_preserve_node_meta
29*da0073e9SAndroid Build Coastguard Worker        current_meta = saved_current_meta
30*da0073e9SAndroid Build Coastguard Worker
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False)
33*da0073e9SAndroid Build Coastguard Workerdef set_stack_trace(stack : List[str]):
34*da0073e9SAndroid Build Coastguard Worker    global current_meta
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker    if should_preserve_node_meta and stack:
37*da0073e9SAndroid Build Coastguard Worker        current_meta["stack_trace"] = "".join(stack)
38*da0073e9SAndroid Build Coastguard Worker
39*da0073e9SAndroid Build Coastguard Worker
40*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False)
41*da0073e9SAndroid Build Coastguard Workerdef set_grad_fn_seq_nr(seq_nr):
42*da0073e9SAndroid Build Coastguard Worker    global current_meta
43*da0073e9SAndroid Build Coastguard Worker
44*da0073e9SAndroid Build Coastguard Worker    if should_preserve_node_meta:
45*da0073e9SAndroid Build Coastguard Worker        # The seq_nr is captured by eager mode in the grad_fn during forward
46*da0073e9SAndroid Build Coastguard Worker        current_meta["grad_fn_seq_nr"] = current_meta.get("grad_fn_seq_nr", []) + [seq_nr]
47*da0073e9SAndroid Build Coastguard Worker        current_meta["in_grad_fn"] = current_meta.get("in_grad_fn", 0) + 1
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker
50*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False)
51*da0073e9SAndroid Build Coastguard Workerdef reset_grad_fn_seq_nr():
52*da0073e9SAndroid Build Coastguard Worker    # NB: reset state properly, this would be helpful towards supporting
53*da0073e9SAndroid Build Coastguard Worker    #     reentrant autograd if we actually wanted to do that.
54*da0073e9SAndroid Build Coastguard Worker    global current_meta
55*da0073e9SAndroid Build Coastguard Worker    if should_preserve_node_meta:
56*da0073e9SAndroid Build Coastguard Worker        current_level = current_meta.get("in_grad_fn", 0)
57*da0073e9SAndroid Build Coastguard Worker        assert current_level > 0
58*da0073e9SAndroid Build Coastguard Worker        if current_level == 1:
59*da0073e9SAndroid Build Coastguard Worker            del current_meta["in_grad_fn"]
60*da0073e9SAndroid Build Coastguard Worker            del current_meta["grad_fn_seq_nr"]
61*da0073e9SAndroid Build Coastguard Worker        else:
62*da0073e9SAndroid Build Coastguard Worker            current_meta["in_grad_fn"] = current_level - 1
63*da0073e9SAndroid Build Coastguard Worker            current_meta["grad_fn_seq_nr"] = current_meta["grad_fn_seq_nr"][:-1]
64*da0073e9SAndroid Build Coastguard Worker
65*da0073e9SAndroid Build Coastguard Worker
66*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False)
67*da0073e9SAndroid Build Coastguard Workerdef format_stack() -> List[str]:
68*da0073e9SAndroid Build Coastguard Worker    if should_preserve_node_meta:
69*da0073e9SAndroid Build Coastguard Worker        return [current_meta.get("stack_trace", "")]
70*da0073e9SAndroid Build Coastguard Worker    else:
71*da0073e9SAndroid Build Coastguard Worker        # fallback to traceback.format_stack()
72*da0073e9SAndroid Build Coastguard Worker        return traceback.format_list(traceback.extract_stack()[:-1])
73*da0073e9SAndroid Build Coastguard Worker
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False)
76*da0073e9SAndroid Build Coastguard Workerdef has_preserved_node_meta() -> bool:
77*da0073e9SAndroid Build Coastguard Worker    return should_preserve_node_meta
78*da0073e9SAndroid Build Coastguard Worker
79*da0073e9SAndroid Build Coastguard Worker
80*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False)
81*da0073e9SAndroid Build Coastguard Worker@contextmanager
82*da0073e9SAndroid Build Coastguard Workerdef set_current_meta(node):
83*da0073e9SAndroid Build Coastguard Worker    global current_meta
84*da0073e9SAndroid Build Coastguard Worker    if should_preserve_node_meta and node.meta:
85*da0073e9SAndroid Build Coastguard Worker        saved_meta = current_meta
86*da0073e9SAndroid Build Coastguard Worker        try:
87*da0073e9SAndroid Build Coastguard Worker            current_meta = node.meta.copy()
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker            # Append (node.name, node.target) onto "from_node" for provenance tracking
90*da0073e9SAndroid Build Coastguard Worker            if "from_node" not in current_meta:
91*da0073e9SAndroid Build Coastguard Worker                current_meta["from_node"] = [(node.name, node.target)]
92*da0073e9SAndroid Build Coastguard Worker            elif current_meta["from_node"][-1][0] != node.name:
93*da0073e9SAndroid Build Coastguard Worker                current_meta["from_node"] = current_meta["from_node"] + [(node.name, node.target)]
94*da0073e9SAndroid Build Coastguard Worker
95*da0073e9SAndroid Build Coastguard Worker            yield
96*da0073e9SAndroid Build Coastguard Worker        finally:
97*da0073e9SAndroid Build Coastguard Worker            current_meta = saved_meta
98*da0073e9SAndroid Build Coastguard Worker    else:
99*da0073e9SAndroid Build Coastguard Worker        yield
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False)
103*da0073e9SAndroid Build Coastguard Workerdef get_current_meta() -> Dict[str, Any]:
104*da0073e9SAndroid Build Coastguard Worker    return current_meta
105