1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs 2*da0073e9SAndroid Build Coastguard Workerimport torch 3*da0073e9SAndroid Build Coastguard Worker 4*da0073e9SAndroid Build Coastguard Worker 5*da0073e9SAndroid Build Coastguard Workerclass Event: 6*da0073e9SAndroid Build Coastguard Worker r"""Wrapper around an MPS event. 7*da0073e9SAndroid Build Coastguard Worker 8*da0073e9SAndroid Build Coastguard Worker MPS events are synchronization markers that can be used to monitor the 9*da0073e9SAndroid Build Coastguard Worker device's progress, to accurately measure timing, and to synchronize MPS streams. 10*da0073e9SAndroid Build Coastguard Worker 11*da0073e9SAndroid Build Coastguard Worker Args: 12*da0073e9SAndroid Build Coastguard Worker enable_timing (bool, optional): indicates if the event should measure time 13*da0073e9SAndroid Build Coastguard Worker (default: ``False``) 14*da0073e9SAndroid Build Coastguard Worker """ 15*da0073e9SAndroid Build Coastguard Worker 16*da0073e9SAndroid Build Coastguard Worker def __init__(self, enable_timing=False): 17*da0073e9SAndroid Build Coastguard Worker self.__eventId = torch._C._mps_acquireEvent(enable_timing) 18*da0073e9SAndroid Build Coastguard Worker 19*da0073e9SAndroid Build Coastguard Worker def __del__(self): 20*da0073e9SAndroid Build Coastguard Worker # checks if torch._C is already destroyed 21*da0073e9SAndroid Build Coastguard Worker if hasattr(torch._C, "_mps_releaseEvent") and self.__eventId > 0: 22*da0073e9SAndroid Build Coastguard Worker torch._C._mps_releaseEvent(self.__eventId) 23*da0073e9SAndroid Build Coastguard Worker 24*da0073e9SAndroid Build Coastguard Worker def record(self): 25*da0073e9SAndroid Build Coastguard Worker r"""Records the event in the default stream.""" 26*da0073e9SAndroid Build Coastguard Worker torch._C._mps_recordEvent(self.__eventId) 27*da0073e9SAndroid Build Coastguard Worker 28*da0073e9SAndroid Build Coastguard Worker def wait(self): 29*da0073e9SAndroid Build Coastguard Worker r"""Makes all future work submitted to the default stream wait for this event.""" 30*da0073e9SAndroid Build Coastguard Worker torch._C._mps_waitForEvent(self.__eventId) 31*da0073e9SAndroid Build Coastguard Worker 32*da0073e9SAndroid Build Coastguard Worker def query(self): 33*da0073e9SAndroid Build Coastguard Worker r"""Returns True if all work currently captured by event has completed.""" 34*da0073e9SAndroid Build Coastguard Worker return torch._C._mps_queryEvent(self.__eventId) 35*da0073e9SAndroid Build Coastguard Worker 36*da0073e9SAndroid Build Coastguard Worker def synchronize(self): 37*da0073e9SAndroid Build Coastguard Worker r"""Waits until the completion of all work currently captured in this event. 38*da0073e9SAndroid Build Coastguard Worker This prevents the CPU thread from proceeding until the event completes. 39*da0073e9SAndroid Build Coastguard Worker """ 40*da0073e9SAndroid Build Coastguard Worker torch._C._mps_synchronizeEvent(self.__eventId) 41*da0073e9SAndroid Build Coastguard Worker 42*da0073e9SAndroid Build Coastguard Worker def elapsed_time(self, end_event): 43*da0073e9SAndroid Build Coastguard Worker r"""Returns the time elapsed in milliseconds after the event was 44*da0073e9SAndroid Build Coastguard Worker recorded and before the end_event was recorded. 45*da0073e9SAndroid Build Coastguard Worker """ 46*da0073e9SAndroid Build Coastguard Worker return torch._C._mps_elapsedTimeOfEvents(self.__eventId, end_event.__eventId) 47