xref: /aosp_15_r20/external/pytorch/torch/testing/_internal/common_cuda.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 # mypy: ignore-errors
2 
3 r"""This file is allowed to initialize CUDA context when imported."""
4 
5 import functools
6 import torch
7 import torch.cuda
8 from torch.testing._internal.common_utils import LazyVal, TEST_NUMBA, TEST_WITH_ROCM, TEST_CUDA, IS_WINDOWS
9 import inspect
10 import contextlib
11 import os
12 
13 
14 CUDA_ALREADY_INITIALIZED_ON_IMPORT = torch.cuda.is_initialized()
15 
16 
17 TEST_MULTIGPU = TEST_CUDA and torch.cuda.device_count() >= 2
18 CUDA_DEVICE = torch.device("cuda:0") if TEST_CUDA else None
19 # note: if ROCm is targeted, TEST_CUDNN is code for TEST_MIOPEN
20 if TEST_WITH_ROCM:
21     TEST_CUDNN = LazyVal(lambda: TEST_CUDA)
22 else:
23     TEST_CUDNN = LazyVal(lambda: TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=CUDA_DEVICE)))
24 
25 TEST_CUDNN_VERSION = LazyVal(lambda: torch.backends.cudnn.version() if TEST_CUDNN else 0)
26 
27 SM53OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (5, 3))
28 SM60OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (6, 0))
29 SM70OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (7, 0))
30 SM75OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (7, 5))
31 SM80OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 0))
32 SM90OrLater = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0))
33 
34 IS_JETSON = LazyVal(lambda: torch.cuda.is_available() and torch.cuda.get_device_capability() in [(7, 2), (8, 7)])
35 
36 def evaluate_gfx_arch_exact(matching_arch):
37     if not torch.cuda.is_available():
38         return False
39     gcn_arch_name = torch.cuda.get_device_properties('cuda').gcnArchName
40     arch = os.environ.get('PYTORCH_DEBUG_FLASH_ATTENTION_GCN_ARCH_OVERRIDE', gcn_arch_name)
41     return arch == matching_arch
42 
43 GFX90A_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-'))
44 GFX942_Exact = LazyVal(lambda: evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-'))
45 
46 def evaluate_platform_supports_flash_attention():
47     if TEST_WITH_ROCM:
48         return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-')
49     if TEST_CUDA:
50         return not IS_WINDOWS and SM80OrLater
51     return False
52 
53 def evaluate_platform_supports_efficient_attention():
54     if TEST_WITH_ROCM:
55         return evaluate_gfx_arch_exact('gfx90a:sramecc+:xnack-') or evaluate_gfx_arch_exact('gfx942:sramecc+:xnack-')
56     if TEST_CUDA:
57         return True
58     return False
59 
60 PLATFORM_SUPPORTS_FLASH_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_flash_attention())
61 PLATFORM_SUPPORTS_MEM_EFF_ATTENTION: bool = LazyVal(lambda: evaluate_platform_supports_efficient_attention())
62 # TODO(eqy): gate this against a cuDNN version
63 PLATFORM_SUPPORTS_CUDNN_ATTENTION: bool = LazyVal(lambda: TEST_CUDA and not TEST_WITH_ROCM and
64                                                   torch.backends.cuda.cudnn_sdp_enabled())
65 # This condition always evaluates to PLATFORM_SUPPORTS_MEM_EFF_ATTENTION but for logical clarity we keep it separate
66 PLATFORM_SUPPORTS_FUSED_ATTENTION: bool = LazyVal(lambda: PLATFORM_SUPPORTS_FLASH_ATTENTION or PLATFORM_SUPPORTS_MEM_EFF_ATTENTION)
67 
68 PLATFORM_SUPPORTS_FUSED_SDPA: bool = TEST_CUDA and not TEST_WITH_ROCM
69 
70 PLATFORM_SUPPORTS_BF16: bool = LazyVal(lambda: TEST_CUDA and SM80OrLater)
71 
72 if TEST_NUMBA:
73     try:
74         import numba.cuda
75         TEST_NUMBA_CUDA = numba.cuda.is_available()
76     except Exception as e:
77         TEST_NUMBA_CUDA = False
78         TEST_NUMBA = False
79 else:
80     TEST_NUMBA_CUDA = False
81 
82 # Used below in `initialize_cuda_context_rng` to ensure that CUDA context and
83 # RNG have been initialized.
84 __cuda_ctx_rng_initialized = False
85 
86 
87 # after this call, CUDA context and RNG must have been initialized on each GPU
88 def initialize_cuda_context_rng():
89     global __cuda_ctx_rng_initialized
90     assert TEST_CUDA, 'CUDA must be available when calling initialize_cuda_context_rng'
91     if not __cuda_ctx_rng_initialized:
92         # initialize cuda context and rng for memory tests
93         for i in range(torch.cuda.device_count()):
94             torch.randn(1, device=f"cuda:{i}")
95         __cuda_ctx_rng_initialized = True
96 
97 
98 # Test whether hardware TF32 math mode enabled. It is enabled only on:
99 # - CUDA >= 11
100 # - arch >= Ampere
101 def tf32_is_not_fp32():
102     if not torch.cuda.is_available() or torch.version.cuda is None:
103         return False
104     if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
105         return False
106     if int(torch.version.cuda.split('.')[0]) < 11:
107         return False
108     return True
109 
110 
111 @contextlib.contextmanager
112 def tf32_off():
113     old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
114     try:
115         torch.backends.cuda.matmul.allow_tf32 = False
116         with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=False):
117             yield
118     finally:
119         torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
120 
121 
122 @contextlib.contextmanager
123 def tf32_on(self, tf32_precision=1e-5):
124     old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
125     old_precision = self.precision
126     try:
127         torch.backends.cuda.matmul.allow_tf32 = True
128         self.precision = tf32_precision
129         with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True):
130             yield
131     finally:
132         torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
133         self.precision = old_precision
134 
135 
136 # This is a wrapper that wraps a test to run this test twice, one with
137 # allow_tf32=True, another with allow_tf32=False. When running with
138 # allow_tf32=True, it will use reduced precision as specified by the
139 # argument. For example:
140 #    @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
141 #    @tf32_on_and_off(0.005)
142 #    def test_matmul(self, device, dtype):
143 #        a = ...; b = ...;
144 #        c = torch.matmul(a, b)
145 #        self.assertEqual(c, expected)
146 # In the above example, when testing torch.float32 and torch.complex64 on CUDA
147 # on a CUDA >= 11 build on an >=Ampere architecture, the matmul will be running at
148 # TF32 mode and TF32 mode off, and on TF32 mode, the assertEqual will use reduced
149 # precision to check values.
150 #
151 # This decorator can be used for function with or without device/dtype, such as
152 # @tf32_on_and_off(0.005)
153 # def test_my_op(self)
154 # @tf32_on_and_off(0.005)
155 # def test_my_op(self, device)
156 # @tf32_on_and_off(0.005)
157 # def test_my_op(self, device, dtype)
158 # @tf32_on_and_off(0.005)
159 # def test_my_op(self, dtype)
160 # if neither device nor dtype is specified, it will check if the system has ampere device
161 # if device is specified, it will check if device is cuda
162 # if dtype is specified, it will check if dtype is float32 or complex64
163 # tf32 and fp32 are different only when all the three checks pass
164 def tf32_on_and_off(tf32_precision=1e-5):
165     def with_tf32_disabled(self, function_call):
166         with tf32_off():
167             function_call()
168 
169     def with_tf32_enabled(self, function_call):
170         with tf32_on(self, tf32_precision):
171             function_call()
172 
173     def wrapper(f):
174         params = inspect.signature(f).parameters
175         arg_names = tuple(params.keys())
176 
177         @functools.wraps(f)
178         def wrapped(*args, **kwargs):
179             for k, v in zip(arg_names, args):
180                 kwargs[k] = v
181             cond = tf32_is_not_fp32()
182             if 'device' in kwargs:
183                 cond = cond and (torch.device(kwargs['device']).type == 'cuda')
184             if 'dtype' in kwargs:
185                 cond = cond and (kwargs['dtype'] in {torch.float32, torch.complex64})
186             if cond:
187                 with_tf32_disabled(kwargs['self'], lambda: f(**kwargs))
188                 with_tf32_enabled(kwargs['self'], lambda: f(**kwargs))
189             else:
190                 f(**kwargs)
191 
192         return wrapped
193     return wrapper
194 
195 
196 # This is a wrapper that wraps a test to run it with TF32 turned off.
197 # This wrapper is designed to be used when a test uses matmul or convolutions
198 # but the purpose of that test is not testing matmul or convolutions.
199 # Disabling TF32 will enforce torch.float tensors to be always computed
200 # at full precision.
201 def with_tf32_off(f):
202     @functools.wraps(f)
203     def wrapped(*args, **kwargs):
204         with tf32_off():
205             return f(*args, **kwargs)
206 
207     return wrapped
208 
209 def _get_magma_version():
210     if 'Magma' not in torch.__config__.show():
211         return (0, 0)
212     position = torch.__config__.show().find('Magma ')
213     version_str = torch.__config__.show()[position + len('Magma '):].split('\n')[0]
214     return tuple(int(x) for x in version_str.split("."))
215 
216 def _get_torch_cuda_version():
217     if torch.version.cuda is None:
218         return (0, 0)
219     cuda_version = str(torch.version.cuda)
220     return tuple(int(x) for x in cuda_version.split("."))
221 
222 def _get_torch_rocm_version():
223     if not TEST_WITH_ROCM:
224         return (0, 0)
225     rocm_version = str(torch.version.hip)
226     rocm_version = rocm_version.split("-")[0]    # ignore git sha
227     return tuple(int(x) for x in rocm_version.split("."))
228 
229 def _check_cusparse_generic_available():
230     return not TEST_WITH_ROCM
231 
232 def _check_hipsparse_generic_available():
233     if not TEST_WITH_ROCM:
234         return False
235 
236     rocm_version = str(torch.version.hip)
237     rocm_version = rocm_version.split("-")[0]    # ignore git sha
238     rocm_version_tuple = tuple(int(x) for x in rocm_version.split("."))
239     return not (rocm_version_tuple is None or rocm_version_tuple < (5, 1))
240 
241 
242 TEST_CUSPARSE_GENERIC = _check_cusparse_generic_available()
243 TEST_HIPSPARSE_GENERIC = _check_hipsparse_generic_available()
244 
245 # Shared by test_torch.py and test_multigpu.py
246 def _create_scaling_models_optimizers(device="cuda", optimizer_ctor=torch.optim.SGD, optimizer_kwargs=None):
247     # Create a module+optimizer that will use scaling, and a control module+optimizer
248     # that will not use scaling, against which the scaling-enabled module+optimizer can be compared.
249     mod_control = torch.nn.Sequential(torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)).to(device=device)
250     mod_scaling = torch.nn.Sequential(torch.nn.Linear(8, 8), torch.nn.Linear(8, 8)).to(device=device)
251     with torch.no_grad():
252         for c, s in zip(mod_control.parameters(), mod_scaling.parameters()):
253             s.copy_(c)
254 
255     kwargs = {"lr": 1.0}
256     if optimizer_kwargs is not None:
257         kwargs.update(optimizer_kwargs)
258     opt_control = optimizer_ctor(mod_control.parameters(), **kwargs)
259     opt_scaling = optimizer_ctor(mod_scaling.parameters(), **kwargs)
260 
261     return mod_control, mod_scaling, opt_control, opt_scaling
262 
263 # Shared by test_torch.py, test_cuda.py and test_multigpu.py
264 def _create_scaling_case(device="cuda", dtype=torch.float, optimizer_ctor=torch.optim.SGD, optimizer_kwargs=None):
265     data = [(torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)),
266             (torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)),
267             (torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device)),
268             (torch.randn((8, 8), dtype=dtype, device=device), torch.randn((8, 8), dtype=dtype, device=device))]
269 
270     loss_fn = torch.nn.MSELoss().to(device)
271 
272     skip_iter = 2
273 
274     return _create_scaling_models_optimizers(
275         device=device, optimizer_ctor=optimizer_ctor, optimizer_kwargs=optimizer_kwargs,
276     ) + (data, loss_fn, skip_iter)
277 
278 
279 # Importing this module should NOT eagerly initialize CUDA
280 if not CUDA_ALREADY_INITIALIZED_ON_IMPORT:
281     assert not torch.cuda.is_initialized()
282