1 #pragma once 2 3 #include <ATen/cuda/CUDAContext.h> 4 #if defined(USE_ROCM) 5 #include <hipsparse/hipsparse-version.h> 6 #define HIPSPARSE_VERSION ((hipsparseVersionMajor*100000) + (hipsparseVersionMinor*100) + hipsparseVersionPatch) 7 #endif 8 9 // cuSparse Generic API added in CUDA 10.1 10 // Windows support added in CUDA 11.0 11 #if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && ((CUSPARSE_VERSION >= 10300) || (CUSPARSE_VERSION >= 11000 && defined(_WIN32))) 12 #define AT_USE_CUSPARSE_GENERIC_API() 1 13 #else 14 #define AT_USE_CUSPARSE_GENERIC_API() 0 15 #endif 16 17 // cuSparse Generic API descriptor pointers were changed to const in CUDA 12.0 18 #if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && \ 19 (CUSPARSE_VERSION < 12000) 20 #define AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS() 1 21 #else 22 #define AT_USE_CUSPARSE_NON_CONST_DESCRIPTORS() 0 23 #endif 24 25 #if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && \ 26 (CUSPARSE_VERSION >= 12000) 27 #define AT_USE_CUSPARSE_CONST_DESCRIPTORS() 1 28 #else 29 #define AT_USE_CUSPARSE_CONST_DESCRIPTORS() 0 30 #endif 31 32 #if defined(USE_ROCM) 33 // hipSparse const API added in v2.4.0 34 #if HIPSPARSE_VERSION >= 200400 35 #define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 1 36 #define AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS() 0 37 #define AT_USE_HIPSPARSE_GENERIC_API() 1 38 #else 39 #define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 0 40 #define AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS() 1 41 #define AT_USE_HIPSPARSE_GENERIC_API() 1 42 #endif 43 #else // USE_ROCM 44 #define AT_USE_HIPSPARSE_CONST_DESCRIPTORS() 0 45 #define AT_USE_HIPSPARSE_NON_CONST_DESCRIPTORS() 0 46 #define AT_USE_HIPSPARSE_GENERIC_API() 0 47 #endif // USE_ROCM 48 49 // cuSparse Generic API spsv function was added in CUDA 11.3.0 50 #if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11500) 51 #define AT_USE_CUSPARSE_GENERIC_SPSV() 1 52 #else 53 #define AT_USE_CUSPARSE_GENERIC_SPSV() 0 54 #endif 55 56 // cuSparse Generic API spsm function was added in CUDA 11.3.1 57 #if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11600) 58 #define AT_USE_CUSPARSE_GENERIC_SPSM() 1 59 #else 60 #define AT_USE_CUSPARSE_GENERIC_SPSM() 0 61 #endif 62 63 // cuSparse Generic API sddmm function was added in CUDA 11.2.1 (cuSparse version 11400) 64 #if defined(CUDART_VERSION) && defined(CUSPARSE_VERSION) && (CUSPARSE_VERSION >= 11400) 65 #define AT_USE_CUSPARSE_GENERIC_SDDMM() 1 66 #else 67 #define AT_USE_CUSPARSE_GENERIC_SDDMM() 0 68 #endif 69 70 // BSR triangular solve functions were added in hipSPARSE 1.11.2 (ROCm 4.5.0) 71 #if defined(CUDART_VERSION) || defined(USE_ROCM) 72 #define AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() 1 73 #else 74 #define AT_USE_HIPSPARSE_TRIANGULAR_SOLVE() 0 75 #endif 76