xref: /aosp_15_r20/external/pytorch/torch/csrc/api/src/cuda.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1  #include <torch/cuda.h>
2  
3  #include <ATen/Context.h>
4  #include <c10/core/DeviceGuard.h>
5  #include <c10/util/irange.h>
6  
7  #include <cstddef>
8  
9  namespace torch {
10  namespace cuda {
11  
device_count()12  size_t device_count() {
13    return at::detail::getCUDAHooks().getNumGPUs();
14  }
15  
is_available()16  bool is_available() {
17    // NB: the semantics of this are different from at::globalContext().hasCUDA();
18    // ATen's function tells you if you have a working driver and CUDA build,
19    // whereas this function also tells you if you actually have any GPUs.
20    // This function matches the semantics of at::cuda::is_available()
21    return cuda::device_count() > 0;
22  }
23  
cudnn_is_available()24  bool cudnn_is_available() {
25    return is_available() && at::detail::getCUDAHooks().hasCuDNN();
26  }
27  
28  /// Sets the seed for the current GPU.
manual_seed(uint64_t seed)29  void manual_seed(uint64_t seed) {
30    if (is_available()) {
31      auto index = at::detail::getCUDAHooks().current_device();
32      auto gen = at::detail::getCUDAHooks().getDefaultCUDAGenerator(index);
33      {
34        // See Note [Acquire lock when using random generators]
35        std::lock_guard<std::mutex> lock(gen.mutex());
36        gen.set_current_seed(seed);
37      }
38    }
39  }
40  
41  /// Sets the seed for all available GPUs.
manual_seed_all(uint64_t seed)42  void manual_seed_all(uint64_t seed) {
43    auto num_gpu = device_count();
44    for (const auto i : c10::irange(num_gpu)) {
45      auto gen = at::detail::getCUDAHooks().getDefaultCUDAGenerator(i);
46      {
47        // See Note [Acquire lock when using random generators]
48        std::lock_guard<std::mutex> lock(gen.mutex());
49        gen.set_current_seed(seed);
50      }
51    }
52  }
53  
synchronize(int64_t device_index)54  void synchronize(int64_t device_index) {
55    TORCH_CHECK(is_available(), "No CUDA GPUs are available");
56    int64_t num_gpus = cuda::device_count();
57    TORCH_CHECK(
58        device_index == -1 || device_index < num_gpus,
59        "Device index out of range: ",
60        device_index);
61    at::detail::getCUDAHooks().deviceSynchronize(device_index);
62  }
63  
64  } // namespace cuda
65  } // namespace torch
66