xref: /aosp_15_r20/external/pytorch/torch/mps/event.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerimport torch
3*da0073e9SAndroid Build Coastguard Worker
4*da0073e9SAndroid Build Coastguard Worker
5*da0073e9SAndroid Build Coastguard Workerclass Event:
6*da0073e9SAndroid Build Coastguard Worker    r"""Wrapper around an MPS event.
7*da0073e9SAndroid Build Coastguard Worker
8*da0073e9SAndroid Build Coastguard Worker    MPS events are synchronization markers that can be used to monitor the
9*da0073e9SAndroid Build Coastguard Worker    device's progress, to accurately measure timing, and to synchronize MPS streams.
10*da0073e9SAndroid Build Coastguard Worker
11*da0073e9SAndroid Build Coastguard Worker    Args:
12*da0073e9SAndroid Build Coastguard Worker        enable_timing (bool, optional): indicates if the event should measure time
13*da0073e9SAndroid Build Coastguard Worker            (default: ``False``)
14*da0073e9SAndroid Build Coastguard Worker    """
15*da0073e9SAndroid Build Coastguard Worker
16*da0073e9SAndroid Build Coastguard Worker    def __init__(self, enable_timing=False):
17*da0073e9SAndroid Build Coastguard Worker        self.__eventId = torch._C._mps_acquireEvent(enable_timing)
18*da0073e9SAndroid Build Coastguard Worker
19*da0073e9SAndroid Build Coastguard Worker    def __del__(self):
20*da0073e9SAndroid Build Coastguard Worker        # checks if torch._C is already destroyed
21*da0073e9SAndroid Build Coastguard Worker        if hasattr(torch._C, "_mps_releaseEvent") and self.__eventId > 0:
22*da0073e9SAndroid Build Coastguard Worker            torch._C._mps_releaseEvent(self.__eventId)
23*da0073e9SAndroid Build Coastguard Worker
24*da0073e9SAndroid Build Coastguard Worker    def record(self):
25*da0073e9SAndroid Build Coastguard Worker        r"""Records the event in the default stream."""
26*da0073e9SAndroid Build Coastguard Worker        torch._C._mps_recordEvent(self.__eventId)
27*da0073e9SAndroid Build Coastguard Worker
28*da0073e9SAndroid Build Coastguard Worker    def wait(self):
29*da0073e9SAndroid Build Coastguard Worker        r"""Makes all future work submitted to the default stream wait for this event."""
30*da0073e9SAndroid Build Coastguard Worker        torch._C._mps_waitForEvent(self.__eventId)
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker    def query(self):
33*da0073e9SAndroid Build Coastguard Worker        r"""Returns True if all work currently captured by event has completed."""
34*da0073e9SAndroid Build Coastguard Worker        return torch._C._mps_queryEvent(self.__eventId)
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker    def synchronize(self):
37*da0073e9SAndroid Build Coastguard Worker        r"""Waits until the completion of all work currently captured in this event.
38*da0073e9SAndroid Build Coastguard Worker        This prevents the CPU thread from proceeding until the event completes.
39*da0073e9SAndroid Build Coastguard Worker        """
40*da0073e9SAndroid Build Coastguard Worker        torch._C._mps_synchronizeEvent(self.__eventId)
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker    def elapsed_time(self, end_event):
43*da0073e9SAndroid Build Coastguard Worker        r"""Returns the time elapsed in milliseconds after the event was
44*da0073e9SAndroid Build Coastguard Worker        recorded and before the end_event was recorded.
45*da0073e9SAndroid Build Coastguard Worker        """
46*da0073e9SAndroid Build Coastguard Worker        return torch._C._mps_elapsedTimeOfEvents(self.__eventId, end_event.__eventId)
47