1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport traceback 3*da0073e9SAndroid Build Coastguard Workerfrom contextlib import contextmanager 4*da0073e9SAndroid Build Coastguard Workerfrom typing import List, Any, Dict 5*da0073e9SAndroid Build Coastguard Workerfrom ._compatibility import compatibility 6*da0073e9SAndroid Build Coastguard Worker 7*da0073e9SAndroid Build Coastguard Worker__all__ = ['preserve_node_meta', 'has_preserved_node_meta', 8*da0073e9SAndroid Build Coastguard Worker 'set_stack_trace', 'set_grad_fn_seq_nr', 'reset_grad_fn_seq_nr', 9*da0073e9SAndroid Build Coastguard Worker 'format_stack', 'set_current_meta', 'get_current_meta'] 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Workercurrent_meta: Dict[str, Any] = {} 12*da0073e9SAndroid Build Coastguard Workershould_preserve_node_meta = False 13*da0073e9SAndroid Build Coastguard Worker 14*da0073e9SAndroid Build Coastguard Worker 15*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False) 16*da0073e9SAndroid Build Coastguard Worker@contextmanager 17*da0073e9SAndroid Build Coastguard Workerdef preserve_node_meta(): 18*da0073e9SAndroid Build Coastguard Worker global should_preserve_node_meta 19*da0073e9SAndroid Build Coastguard Worker global current_meta 20*da0073e9SAndroid Build Coastguard Worker 21*da0073e9SAndroid Build Coastguard Worker saved_should_preserve_node_meta = should_preserve_node_meta 22*da0073e9SAndroid Build Coastguard Worker # Shallow copy is OK since fields of current_meta are not mutated 23*da0073e9SAndroid Build Coastguard Worker saved_current_meta = current_meta.copy() 24*da0073e9SAndroid Build Coastguard Worker try: 25*da0073e9SAndroid Build Coastguard Worker should_preserve_node_meta = True 26*da0073e9SAndroid Build Coastguard Worker yield 27*da0073e9SAndroid Build Coastguard Worker finally: 28*da0073e9SAndroid Build Coastguard Worker should_preserve_node_meta = saved_should_preserve_node_meta 29*da0073e9SAndroid Build Coastguard Worker current_meta = saved_current_meta 30*da0073e9SAndroid Build Coastguard Worker 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False) 33*da0073e9SAndroid Build Coastguard Workerdef set_stack_trace(stack : List[str]): 34*da0073e9SAndroid Build Coastguard Worker global current_meta 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker if should_preserve_node_meta and stack: 37*da0073e9SAndroid Build Coastguard Worker current_meta["stack_trace"] = "".join(stack) 38*da0073e9SAndroid Build Coastguard Worker 39*da0073e9SAndroid Build Coastguard Worker 40*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False) 41*da0073e9SAndroid Build Coastguard Workerdef set_grad_fn_seq_nr(seq_nr): 42*da0073e9SAndroid Build Coastguard Worker global current_meta 43*da0073e9SAndroid Build Coastguard Worker 44*da0073e9SAndroid Build Coastguard Worker if should_preserve_node_meta: 45*da0073e9SAndroid Build Coastguard Worker # The seq_nr is captured by eager mode in the grad_fn during forward 46*da0073e9SAndroid Build Coastguard Worker current_meta["grad_fn_seq_nr"] = current_meta.get("grad_fn_seq_nr", []) + [seq_nr] 47*da0073e9SAndroid Build Coastguard Worker current_meta["in_grad_fn"] = current_meta.get("in_grad_fn", 0) + 1 48*da0073e9SAndroid Build Coastguard Worker 49*da0073e9SAndroid Build Coastguard Worker 50*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False) 51*da0073e9SAndroid Build Coastguard Workerdef reset_grad_fn_seq_nr(): 52*da0073e9SAndroid Build Coastguard Worker # NB: reset state properly, this would be helpful towards supporting 53*da0073e9SAndroid Build Coastguard Worker # reentrant autograd if we actually wanted to do that. 54*da0073e9SAndroid Build Coastguard Worker global current_meta 55*da0073e9SAndroid Build Coastguard Worker if should_preserve_node_meta: 56*da0073e9SAndroid Build Coastguard Worker current_level = current_meta.get("in_grad_fn", 0) 57*da0073e9SAndroid Build Coastguard Worker assert current_level > 0 58*da0073e9SAndroid Build Coastguard Worker if current_level == 1: 59*da0073e9SAndroid Build Coastguard Worker del current_meta["in_grad_fn"] 60*da0073e9SAndroid Build Coastguard Worker del current_meta["grad_fn_seq_nr"] 61*da0073e9SAndroid Build Coastguard Worker else: 62*da0073e9SAndroid Build Coastguard Worker current_meta["in_grad_fn"] = current_level - 1 63*da0073e9SAndroid Build Coastguard Worker current_meta["grad_fn_seq_nr"] = current_meta["grad_fn_seq_nr"][:-1] 64*da0073e9SAndroid Build Coastguard Worker 65*da0073e9SAndroid Build Coastguard Worker 66*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False) 67*da0073e9SAndroid Build Coastguard Workerdef format_stack() -> List[str]: 68*da0073e9SAndroid Build Coastguard Worker if should_preserve_node_meta: 69*da0073e9SAndroid Build Coastguard Worker return [current_meta.get("stack_trace", "")] 70*da0073e9SAndroid Build Coastguard Worker else: 71*da0073e9SAndroid Build Coastguard Worker # fallback to traceback.format_stack() 72*da0073e9SAndroid Build Coastguard Worker return traceback.format_list(traceback.extract_stack()[:-1]) 73*da0073e9SAndroid Build Coastguard Worker 74*da0073e9SAndroid Build Coastguard Worker 75*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False) 76*da0073e9SAndroid Build Coastguard Workerdef has_preserved_node_meta() -> bool: 77*da0073e9SAndroid Build Coastguard Worker return should_preserve_node_meta 78*da0073e9SAndroid Build Coastguard Worker 79*da0073e9SAndroid Build Coastguard Worker 80*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False) 81*da0073e9SAndroid Build Coastguard Worker@contextmanager 82*da0073e9SAndroid Build Coastguard Workerdef set_current_meta(node): 83*da0073e9SAndroid Build Coastguard Worker global current_meta 84*da0073e9SAndroid Build Coastguard Worker if should_preserve_node_meta and node.meta: 85*da0073e9SAndroid Build Coastguard Worker saved_meta = current_meta 86*da0073e9SAndroid Build Coastguard Worker try: 87*da0073e9SAndroid Build Coastguard Worker current_meta = node.meta.copy() 88*da0073e9SAndroid Build Coastguard Worker 89*da0073e9SAndroid Build Coastguard Worker # Append (node.name, node.target) onto "from_node" for provenance tracking 90*da0073e9SAndroid Build Coastguard Worker if "from_node" not in current_meta: 91*da0073e9SAndroid Build Coastguard Worker current_meta["from_node"] = [(node.name, node.target)] 92*da0073e9SAndroid Build Coastguard Worker elif current_meta["from_node"][-1][0] != node.name: 93*da0073e9SAndroid Build Coastguard Worker current_meta["from_node"] = current_meta["from_node"] + [(node.name, node.target)] 94*da0073e9SAndroid Build Coastguard Worker 95*da0073e9SAndroid Build Coastguard Worker yield 96*da0073e9SAndroid Build Coastguard Worker finally: 97*da0073e9SAndroid Build Coastguard Worker current_meta = saved_meta 98*da0073e9SAndroid Build Coastguard Worker else: 99*da0073e9SAndroid Build Coastguard Worker yield 100*da0073e9SAndroid Build Coastguard Worker 101*da0073e9SAndroid Build Coastguard Worker 102*da0073e9SAndroid Build Coastguard Worker@compatibility(is_backward_compatible=False) 103*da0073e9SAndroid Build Coastguard Workerdef get_current_meta() -> Dict[str, Any]: 104*da0073e9SAndroid Build Coastguard Worker return current_meta 105