xref: /aosp_15_r20/external/pytorch/torch/cuda/gds.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 import os
2 import sys
3 from typing import Callable, List, Optional
4 
5 import torch
6 from torch.types import Storage
7 
8 
9 __all__: List[str] = []
10 
11 
12 def _dummy_fn(name: str) -> Callable:
13     def fn(*args, **kwargs):  # type: ignore[no-untyped-def]
14         raise RuntimeError(f"torch._C.{name} is not supported on this platform")
15 
16     return fn
17 
18 
19 if not hasattr(torch._C, "_gds_register_buffer"):
20     assert not hasattr(torch._C, "_gds_deregister_buffer")
21     assert not hasattr(torch._C, "_gds_register_handle")
22     assert not hasattr(torch._C, "_gds_deregister_handle")
23     assert not hasattr(torch._C, "_gds_load_storage")
24     assert not hasattr(torch._C, "_gds_save_storage")
25     # Define functions
26     torch._C.__dict__["_gds_register_buffer"] = _dummy_fn("_gds_register_buffer")
27     torch._C.__dict__["_gds_deregister_buffer"] = _dummy_fn("_gds_deregister_buffer")
28     torch._C.__dict__["_gds_register_handle"] = _dummy_fn("_gds_register_handle")
29     torch._C.__dict__["_gds_deregister_handle"] = _dummy_fn("_gds_deregister_handle")
30     torch._C.__dict__["_gds_load_storage"] = _dummy_fn("_gds_load_storage")
31     torch._C.__dict__["_gds_save_storage"] = _dummy_fn("_gds_save_storage")
32 
33 
34 def _gds_register_buffer(s: Storage) -> None:
35     """Registers a buffer.
36 
37     Args:
38         s (Storage): Buffer to register.
39     """
40     torch._C._gds_register_buffer(s)
41 
42 
43 def _gds_deregister_buffer(s: Storage) -> None:
44     """Registers a buffer.
45 
46     Args:
47         s (Storage): Buffer to register.
48     """
49     torch._C._gds_deregister_buffer(s)
50 
51 
52 class _GdsFile:
53     r"""Wrapper around cuFile.
54 
55     cuFile is a file-like interface to the GPUDirect Storage (GDS) API.
56 
57     Args:
58         filename (str): Name of the file to open.
59         flags (int): Flags to pass to ``os.open`` when opening the file. ``os.O_DIRECT`` will
60             be added automatically.
61 
62     .. _CUDA GPUDirect Storage Documentation:
63         https://docs.nvidia.com/gpudirect-storage/api-reference-guide/index.html#cufile-io-api
64     """
65 
66     def __init__(self, filename: str, flags: int):
67         if sys.platform == "win32":
68             raise RuntimeError("GdsFile is not supported on this platform.")
69         self.filename = filename
70         self.flags = flags
71         self.fd = os.open(filename, flags | os.O_DIRECT)
72         self.handle: Optional[int] = None
73         self.register_handle()
74 
75     def __del__(self) -> None:
76         if self.handle is not None:
77             self.deregister_handle()
78         os.close(self.fd)
79 
80     def register_handle(self) -> None:
81         """Registers file descriptor to cuFile Driver.
82 
83         This is a wrapper around ``cuFileHandleRegister``.
84         """
85         assert (
86             self.handle is None
87         ), "Cannot register a handle that is already registered."
88         self.handle = torch._C._gds_register_handle(self.fd)
89 
90     def deregister_handle(self) -> None:
91         """Deregisters file descriptor from cuFile Driver.
92 
93         This is a wrapper around ``cuFileHandleDeregister``.
94         """
95         assert (
96             self.handle is not None
97         ), "Cannot deregister a handle that is not registered."
98         torch._C._gds_deregister_handle(self.handle)
99         self.handle = None
100 
101     def load_storage(self, storage: Storage, offset: int = 0) -> None:
102         """Loads data from the file into the storage.
103 
104         This is a wrapper around ``cuFileRead``. ``storage.nbytes()`` of data
105         will be loaded from the file at ``offset`` into the storage.
106 
107         Args:
108             storage (Storage): Storage to load data into.
109             offset (int, optional): Offset into the file to start loading from. (Default: 0)
110         """
111         assert (
112             self.handle is not None
113         ), "Cannot load data from a file that is not registered."
114         torch._C._gds_load_storage(self.handle, storage, offset)
115 
116     def save_storage(self, storage: Storage, offset: int = 0) -> None:
117         """Saves data from the storage into the file.
118 
119         This is a wrapper around ``cuFileWrite``. All bytes of the storage
120         will be written to the file at ``offset``.
121 
122         Args:
123             storage (Storage): Storage to save data from.
124             offset (int, optional): Offset into the file to start saving to. (Default: 0)
125         """
126         assert (
127             self.handle is not None
128         ), "Cannot save data to a file that is not registered."
129         torch._C._gds_save_storage(self.handle, storage, offset)
130