xref: /aosp_15_r20/external/pytorch/aten/src/ATen/cuda/CUDASparse.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/cuda/CUDAContext.h>
4 #if defined(USE_ROCM)
5 #include <hipsparse/hipsparse-version.h>
6 #define HIPSPARSE_VERSION ((hipsparseVersionMajor*100000) + (hipsparseVersionMinor*100) + hipsparseVersionPatch)
7 #endif
8 
9 // cuSparse Generic API added in CUDA 10.1
10 // Windows support added in CUDA 11.0
11 #if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && ((CUSPARSE_VERSION >= 10300) || (CUSPARSE_VERSION >= 11000 && defined(_WIN32)))
12 #define AT_USE_CUSPARSE_GENERIC_API() 1
13 #else
14 #define AT_USE_CUSPARSE_GENERIC_API() 0
15 #endif
16 
17 // cuSparse Generic API descriptor pointers were changed to const in CUDA 12.0
18 #if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && \
19     (CUSPARSE_VERSION < 12000)
20 #define AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS() 1
21 #else
22 #define AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS() 0
23 #endif
24 
25 #if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && \
26     (CUSPARSE_VERSION >= 12000)
27 #define AT_USE_CUSPARSE_CONST_DESCRIPTORS() 1
28 #else
29 #define AT_USE_CUSPARSE_CONST_DESCRIPTORS() 0
30 #endif
31 
32 #if defined(USE_ROCM)
33 // hipSparse const API added in v2.4.0
34 #if HIPSPARSE_VERSION >= 200400
35 #define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 1
36 #define AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS() 0
37 #define AT_USE_HIPSPARSE_GENERIC_API() 1
38 #else
39 #define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 0
40 #define AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS() 1
41 #define AT_USE_HIPSPARSE_GENERIC_API() 1
42 #endif
43 #else // USE_ROCM
44 #define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 0
45 #define AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS() 0
46 #define AT_USE_HIPSPARSE_GENERIC_API() 0
47 #endif // USE_ROCM
48 
49 // cuSparse Generic API spsv function was added in CUDA 11.3.0
50 #if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11500)
51 #define AT_USE_CUSPARSE_GENERIC_SPSV() 1
52 #else
53 #define AT_USE_CUSPARSE_GENERIC_SPSV() 0
54 #endif
55 
56 // cuSparse Generic API spsm function was added in CUDA 11.3.1
57 #if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11600)
58 #define AT_USE_CUSPARSE_GENERIC_SPSM() 1
59 #else
60 #define AT_USE_CUSPARSE_GENERIC_SPSM() 0
61 #endif
62 
63 // cuSparse Generic API sddmm function was added in CUDA 11.2.1 (cuSparse version 11400)
64 #if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11400)
65 #define AT_USE_CUSPARSE_GENERIC_SDDMM() 1
66 #else
67 #define AT_USE_CUSPARSE_GENERIC_SDDMM() 0
68 #endif
69 
70 // BSR triangular solve functions were added in hipSPARSE 1.11.2 (ROCm 4.5.0)
71 #if defined(CUDART_VERSION) || defined(USE_ROCM)
72 #define AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() 1
73 #else
74 #define AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() 0
75 #endif
76