xref: /aosp_15_r20/external/pytorch/torch/jit/_ir_utils.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 # mypy: allow-untyped-defs
2 from typing import Union
3 
4 import torch
5 
6 
7 class _InsertPoint:
8     def __init__(
9         self,
10         insert_point_graph: torch._C.Graph,
11         insert_point: Union[torch._C.Node, torch._C.Block],
12     ):
13         self.insert_point = insert_point
14         self.g = insert_point_graph
15         self.guard = None
16 
17     def __enter__(self):
18         self.prev_insert_point = self.g.insertPoint()
19         self.g.setInsertPoint(self.insert_point)
20 
21     def __exit__(self, *args):
22         self.g.setInsertPoint(self.prev_insert_point)
23 
24 
25 def insert_point_guard(self, insert_point: Union[torch._C.Node, torch._C.Block]):
26     return _InsertPoint(self, insert_point)
27