xref: /aosp_15_r20/external/pytorch/torch/_lazy/closure.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2import os
3import threading
4from queue import Empty as EmptyQueue, Queue
5
6from torch._lazy.device_context import get_device_context
7
8
9class ClosureHandler:
10    def __init__(self) -> None:
11        pass
12
13    def run(self, closure):
14        """Run closure function
15
16        Args:
17        closure: callable function to run
18        """
19        closure()
20
21    def __call__(self, closures):
22        for closure in closures:
23            self.run(closure)
24
25
26class AsyncClosureHandler(ClosureHandler):
27    """Handler for Asynchronous Step Closures
28    Args:
29        max_queue_size: The maximum length of the closure queue after which
30        the training loop will block until closures are evaluated.
31        By default, a reasonable limit of a maximum of 100 on the queue.
32        This value can be set using the `XLA_MAX_ASYNC_QUEUE` environment
33        variable.
34    """
35
36    def __init__(self, max_queue_size=100):
37        super().__init__()
38        self._closure_queue: Queue = Queue(
39            int(os.environ.get("LTC_MAX_ASYNC_QUEUE", max_queue_size))
40        )
41        self._closure_exception: Queue = Queue()
42        self._closure_lock = threading.Lock()
43        self._closure_event_loop_finished = threading.Event()
44        self._closure_event_loop = None
45
46    def start_event_loop(self):
47        """Start closure event loop if not started"""
48        if self._closure_event_loop is None:
49
50            def event_loop():
51                # Run loop until closure event is set and closure queue is empty
52                while True:
53                    try:
54                        closure = self._closure_queue.get(block=True, timeout=3)
55                        closure()
56                        self._closure_queue.task_done()
57                    except EmptyQueue:
58                        with self._closure_lock:
59                            if self._closure_queue.empty():
60                                self._closure_event_loop_finished.set()
61                                return
62                    except Exception as e:
63                        self._closure_exception.put(e)
64                        return
65
66            self._closure_event_loop = threading.Thread(target=event_loop)
67            self._closure_event_loop.start()
68
69    def run(self, closure):
70        with self._closure_lock:
71            self._closure_queue.put(closure, block=True)
72            if (
73                self._closure_event_loop is None
74                or not self._closure_event_loop.is_alive()
75            ):
76                try:
77                    e = self._closure_exception.get(block=False)
78                    raise RuntimeError(
79                        "Cannot run asynchronous closure due to previously raised exception"
80                    ) from e
81                except EmptyQueue:
82                    self._closure_event_loop = None
83                    self.start_event_loop()
84
85
86def add_step_closure(closure, args=(), run_async=False):
87    """Adds a closure to the list of the ones to be run at the end of the step.
88    Many times during model training there is the need to print/report (print to
89    console, post to tensorboard, etc...) information which require the content of
90    intermediary tensors to be inspected.
91    Inspecting different tensors content in different points of the model code
92    requires many executions and typically causes performance issues.
93    Adding a step closure will ensure that it will be run after the barrier, when
94    all the live tensors will be already materialized to device data.
95    Live tensors which will include the ones captured by the closure arguments.
96    So using `add_step_closure()` will ensure a single execution will be
97    performed, even when multiple closures are queued, requiring multiple tensors
98    to be inspected.
99    Step closures will be run sequentially in the order they have been queued.
100    Note that even though using this API the execution will be optimized, it is
101    advised to throttle the printing/reporting events once every N steps.
102    Args:
103      closure (callable): The function to be called.
104      args (tuple): The arguments to be passed to the closure.
105      run_async: If True, run the closure asynchronously.
106    """
107    devctx = get_device_context()
108    closures_type = "async_step_closures" if run_async else "step_closures"
109    step_closures = getattr(devctx, closures_type, None)
110    if step_closures is None:
111        step_closures = []
112        setattr(devctx, closures_type, step_closures)
113    step_closures.append(lambda a=args: closure(*a))
114
115
116def run_step_closures():
117    devctx = get_device_context()
118    async_step_closures = getattr(devctx, "async_step_closures", None)
119    if async_step_closures is not None:
120        devctx.async_step_closures = []
121        async_closure_handler = getattr(devctx, "async_closure_handler", None)
122        if async_closure_handler is None:
123            async_closure_handler = AsyncClosureHandler()
124            devctx.async_closure_handler = async_closure_handler
125        async_closure_handler(async_step_closures)
126
127    step_closures = getattr(devctx, "step_closures", None)
128    if step_closures is not None:
129        devctx.step_closures = []
130        closure_handler = getattr(devctx, "closure_handler", None)
131        if closure_handler is None:
132            closure_handler = ClosureHandler()
133            devctx.closure_handler = closure_handler
134        closure_handler(step_closures)
135    return devctx
136