1 # mypy: allow-untyped-defs 2 from typing import Union 3 4 import torch 5 6 7 class _InsertPoint: 8 def __init__( 9 self, 10 insert_point_graph: torch._C.Graph, 11 insert_point: Union[torch._C.Node, torch._C.Block], 12 ): 13 self.insert_point = insert_point 14 self.g = insert_point_graph 15 self.guard = None 16 17 def __enter__(self): 18 self.prev_insert_point = self.g.insertPoint() 19 self.g.setInsertPoint(self.insert_point) 20 21 def __exit__(self, *args): 22 self.g.setInsertPoint(self.prev_insert_point) 23 24 25 def insert_point_guard(self, insert_point: Union[torch._C.Node, torch._C.Block]): 26 return _InsertPoint(self, insert_point) 27