xref: /aosp_15_r20/external/pytorch/c10/cuda/CUDAFunctions.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/CUDAFunctions.h>
2*da0073e9SAndroid Build Coastguard Worker #include <c10/macros/Macros.h>
3*da0073e9SAndroid Build Coastguard Worker 
4*da0073e9SAndroid Build Coastguard Worker #include <limits>
5*da0073e9SAndroid Build Coastguard Worker 
6*da0073e9SAndroid Build Coastguard Worker namespace c10::cuda {
7*da0073e9SAndroid Build Coastguard Worker 
8*da0073e9SAndroid Build Coastguard Worker namespace {
9*da0073e9SAndroid Build Coastguard Worker // returns -1 on failure
driver_version()10*da0073e9SAndroid Build Coastguard Worker int32_t driver_version() {
11*da0073e9SAndroid Build Coastguard Worker   int driver_version = -1;
12*da0073e9SAndroid Build Coastguard Worker   C10_CUDA_IGNORE_ERROR(cudaDriverGetVersion(&driver_version));
13*da0073e9SAndroid Build Coastguard Worker   return driver_version;
14*da0073e9SAndroid Build Coastguard Worker }
15*da0073e9SAndroid Build Coastguard Worker 
device_count_impl(bool fail_if_no_driver)16*da0073e9SAndroid Build Coastguard Worker int device_count_impl(bool fail_if_no_driver) {
17*da0073e9SAndroid Build Coastguard Worker   int count = 0;
18*da0073e9SAndroid Build Coastguard Worker   auto err = C10_CUDA_ERROR_HANDLED(c10::cuda::GetDeviceCount(&count));
19*da0073e9SAndroid Build Coastguard Worker   if (err == cudaSuccess) {
20*da0073e9SAndroid Build Coastguard Worker     return count;
21*da0073e9SAndroid Build Coastguard Worker   }
22*da0073e9SAndroid Build Coastguard Worker   // Clear out the error state, so we don't spuriously trigger someone else.
23*da0073e9SAndroid Build Coastguard Worker   // (This shouldn't really matter, since we won't be running very much CUDA
24*da0073e9SAndroid Build Coastguard Worker   // code in this regime.)
25*da0073e9SAndroid Build Coastguard Worker   cudaError_t last_err C10_UNUSED = cudaGetLastError();
26*da0073e9SAndroid Build Coastguard Worker   switch (err) {
27*da0073e9SAndroid Build Coastguard Worker     case cudaErrorNoDevice:
28*da0073e9SAndroid Build Coastguard Worker       // Zero devices is ok here
29*da0073e9SAndroid Build Coastguard Worker       count = 0;
30*da0073e9SAndroid Build Coastguard Worker       break;
31*da0073e9SAndroid Build Coastguard Worker     case cudaErrorInsufficientDriver: {
32*da0073e9SAndroid Build Coastguard Worker       auto version = driver_version();
33*da0073e9SAndroid Build Coastguard Worker       if (version <= 0) {
34*da0073e9SAndroid Build Coastguard Worker         if (!fail_if_no_driver) {
35*da0073e9SAndroid Build Coastguard Worker           // No CUDA driver means no devices
36*da0073e9SAndroid Build Coastguard Worker           count = 0;
37*da0073e9SAndroid Build Coastguard Worker           break;
38*da0073e9SAndroid Build Coastguard Worker         }
39*da0073e9SAndroid Build Coastguard Worker         TORCH_CHECK(
40*da0073e9SAndroid Build Coastguard Worker             false,
41*da0073e9SAndroid Build Coastguard Worker             "Found no NVIDIA driver on your system. Please check that you "
42*da0073e9SAndroid Build Coastguard Worker             "have an NVIDIA GPU and installed a driver from "
43*da0073e9SAndroid Build Coastguard Worker             "http://www.nvidia.com/Download/index.aspx");
44*da0073e9SAndroid Build Coastguard Worker       } else {
45*da0073e9SAndroid Build Coastguard Worker         TORCH_CHECK(
46*da0073e9SAndroid Build Coastguard Worker             false,
47*da0073e9SAndroid Build Coastguard Worker             "The NVIDIA driver on your system is too old (found version ",
48*da0073e9SAndroid Build Coastguard Worker             version,
49*da0073e9SAndroid Build Coastguard Worker             "). Please update your GPU driver by downloading and installing "
50*da0073e9SAndroid Build Coastguard Worker             "a new version from the URL: "
51*da0073e9SAndroid Build Coastguard Worker             "http://www.nvidia.com/Download/index.aspx Alternatively, go to: "
52*da0073e9SAndroid Build Coastguard Worker             "https://pytorch.org to install a PyTorch version that has been "
53*da0073e9SAndroid Build Coastguard Worker             "compiled with your version of the CUDA driver.");
54*da0073e9SAndroid Build Coastguard Worker       }
55*da0073e9SAndroid Build Coastguard Worker     } break;
56*da0073e9SAndroid Build Coastguard Worker     case cudaErrorInitializationError:
57*da0073e9SAndroid Build Coastguard Worker       TORCH_CHECK(
58*da0073e9SAndroid Build Coastguard Worker           false,
59*da0073e9SAndroid Build Coastguard Worker           "CUDA driver initialization failed, you might not "
60*da0073e9SAndroid Build Coastguard Worker           "have a CUDA gpu.");
61*da0073e9SAndroid Build Coastguard Worker       break;
62*da0073e9SAndroid Build Coastguard Worker     case cudaErrorUnknown:
63*da0073e9SAndroid Build Coastguard Worker       TORCH_CHECK(
64*da0073e9SAndroid Build Coastguard Worker           false,
65*da0073e9SAndroid Build Coastguard Worker           "CUDA unknown error - this may be due to an "
66*da0073e9SAndroid Build Coastguard Worker           "incorrectly set up environment, e.g. changing env "
67*da0073e9SAndroid Build Coastguard Worker           "variable CUDA_VISIBLE_DEVICES after program start. "
68*da0073e9SAndroid Build Coastguard Worker           "Setting the available devices to be zero.");
69*da0073e9SAndroid Build Coastguard Worker       break;
70*da0073e9SAndroid Build Coastguard Worker #if C10_ASAN_ENABLED
71*da0073e9SAndroid Build Coastguard Worker     case cudaErrorMemoryAllocation:
72*da0073e9SAndroid Build Coastguard Worker       // In ASAN mode, we know that a cudaErrorMemoryAllocation error will
73*da0073e9SAndroid Build Coastguard Worker       // pop up if compiled with NVCC (clang-cuda is fine)
74*da0073e9SAndroid Build Coastguard Worker       TORCH_CHECK(
75*da0073e9SAndroid Build Coastguard Worker           false,
76*da0073e9SAndroid Build Coastguard Worker           "Got 'out of memory' error while trying to initialize CUDA. "
77*da0073e9SAndroid Build Coastguard Worker           "CUDA with nvcc does not work well with ASAN and it's probably "
78*da0073e9SAndroid Build Coastguard Worker           "the reason. We will simply shut down CUDA support. If you "
79*da0073e9SAndroid Build Coastguard Worker           "would like to use GPUs, turn off ASAN.");
80*da0073e9SAndroid Build Coastguard Worker       break;
81*da0073e9SAndroid Build Coastguard Worker #endif // C10_ASAN_ENABLED
82*da0073e9SAndroid Build Coastguard Worker     default:
83*da0073e9SAndroid Build Coastguard Worker       TORCH_CHECK(
84*da0073e9SAndroid Build Coastguard Worker           false,
85*da0073e9SAndroid Build Coastguard Worker           "Unexpected error from cudaGetDeviceCount(). Did you run "
86*da0073e9SAndroid Build Coastguard Worker           "some cuda functions before calling NumCudaDevices() "
87*da0073e9SAndroid Build Coastguard Worker           "that might have already set an error? Error ",
88*da0073e9SAndroid Build Coastguard Worker           err,
89*da0073e9SAndroid Build Coastguard Worker           ": ",
90*da0073e9SAndroid Build Coastguard Worker           cudaGetErrorString(err));
91*da0073e9SAndroid Build Coastguard Worker   }
92*da0073e9SAndroid Build Coastguard Worker   return count;
93*da0073e9SAndroid Build Coastguard Worker }
94*da0073e9SAndroid Build Coastguard Worker } // namespace
95*da0073e9SAndroid Build Coastguard Worker 
device_count()96*da0073e9SAndroid Build Coastguard Worker DeviceIndex device_count() noexcept {
97*da0073e9SAndroid Build Coastguard Worker   // initialize number of devices only once
98*da0073e9SAndroid Build Coastguard Worker   static int count = []() {
99*da0073e9SAndroid Build Coastguard Worker     try {
100*da0073e9SAndroid Build Coastguard Worker       auto result = device_count_impl(/*fail_if_no_driver=*/false);
101*da0073e9SAndroid Build Coastguard Worker       TORCH_INTERNAL_ASSERT(
102*da0073e9SAndroid Build Coastguard Worker           result <= std::numeric_limits<DeviceIndex>::max(),
103*da0073e9SAndroid Build Coastguard Worker           "Too many CUDA devices, DeviceIndex overflowed");
104*da0073e9SAndroid Build Coastguard Worker       return result;
105*da0073e9SAndroid Build Coastguard Worker     } catch (const c10::Error& ex) {
106*da0073e9SAndroid Build Coastguard Worker       // We don't want to fail, but still log the warning
107*da0073e9SAndroid Build Coastguard Worker       // msg() returns the message without the stack trace
108*da0073e9SAndroid Build Coastguard Worker       TORCH_WARN("CUDA initialization: ", ex.msg());
109*da0073e9SAndroid Build Coastguard Worker       return 0;
110*da0073e9SAndroid Build Coastguard Worker     }
111*da0073e9SAndroid Build Coastguard Worker   }();
112*da0073e9SAndroid Build Coastguard Worker   return static_cast<DeviceIndex>(count);
113*da0073e9SAndroid Build Coastguard Worker }
114*da0073e9SAndroid Build Coastguard Worker 
device_count_ensure_non_zero()115*da0073e9SAndroid Build Coastguard Worker DeviceIndex device_count_ensure_non_zero() {
116*da0073e9SAndroid Build Coastguard Worker   // Call the implementation every time to throw the exception
117*da0073e9SAndroid Build Coastguard Worker   int count = device_count_impl(/*fail_if_no_driver=*/true);
118*da0073e9SAndroid Build Coastguard Worker   // Zero gpus doesn't produce a warning in `device_count` but we fail here
119*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(count, "No CUDA GPUs are available");
120*da0073e9SAndroid Build Coastguard Worker   TORCH_INTERNAL_ASSERT(
121*da0073e9SAndroid Build Coastguard Worker       count <= std::numeric_limits<DeviceIndex>::max(),
122*da0073e9SAndroid Build Coastguard Worker       "Too many CUDA devices, DeviceIndex overflowed");
123*da0073e9SAndroid Build Coastguard Worker   return static_cast<DeviceIndex>(count);
124*da0073e9SAndroid Build Coastguard Worker }
125*da0073e9SAndroid Build Coastguard Worker 
current_device()126*da0073e9SAndroid Build Coastguard Worker DeviceIndex current_device() {
127*da0073e9SAndroid Build Coastguard Worker   DeviceIndex cur_device = -1;
128*da0073e9SAndroid Build Coastguard Worker   C10_CUDA_CHECK(c10::cuda::GetDevice(&cur_device));
129*da0073e9SAndroid Build Coastguard Worker   return cur_device;
130*da0073e9SAndroid Build Coastguard Worker }
131*da0073e9SAndroid Build Coastguard Worker 
set_device(DeviceIndex device)132*da0073e9SAndroid Build Coastguard Worker void set_device(DeviceIndex device) {
133*da0073e9SAndroid Build Coastguard Worker   C10_CUDA_CHECK(c10::cuda::SetDevice(device));
134*da0073e9SAndroid Build Coastguard Worker }
135*da0073e9SAndroid Build Coastguard Worker 
device_synchronize()136*da0073e9SAndroid Build Coastguard Worker void device_synchronize() {
137*da0073e9SAndroid Build Coastguard Worker   const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
138*da0073e9SAndroid Build Coastguard Worker   if (C10_UNLIKELY(interp)) {
139*da0073e9SAndroid Build Coastguard Worker     (*interp)->trace_gpu_device_synchronization(c10::kCUDA);
140*da0073e9SAndroid Build Coastguard Worker   }
141*da0073e9SAndroid Build Coastguard Worker   C10_CUDA_CHECK(cudaDeviceSynchronize());
142*da0073e9SAndroid Build Coastguard Worker }
143*da0073e9SAndroid Build Coastguard Worker 
144*da0073e9SAndroid Build Coastguard Worker // this function has to be called from callers performing cuda synchronizing
145*da0073e9SAndroid Build Coastguard Worker // operations, to raise proper error or warning
warn_or_error_on_sync()146*da0073e9SAndroid Build Coastguard Worker void warn_or_error_on_sync() {
147*da0073e9SAndroid Build Coastguard Worker   if (warning_state().get_sync_debug_mode() == SyncDebugMode::L_ERROR) {
148*da0073e9SAndroid Build Coastguard Worker     TORCH_CHECK(false, "called a synchronizing CUDA operation");
149*da0073e9SAndroid Build Coastguard Worker   } else if (warning_state().get_sync_debug_mode() == SyncDebugMode::L_WARN) {
150*da0073e9SAndroid Build Coastguard Worker     TORCH_WARN("called a synchronizing CUDA operation");
151*da0073e9SAndroid Build Coastguard Worker   }
152*da0073e9SAndroid Build Coastguard Worker }
153*da0073e9SAndroid Build Coastguard Worker 
getDeviceIndexWithPrimaryContext()154*da0073e9SAndroid Build Coastguard Worker std::optional<DeviceIndex> getDeviceIndexWithPrimaryContext() {
155*da0073e9SAndroid Build Coastguard Worker   // check current device first
156*da0073e9SAndroid Build Coastguard Worker   auto current_device_index = current_device();
157*da0073e9SAndroid Build Coastguard Worker   if (current_device_index >= 0) {
158*da0073e9SAndroid Build Coastguard Worker     if (hasPrimaryContext(current_device_index)) {
159*da0073e9SAndroid Build Coastguard Worker       return current_device_index;
160*da0073e9SAndroid Build Coastguard Worker     }
161*da0073e9SAndroid Build Coastguard Worker   }
162*da0073e9SAndroid Build Coastguard Worker   for (const auto device_index : c10::irange(at::cuda::device_count())) {
163*da0073e9SAndroid Build Coastguard Worker     if (device_index == current_device_index)
164*da0073e9SAndroid Build Coastguard Worker       continue;
165*da0073e9SAndroid Build Coastguard Worker     if (hasPrimaryContext(device_index)) {
166*da0073e9SAndroid Build Coastguard Worker       return device_index;
167*da0073e9SAndroid Build Coastguard Worker     }
168*da0073e9SAndroid Build Coastguard Worker   }
169*da0073e9SAndroid Build Coastguard Worker   return std::nullopt;
170*da0073e9SAndroid Build Coastguard Worker }
171*da0073e9SAndroid Build Coastguard Worker 
172*da0073e9SAndroid Build Coastguard Worker namespace _internal {
dummyHasPrimaryContext(C10_UNUSED DeviceIndex device_index)173*da0073e9SAndroid Build Coastguard Worker bool dummyHasPrimaryContext(C10_UNUSED DeviceIndex device_index) {
174*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(false, "Should never been called");
175*da0073e9SAndroid Build Coastguard Worker }
176*da0073e9SAndroid Build Coastguard Worker bool (*hasPrimaryContext)(DeviceIndex) = dummyHasPrimaryContext;
177*da0073e9SAndroid Build Coastguard Worker 
178*da0073e9SAndroid Build Coastguard Worker // Private api to be called from CUDAHooks.cpp
setHasPrimaryContext(bool (* func)(DeviceIndex))179*da0073e9SAndroid Build Coastguard Worker C10_CUDA_API void setHasPrimaryContext(bool (*func)(DeviceIndex)) {
180*da0073e9SAndroid Build Coastguard Worker   hasPrimaryContext = func ? func : dummyHasPrimaryContext;
181*da0073e9SAndroid Build Coastguard Worker }
182*da0073e9SAndroid Build Coastguard Worker } // namespace _internal
183*da0073e9SAndroid Build Coastguard Worker 
hasPrimaryContext(DeviceIndex device_index)184*da0073e9SAndroid Build Coastguard Worker bool hasPrimaryContext(DeviceIndex device_index) {
185*da0073e9SAndroid Build Coastguard Worker   return _internal::hasPrimaryContext(device_index);
186*da0073e9SAndroid Build Coastguard Worker }
187*da0073e9SAndroid Build Coastguard Worker 
188*da0073e9SAndroid Build Coastguard Worker // Wrappers for raw CUDA device management functions
GetDeviceCount(int * dev_count)189*da0073e9SAndroid Build Coastguard Worker cudaError_t GetDeviceCount(int* dev_count) {
190*da0073e9SAndroid Build Coastguard Worker   return cudaGetDeviceCount(dev_count);
191*da0073e9SAndroid Build Coastguard Worker }
192*da0073e9SAndroid Build Coastguard Worker 
193*da0073e9SAndroid Build Coastguard Worker // This is a codepath for CUDA 12 that comes with a critical change in behavior
194*da0073e9SAndroid Build Coastguard Worker // of `cudaSetDevice`. Unlike to previous CUDA versions that allocate context
195*da0073e9SAndroid Build Coastguard Worker // lazily CUDA 12.x eagerly allocates primary context the moment `cudaSetDevice`
196*da0073e9SAndroid Build Coastguard Worker // is called. This can lead to dramatic consequences and pollute the device
197*da0073e9SAndroid Build Coastguard Worker // memory in distributed runs. To avoid unnecessary context creation a new
198*da0073e9SAndroid Build Coastguard Worker // function called `MaybeSetDevice` was introduced. This function is to be
199*da0073e9SAndroid Build Coastguard Worker // called in device guard destructor and at the exit of torch.cuda.device
200*da0073e9SAndroid Build Coastguard Worker // context manager. The behavior of `MaybeSetDevice` is quite simple, it calls
201*da0073e9SAndroid Build Coastguard Worker // to `cudaSetDevice` if context already exist or if context was not allocated
202*da0073e9SAndroid Build Coastguard Worker // on targeted device it simply saves the device index. This way we can keep
203*da0073e9SAndroid Build Coastguard Worker // PyTorch backward compatible for applications like this:
204*da0073e9SAndroid Build Coastguard Worker //
205*da0073e9SAndroid Build Coastguard Worker // ```
206*da0073e9SAndroid Build Coastguard Worker // import torch
207*da0073e9SAndroid Build Coastguard Worker // x = torch.empty(1, device=“cuda:1”) # no CUDA context on cuda:0 after this
208*da0073e9SAndroid Build Coastguard Worker // call y = torch.empty(1, device=“cuda”) # CUDA context is created on cuda:0
209*da0073e9SAndroid Build Coastguard Worker // ```
210*da0073e9SAndroid Build Coastguard Worker #if CUDA_VERSION >= 12000
211*da0073e9SAndroid Build Coastguard Worker thread_local DeviceIndex targetDeviceIndex = -1;
212*da0073e9SAndroid Build Coastguard Worker 
GetDevice(DeviceIndex * device)213*da0073e9SAndroid Build Coastguard Worker cudaError_t GetDevice(DeviceIndex* device) {
214*da0073e9SAndroid Build Coastguard Worker   if (targetDeviceIndex >= 0) {
215*da0073e9SAndroid Build Coastguard Worker     *device = targetDeviceIndex;
216*da0073e9SAndroid Build Coastguard Worker     return cudaSuccess;
217*da0073e9SAndroid Build Coastguard Worker   }
218*da0073e9SAndroid Build Coastguard Worker   int tmp_device = -1;
219*da0073e9SAndroid Build Coastguard Worker   auto err = cudaGetDevice(&tmp_device);
220*da0073e9SAndroid Build Coastguard Worker   if (err == cudaSuccess) {
221*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT(
222*da0073e9SAndroid Build Coastguard Worker         tmp_device >= 0 &&
223*da0073e9SAndroid Build Coastguard Worker             tmp_device <= std::numeric_limits<DeviceIndex>::max(),
224*da0073e9SAndroid Build Coastguard Worker         "cudaGetDevice returns invalid device ",
225*da0073e9SAndroid Build Coastguard Worker         tmp_device);
226*da0073e9SAndroid Build Coastguard Worker     *device = static_cast<DeviceIndex>(tmp_device);
227*da0073e9SAndroid Build Coastguard Worker   }
228*da0073e9SAndroid Build Coastguard Worker   return err;
229*da0073e9SAndroid Build Coastguard Worker }
230*da0073e9SAndroid Build Coastguard Worker 
SetDevice(DeviceIndex device)231*da0073e9SAndroid Build Coastguard Worker cudaError_t SetDevice(DeviceIndex device) {
232*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(device >= 0, "device id must be positive!", device);
233*da0073e9SAndroid Build Coastguard Worker   targetDeviceIndex = -1;
234*da0073e9SAndroid Build Coastguard Worker   int cur_device = -1;
235*da0073e9SAndroid Build Coastguard Worker   C10_CUDA_CHECK(cudaGetDevice(&cur_device));
236*da0073e9SAndroid Build Coastguard Worker   if (device == cur_device) {
237*da0073e9SAndroid Build Coastguard Worker     return cudaSuccess;
238*da0073e9SAndroid Build Coastguard Worker   }
239*da0073e9SAndroid Build Coastguard Worker   return cudaSetDevice(device);
240*da0073e9SAndroid Build Coastguard Worker }
241*da0073e9SAndroid Build Coastguard Worker 
MaybeSetDevice(DeviceIndex device)242*da0073e9SAndroid Build Coastguard Worker cudaError_t MaybeSetDevice(DeviceIndex device) {
243*da0073e9SAndroid Build Coastguard Worker   if (hasPrimaryContext(device)) {
244*da0073e9SAndroid Build Coastguard Worker     return c10::cuda::SetDevice(device);
245*da0073e9SAndroid Build Coastguard Worker   }
246*da0073e9SAndroid Build Coastguard Worker   targetDeviceIndex = device;
247*da0073e9SAndroid Build Coastguard Worker   return cudaSuccess;
248*da0073e9SAndroid Build Coastguard Worker }
249*da0073e9SAndroid Build Coastguard Worker 
250*da0073e9SAndroid Build Coastguard Worker // This function always initializes the CUDA context
251*da0073e9SAndroid Build Coastguard Worker // on to_device
ExchangeDevice(DeviceIndex to_device)252*da0073e9SAndroid Build Coastguard Worker DeviceIndex ExchangeDevice(DeviceIndex to_device) {
253*da0073e9SAndroid Build Coastguard Worker   auto cur_device = targetDeviceIndex;
254*da0073e9SAndroid Build Coastguard Worker   targetDeviceIndex = -1;
255*da0073e9SAndroid Build Coastguard Worker   if (cur_device < 0) {
256*da0073e9SAndroid Build Coastguard Worker     int tmp_device = -1;
257*da0073e9SAndroid Build Coastguard Worker     C10_CUDA_CHECK(cudaGetDevice(&tmp_device));
258*da0073e9SAndroid Build Coastguard Worker     cur_device = static_cast<DeviceIndex>(tmp_device);
259*da0073e9SAndroid Build Coastguard Worker     if (to_device == cur_device) {
260*da0073e9SAndroid Build Coastguard Worker       return cur_device;
261*da0073e9SAndroid Build Coastguard Worker     }
262*da0073e9SAndroid Build Coastguard Worker   }
263*da0073e9SAndroid Build Coastguard Worker   C10_CUDA_CHECK(cudaSetDevice(to_device));
264*da0073e9SAndroid Build Coastguard Worker   return cur_device;
265*da0073e9SAndroid Build Coastguard Worker }
266*da0073e9SAndroid Build Coastguard Worker 
267*da0073e9SAndroid Build Coastguard Worker // This function does not initialize the CUDA context
268*da0073e9SAndroid Build Coastguard Worker // on to_device if it does not already exist
MaybeExchangeDevice(DeviceIndex to_device)269*da0073e9SAndroid Build Coastguard Worker DeviceIndex MaybeExchangeDevice(DeviceIndex to_device) {
270*da0073e9SAndroid Build Coastguard Worker   int tmp_cur_device = -1;
271*da0073e9SAndroid Build Coastguard Worker   C10_CUDA_CHECK(cudaGetDevice(&tmp_cur_device));
272*da0073e9SAndroid Build Coastguard Worker   TORCH_INTERNAL_ASSERT(
273*da0073e9SAndroid Build Coastguard Worker       tmp_cur_device >= 0 &&
274*da0073e9SAndroid Build Coastguard Worker           tmp_cur_device <= std::numeric_limits<DeviceIndex>::max(),
275*da0073e9SAndroid Build Coastguard Worker       "cudaGetDevice returns invalid device ",
276*da0073e9SAndroid Build Coastguard Worker       tmp_cur_device);
277*da0073e9SAndroid Build Coastguard Worker   auto cur_device = static_cast<DeviceIndex>(tmp_cur_device);
278*da0073e9SAndroid Build Coastguard Worker   if (to_device == tmp_cur_device) {
279*da0073e9SAndroid Build Coastguard Worker     return cur_device;
280*da0073e9SAndroid Build Coastguard Worker   }
281*da0073e9SAndroid Build Coastguard Worker   if (hasPrimaryContext(to_device)) {
282*da0073e9SAndroid Build Coastguard Worker     C10_CUDA_CHECK(cudaSetDevice(to_device));
283*da0073e9SAndroid Build Coastguard Worker   } else {
284*da0073e9SAndroid Build Coastguard Worker     targetDeviceIndex = to_device;
285*da0073e9SAndroid Build Coastguard Worker   }
286*da0073e9SAndroid Build Coastguard Worker   return cur_device;
287*da0073e9SAndroid Build Coastguard Worker }
288*da0073e9SAndroid Build Coastguard Worker 
SetTargetDevice()289*da0073e9SAndroid Build Coastguard Worker void SetTargetDevice() {
290*da0073e9SAndroid Build Coastguard Worker   if (targetDeviceIndex >= 0) {
291*da0073e9SAndroid Build Coastguard Worker     C10_CUDA_CHECK(c10::cuda::SetDevice(targetDeviceIndex));
292*da0073e9SAndroid Build Coastguard Worker   }
293*da0073e9SAndroid Build Coastguard Worker }
294*da0073e9SAndroid Build Coastguard Worker #else
GetDevice(DeviceIndex * device)295*da0073e9SAndroid Build Coastguard Worker cudaError_t GetDevice(DeviceIndex* device) {
296*da0073e9SAndroid Build Coastguard Worker   int tmp_device = -1;
297*da0073e9SAndroid Build Coastguard Worker   auto err = cudaGetDevice(&tmp_device);
298*da0073e9SAndroid Build Coastguard Worker   if (err == cudaSuccess) {
299*da0073e9SAndroid Build Coastguard Worker     TORCH_INTERNAL_ASSERT(
300*da0073e9SAndroid Build Coastguard Worker         tmp_device >= 0 &&
301*da0073e9SAndroid Build Coastguard Worker             tmp_device <= std::numeric_limits<DeviceIndex>::max(),
302*da0073e9SAndroid Build Coastguard Worker         "cudaGetDevice returns invalid device ",
303*da0073e9SAndroid Build Coastguard Worker         tmp_device);
304*da0073e9SAndroid Build Coastguard Worker     *device = static_cast<DeviceIndex>(tmp_device);
305*da0073e9SAndroid Build Coastguard Worker   }
306*da0073e9SAndroid Build Coastguard Worker   return err;
307*da0073e9SAndroid Build Coastguard Worker }
308*da0073e9SAndroid Build Coastguard Worker 
SetDevice(DeviceIndex device)309*da0073e9SAndroid Build Coastguard Worker cudaError_t SetDevice(DeviceIndex device) {
310*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(device >= 0, "device id must be positive!", device);
311*da0073e9SAndroid Build Coastguard Worker   int cur_device = -1;
312*da0073e9SAndroid Build Coastguard Worker   C10_CUDA_CHECK(cudaGetDevice(&cur_device));
313*da0073e9SAndroid Build Coastguard Worker   if (device == cur_device) {
314*da0073e9SAndroid Build Coastguard Worker     return cudaSuccess;
315*da0073e9SAndroid Build Coastguard Worker   }
316*da0073e9SAndroid Build Coastguard Worker   return cudaSetDevice(device);
317*da0073e9SAndroid Build Coastguard Worker }
318*da0073e9SAndroid Build Coastguard Worker 
MaybeSetDevice(DeviceIndex device)319*da0073e9SAndroid Build Coastguard Worker cudaError_t MaybeSetDevice(DeviceIndex device) {
320*da0073e9SAndroid Build Coastguard Worker   return c10::cuda::SetDevice(device);
321*da0073e9SAndroid Build Coastguard Worker }
322*da0073e9SAndroid Build Coastguard Worker 
ExchangeDevice(DeviceIndex to_device)323*da0073e9SAndroid Build Coastguard Worker DeviceIndex ExchangeDevice(DeviceIndex to_device) {
324*da0073e9SAndroid Build Coastguard Worker   DeviceIndex cur_device = -1;
325*da0073e9SAndroid Build Coastguard Worker   C10_CUDA_CHECK(c10::cuda::GetDevice(&cur_device));
326*da0073e9SAndroid Build Coastguard Worker   if (to_device == cur_device) {
327*da0073e9SAndroid Build Coastguard Worker     return cur_device;
328*da0073e9SAndroid Build Coastguard Worker   }
329*da0073e9SAndroid Build Coastguard Worker   C10_CUDA_CHECK(cudaSetDevice(to_device));
330*da0073e9SAndroid Build Coastguard Worker   return cur_device;
331*da0073e9SAndroid Build Coastguard Worker }
332*da0073e9SAndroid Build Coastguard Worker 
MaybeExchangeDevice(DeviceIndex to_device)333*da0073e9SAndroid Build Coastguard Worker DeviceIndex MaybeExchangeDevice(DeviceIndex to_device) {
334*da0073e9SAndroid Build Coastguard Worker   return c10::cuda::ExchangeDevice(to_device);
335*da0073e9SAndroid Build Coastguard Worker }
336*da0073e9SAndroid Build Coastguard Worker 
SetTargetDevice()337*da0073e9SAndroid Build Coastguard Worker void SetTargetDevice() {
338*da0073e9SAndroid Build Coastguard Worker   // no-op on CUDA version < 12.x
339*da0073e9SAndroid Build Coastguard Worker }
340*da0073e9SAndroid Build Coastguard Worker #endif
341*da0073e9SAndroid Build Coastguard Worker 
342*da0073e9SAndroid Build Coastguard Worker } // namespace c10::cuda
343