1# mypy: allow-untyped-defs 2import torch 3 4 5class Event: 6 r"""Wrapper around an MPS event. 7 8 MPS events are synchronization markers that can be used to monitor the 9 device's progress, to accurately measure timing, and to synchronize MPS streams. 10 11 Args: 12 enable_timing (bool, optional): indicates if the event should measure time 13 (default: ``False``) 14 """ 15 16 def __init__(self, enable_timing=False): 17 self.__eventId = torch._C._mps_acquireEvent(enable_timing) 18 19 def __del__(self): 20 # checks if torch._C is already destroyed 21 if hasattr(torch._C, "_mps_releaseEvent") and self.__eventId > 0: 22 torch._C._mps_releaseEvent(self.__eventId) 23 24 def record(self): 25 r"""Records the event in the default stream.""" 26 torch._C._mps_recordEvent(self.__eventId) 27 28 def wait(self): 29 r"""Makes all future work submitted to the default stream wait for this event.""" 30 torch._C._mps_waitForEvent(self.__eventId) 31 32 def query(self): 33 r"""Returns True if all work currently captured by event has completed.""" 34 return torch._C._mps_queryEvent(self.__eventId) 35 36 def synchronize(self): 37 r"""Waits until the completion of all work currently captured in this event. 38 This prevents the CPU thread from proceeding until the event completes. 39 """ 40 torch._C._mps_synchronizeEvent(self.__eventId) 41 42 def elapsed_time(self, end_event): 43 r"""Returns the time elapsed in milliseconds after the event was 44 recorded and before the end_event was recorded. 45 """ 46 return torch._C._mps_elapsedTimeOfEvents(self.__eventId, end_event.__eventId) 47