xref: /aosp_15_r20/external/eigen/doc/examples/matrixfree_cg.cpp (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1 #include <iostream>
2 #include <Eigen/Core>
3 #include <Eigen/Dense>
4 #include <Eigen/IterativeLinearSolvers>
5 #include <unsupported/Eigen/IterativeSolvers>
6 
7 class MatrixReplacement;
8 using Eigen::SparseMatrix;
9 
10 namespace Eigen {
11 namespace internal {
12   // MatrixReplacement looks-like a SparseMatrix, so let's inherits its traits:
13   template<>
14   struct traits<MatrixReplacement> :  public Eigen::internal::traits<Eigen::SparseMatrix<double> >
15   {};
16 }
17 }
18 
19 // Example of a matrix-free wrapper from a user type to Eigen's compatible type
20 // For the sake of simplicity, this example simply wrap a Eigen::SparseMatrix.
21 class MatrixReplacement : public Eigen::EigenBase<MatrixReplacement> {
22 public:
23   // Required typedefs, constants, and method:
24   typedef double Scalar;
25   typedef double RealScalar;
26   typedef int StorageIndex;
27   enum {
28     ColsAtCompileTime = Eigen::Dynamic,
29     MaxColsAtCompileTime = Eigen::Dynamic,
30     IsRowMajor = false
31   };
32 
rows() const33   Index rows() const { return mp_mat->rows(); }
cols() const34   Index cols() const { return mp_mat->cols(); }
35 
36   template<typename Rhs>
operator *(const Eigen::MatrixBase<Rhs> & x) const37   Eigen::Product<MatrixReplacement,Rhs,Eigen::AliasFreeProduct> operator*(const Eigen::MatrixBase<Rhs>& x) const {
38     return Eigen::Product<MatrixReplacement,Rhs,Eigen::AliasFreeProduct>(*this, x.derived());
39   }
40 
41   // Custom API:
MatrixReplacement()42   MatrixReplacement() : mp_mat(0) {}
43 
attachMyMatrix(const SparseMatrix<double> & mat)44   void attachMyMatrix(const SparseMatrix<double> &mat) {
45     mp_mat = &mat;
46   }
my_matrix() const47   const SparseMatrix<double> my_matrix() const { return *mp_mat; }
48 
49 private:
50   const SparseMatrix<double> *mp_mat;
51 };
52 
53 
54 // Implementation of MatrixReplacement * Eigen::DenseVector though a specialization of internal::generic_product_impl:
55 namespace Eigen {
56 namespace internal {
57 
58   template<typename Rhs>
59   struct generic_product_impl<MatrixReplacement, Rhs, SparseShape, DenseShape, GemvProduct> // GEMV stands for matrix-vector
60   : generic_product_impl_base<MatrixReplacement,Rhs,generic_product_impl<MatrixReplacement,Rhs> >
61   {
62     typedef typename Product<MatrixReplacement,Rhs>::Scalar Scalar;
63 
64     template<typename Dest>
scaleAndAddToEigen::internal::generic_product_impl65     static void scaleAndAddTo(Dest& dst, const MatrixReplacement& lhs, const Rhs& rhs, const Scalar& alpha)
66     {
67       // This method should implement "dst += alpha * lhs * rhs" inplace,
68       // however, for iterative solvers, alpha is always equal to 1, so let's not bother about it.
69       assert(alpha==Scalar(1) && "scaling is not implemented");
70       EIGEN_ONLY_USED_FOR_DEBUG(alpha);
71 
72       // Here we could simply call dst.noalias() += lhs.my_matrix() * rhs,
73       // but let's do something fancier (and less efficient):
74       for(Index i=0; i<lhs.cols(); ++i)
75         dst += rhs(i) * lhs.my_matrix().col(i);
76     }
77   };
78 
79 }
80 }
81 
main()82 int main()
83 {
84   int n = 10;
85   Eigen::SparseMatrix<double> S = Eigen::MatrixXd::Random(n,n).sparseView(0.5,1);
86   S = S.transpose()*S;
87 
88   MatrixReplacement A;
89   A.attachMyMatrix(S);
90 
91   Eigen::VectorXd b(n), x;
92   b.setRandom();
93 
94   // Solve Ax = b using various iterative solver with matrix-free version:
95   {
96     Eigen::ConjugateGradient<MatrixReplacement, Eigen::Lower|Eigen::Upper, Eigen::IdentityPreconditioner> cg;
97     cg.compute(A);
98     x = cg.solve(b);
99     std::cout << "CG:       #iterations: " << cg.iterations() << ", estimated error: " << cg.error() << std::endl;
100   }
101 
102   {
103     Eigen::BiCGSTAB<MatrixReplacement, Eigen::IdentityPreconditioner> bicg;
104     bicg.compute(A);
105     x = bicg.solve(b);
106     std::cout << "BiCGSTAB: #iterations: " << bicg.iterations() << ", estimated error: " << bicg.error() << std::endl;
107   }
108 
109   {
110     Eigen::GMRES<MatrixReplacement, Eigen::IdentityPreconditioner> gmres;
111     gmres.compute(A);
112     x = gmres.solve(b);
113     std::cout << "GMRES:    #iterations: " << gmres.iterations() << ", estimated error: " << gmres.error() << std::endl;
114   }
115 
116   {
117     Eigen::DGMRES<MatrixReplacement, Eigen::IdentityPreconditioner> gmres;
118     gmres.compute(A);
119     x = gmres.solve(b);
120     std::cout << "DGMRES:   #iterations: " << gmres.iterations() << ", estimated error: " << gmres.error() << std::endl;
121   }
122 
123   {
124     Eigen::MINRES<MatrixReplacement, Eigen::Lower|Eigen::Upper, Eigen::IdentityPreconditioner> minres;
125     minres.compute(A);
126     x = minres.solve(b);
127     std::cout << "MINRES:   #iterations: " << minres.iterations() << ", estimated error: " << minres.error() << std::endl;
128   }
129 }
130