xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <ATen/Context.h>
5 #include <ATen/cuda/CUDAContext.h>
6 #include <c10/cuda/CUDACachingAllocator.h>
7 
8 #include <ATen/native/TransposeType.h>
9 #include <ATen/native/cuda/MiscUtils.h>
10 
11 #if (defined(CUDART_VERSION) && defined(CUSOLVER_VERSION)) || defined(USE_ROCM)
12 #define USE_LINALG_SOLVER
13 #endif
14 
15 // cusolverDn<T>potrfBatched may have numerical issue before cuda 11.3 release,
16 // (which is cusolver version 11101 in the header), so we only use cusolver potrf batched
17 // if cuda version is >= 11.3
18 #if CUSOLVER_VERSION >= 11101
19   constexpr bool use_cusolver_potrf_batched_ = true;
20 #else
21   constexpr bool use_cusolver_potrf_batched_ = false;
22 #endif
23 
24 // cusolverDn<T>syevjBatched may have numerical issue before cuda 11.3.1 release,
25 // (which is cusolver version 11102 in the header), so we only use cusolver syevj batched
26 // if cuda version is >= 11.3.1
27 // See https://github.com/pytorch/pytorch/pull/53040#issuecomment-793626268 and https://github.com/cupy/cupy/issues/4847
28 #if CUSOLVER_VERSION >= 11102
29   constexpr bool use_cusolver_syevj_batched_ = true;
30 #else
31   constexpr bool use_cusolver_syevj_batched_ = false;
32 #endif
33 
34 // From cuSOLVER doc: Jacobi method has quadratic convergence, so the accuracy is not proportional to number of sweeps.
35 //   To guarantee certain accuracy, the user should configure tolerance only.
36 // The current pytorch implementation sets gesvdj tolerance to epsilon of a C++ data type to target the best possible precision.
37 constexpr int cusolver_gesvdj_max_sweeps = 400;
38 
39 namespace at {
40 namespace native {
41 
42 void geqrf_batched_cublas(const Tensor& input, const Tensor& tau);
43 void triangular_solve_cublas(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular);
44 void triangular_solve_batched_cublas(const Tensor& A, const Tensor& B, bool left, bool upper, TransposeType transpose, bool unitriangular);
45 void gels_batched_cublas(const Tensor& a, Tensor& b, Tensor& infos);
46 void ldl_factor_cusolver(
47     const Tensor& LD,
48     const Tensor& pivots,
49     const Tensor& info,
50     bool upper,
51     bool hermitian);
52 void ldl_solve_cusolver(
53     const Tensor& LD,
54     const Tensor& pivots,
55     const Tensor& B,
56     bool upper);
57 void lu_factor_batched_cublas(const Tensor& A, const Tensor& pivots, const Tensor& infos, bool get_pivots);
58 void lu_solve_batched_cublas(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType transpose);
59 
60 #if defined(USE_LINALG_SOLVER)
61 
62 // entrance of calculations of `svd` using cusolver gesvdj and gesvdjBatched
63 void svd_cusolver(const Tensor& A, const bool full_matrices, const bool compute_uv,
64   const std::optional<c10::string_view>& driver, const Tensor& U, const Tensor& S, const Tensor& V, const Tensor& info);
65 
66 // entrance of calculations of `cholesky` using cusolver potrf and potrfBatched
67 void cholesky_helper_cusolver(const Tensor& input, bool upper, const Tensor& info);
68 Tensor _cholesky_solve_helper_cuda_cusolver(const Tensor& self, const Tensor& A, bool upper);
69 Tensor& cholesky_inverse_kernel_impl_cusolver(Tensor &result, Tensor& infos, bool upper);
70 
71 void geqrf_cusolver(const Tensor& input, const Tensor& tau);
72 void ormqr_cusolver(const Tensor& input, const Tensor& tau, const Tensor& other, bool left, bool transpose);
73 Tensor& orgqr_helper_cusolver(Tensor& result, const Tensor& tau);
74 
75 void linalg_eigh_cusolver(const Tensor& eigenvalues, const Tensor& eigenvectors, const Tensor& infos, bool upper, bool compute_eigenvectors);
76 void lu_solve_looped_cusolver(const Tensor& LU, const Tensor& pivots, const Tensor& B, TransposeType transpose);
77 
78 void lu_factor_looped_cusolver(const Tensor& self, const Tensor& pivots, const Tensor& infos, bool get_pivots);
79 
80 #endif  // USE_LINALG_SOLVER
81 
82 #if defined(BUILD_LAZY_CUDA_LINALG)
83 namespace cuda { namespace detail {
84 // This is only used for an old-style dispatches
85 // Please do not add any new entires to it
86 struct LinalgDispatch {
87    Tensor (*cholesky_solve_helper)(const Tensor& self, const Tensor& A, bool upper);
88 };
89 C10_EXPORT void registerLinalgDispatch(const LinalgDispatch&);
90 }} // namespace cuda::detail
91 #endif
92 
93 }}  // namespace at::native
94