1# mypy: allow-untyped-defs 2import os 3import threading 4from queue import Empty as EmptyQueue, Queue 5 6from torch._lazy.device_context import get_device_context 7 8 9class ClosureHandler: 10 def __init__(self) -> None: 11 pass 12 13 def run(self, closure): 14 """Run closure function 15 16 Args: 17 closure: callable function to run 18 """ 19 closure() 20 21 def __call__(self, closures): 22 for closure in closures: 23 self.run(closure) 24 25 26class AsyncClosureHandler(ClosureHandler): 27 """Handler for Asynchronous Step Closures 28 Args: 29 max_queue_size: The maximum length of the closure queue after which 30 the training loop will block until closures are evaluated. 31 By default, a reasonable limit of a maximum of 100 on the queue. 32 This value can be set using the `XLA_MAX_ASYNC_QUEUE` environment 33 variable. 34 """ 35 36 def __init__(self, max_queue_size=100): 37 super().__init__() 38 self._closure_queue: Queue = Queue( 39 int(os.environ.get("LTC_MAX_ASYNC_QUEUE", max_queue_size)) 40 ) 41 self._closure_exception: Queue = Queue() 42 self._closure_lock = threading.Lock() 43 self._closure_event_loop_finished = threading.Event() 44 self._closure_event_loop = None 45 46 def start_event_loop(self): 47 """Start closure event loop if not started""" 48 if self._closure_event_loop is None: 49 50 def event_loop(): 51 # Run loop until closure event is set and closure queue is empty 52 while True: 53 try: 54 closure = self._closure_queue.get(block=True, timeout=3) 55 closure() 56 self._closure_queue.task_done() 57 except EmptyQueue: 58 with self._closure_lock: 59 if self._closure_queue.empty(): 60 self._closure_event_loop_finished.set() 61 return 62 except Exception as e: 63 self._closure_exception.put(e) 64 return 65 66 self._closure_event_loop = threading.Thread(target=event_loop) 67 self._closure_event_loop.start() 68 69 def run(self, closure): 70 with self._closure_lock: 71 self._closure_queue.put(closure, block=True) 72 if ( 73 self._closure_event_loop is None 74 or not self._closure_event_loop.is_alive() 75 ): 76 try: 77 e = self._closure_exception.get(block=False) 78 raise RuntimeError( 79 "Cannot run asynchronous closure due to previously raised exception" 80 ) from e 81 except EmptyQueue: 82 self._closure_event_loop = None 83 self.start_event_loop() 84 85 86def add_step_closure(closure, args=(), run_async=False): 87 """Adds a closure to the list of the ones to be run at the end of the step. 88 Many times during model training there is the need to print/report (print to 89 console, post to tensorboard, etc...) information which require the content of 90 intermediary tensors to be inspected. 91 Inspecting different tensors content in different points of the model code 92 requires many executions and typically causes performance issues. 93 Adding a step closure will ensure that it will be run after the barrier, when 94 all the live tensors will be already materialized to device data. 95 Live tensors which will include the ones captured by the closure arguments. 96 So using `add_step_closure()` will ensure a single execution will be 97 performed, even when multiple closures are queued, requiring multiple tensors 98 to be inspected. 99 Step closures will be run sequentially in the order they have been queued. 100 Note that even though using this API the execution will be optimized, it is 101 advised to throttle the printing/reporting events once every N steps. 102 Args: 103 closure (callable): The function to be called. 104 args (tuple): The arguments to be passed to the closure. 105 run_async: If True, run the closure asynchronously. 106 """ 107 devctx = get_device_context() 108 closures_type = "async_step_closures" if run_async else "step_closures" 109 step_closures = getattr(devctx, closures_type, None) 110 if step_closures is None: 111 step_closures = [] 112 setattr(devctx, closures_type, step_closures) 113 step_closures.append(lambda a=args: closure(*a)) 114 115 116def run_step_closures(): 117 devctx = get_device_context() 118 async_step_closures = getattr(devctx, "async_step_closures", None) 119 if async_step_closures is not None: 120 devctx.async_step_closures = [] 121 async_closure_handler = getattr(devctx, "async_closure_handler", None) 122 if async_closure_handler is None: 123 async_closure_handler = AsyncClosureHandler() 124 devctx.async_closure_handler = async_closure_handler 125 async_closure_handler(async_step_closures) 126 127 step_closures = getattr(devctx, "step_closures", None) 128 if step_closures is not None: 129 devctx.step_closures = [] 130 closure_handler = getattr(devctx, "closure_handler", None) 131 if closure_handler is None: 132 closure_handler = ClosureHandler() 133 devctx.closure_handler = closure_handler 134 closure_handler(step_closures) 135 return devctx 136