xref: /aosp_15_r20/external/pytorch/c10/cuda/driver_api.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker #if !defined(USE_ROCM) && defined(PYTORCH_C10_DRIVER_API_SUPPORTED)
2*da0073e9SAndroid Build Coastguard Worker #include <c10/cuda/driver_api.h>
3*da0073e9SAndroid Build Coastguard Worker #include <c10/util/CallOnce.h>
4*da0073e9SAndroid Build Coastguard Worker #include <c10/util/Exception.h>
5*da0073e9SAndroid Build Coastguard Worker #include <dlfcn.h>
6*da0073e9SAndroid Build Coastguard Worker 
7*da0073e9SAndroid Build Coastguard Worker namespace c10::cuda {
8*da0073e9SAndroid Build Coastguard Worker 
9*da0073e9SAndroid Build Coastguard Worker namespace {
10*da0073e9SAndroid Build Coastguard Worker 
create_driver_api()11*da0073e9SAndroid Build Coastguard Worker DriverAPI create_driver_api() {
12*da0073e9SAndroid Build Coastguard Worker   void* handle_0 = dlopen("libcuda.so.1", RTLD_LAZY | RTLD_NOLOAD);
13*da0073e9SAndroid Build Coastguard Worker   TORCH_CHECK(handle_0, "Can't open libcuda.so.1: ", dlerror());
14*da0073e9SAndroid Build Coastguard Worker   void* handle_1 = DriverAPI::get_nvml_handle();
15*da0073e9SAndroid Build Coastguard Worker   DriverAPI r{};
16*da0073e9SAndroid Build Coastguard Worker 
17*da0073e9SAndroid Build Coastguard Worker #define LOOKUP_LIBCUDA_ENTRY(name)                       \
18*da0073e9SAndroid Build Coastguard Worker   r.name##_ = ((decltype(&name))dlsym(handle_0, #name)); \
19*da0073e9SAndroid Build Coastguard Worker   TORCH_INTERNAL_ASSERT(r.name##_, "Can't find ", #name, ": ", dlerror())
20*da0073e9SAndroid Build Coastguard Worker   C10_LIBCUDA_DRIVER_API(LOOKUP_LIBCUDA_ENTRY)
21*da0073e9SAndroid Build Coastguard Worker #undef LOOKUP_LIBCUDA_ENTRY
22*da0073e9SAndroid Build Coastguard Worker 
23*da0073e9SAndroid Build Coastguard Worker #define LOOKUP_LIBCUDA_ENTRY(name)                       \
24*da0073e9SAndroid Build Coastguard Worker   r.name##_ = ((decltype(&name))dlsym(handle_0, #name)); \
25*da0073e9SAndroid Build Coastguard Worker   dlerror();
26*da0073e9SAndroid Build Coastguard Worker   C10_LIBCUDA_DRIVER_API_12030(LOOKUP_LIBCUDA_ENTRY)
27*da0073e9SAndroid Build Coastguard Worker #undef LOOKUP_LIBCUDA_ENTRY
28*da0073e9SAndroid Build Coastguard Worker 
29*da0073e9SAndroid Build Coastguard Worker   if (handle_1) {
30*da0073e9SAndroid Build Coastguard Worker #define LOOKUP_NVML_ENTRY(name)                          \
31*da0073e9SAndroid Build Coastguard Worker   r.name##_ = ((decltype(&name))dlsym(handle_1, #name)); \
32*da0073e9SAndroid Build Coastguard Worker   TORCH_INTERNAL_ASSERT(r.name##_, "Can't find ", #name, ": ", dlerror())
33*da0073e9SAndroid Build Coastguard Worker     C10_NVML_DRIVER_API(LOOKUP_NVML_ENTRY)
34*da0073e9SAndroid Build Coastguard Worker #undef LOOKUP_NVML_ENTRY
35*da0073e9SAndroid Build Coastguard Worker   }
36*da0073e9SAndroid Build Coastguard Worker   return r;
37*da0073e9SAndroid Build Coastguard Worker }
38*da0073e9SAndroid Build Coastguard Worker } // namespace
39*da0073e9SAndroid Build Coastguard Worker 
get_nvml_handle()40*da0073e9SAndroid Build Coastguard Worker void* DriverAPI::get_nvml_handle() {
41*da0073e9SAndroid Build Coastguard Worker   static void* nvml_hanle = dlopen("libnvidia-ml.so.1", RTLD_LAZY);
42*da0073e9SAndroid Build Coastguard Worker   return nvml_hanle;
43*da0073e9SAndroid Build Coastguard Worker }
44*da0073e9SAndroid Build Coastguard Worker 
get()45*da0073e9SAndroid Build Coastguard Worker C10_EXPORT DriverAPI* DriverAPI::get() {
46*da0073e9SAndroid Build Coastguard Worker   static DriverAPI singleton = create_driver_api();
47*da0073e9SAndroid Build Coastguard Worker   return &singleton;
48*da0073e9SAndroid Build Coastguard Worker }
49*da0073e9SAndroid Build Coastguard Worker 
50*da0073e9SAndroid Build Coastguard Worker } // namespace c10::cuda
51*da0073e9SAndroid Build Coastguard Worker 
52*da0073e9SAndroid Build Coastguard Worker #endif
53