xref: /aosp_15_r20/external/eigen/doc/examples/matrixfree_cg.cpp (revision bf2c37156dfe67e5dfebd6d394bad8b2ab5804d4)
1*bf2c3715SXin Li #include <iostream>
2*bf2c3715SXin Li #include <Eigen/Core>
3*bf2c3715SXin Li #include <Eigen/Dense>
4*bf2c3715SXin Li #include <Eigen/IterativeLinearSolvers>
5*bf2c3715SXin Li #include <unsupported/Eigen/IterativeSolvers>
6*bf2c3715SXin Li 
7*bf2c3715SXin Li class MatrixReplacement;
8*bf2c3715SXin Li using Eigen::SparseMatrix;
9*bf2c3715SXin Li 
10*bf2c3715SXin Li namespace Eigen {
11*bf2c3715SXin Li namespace internal {
12*bf2c3715SXin Li   // MatrixReplacement looks-like a SparseMatrix, so let's inherits its traits:
13*bf2c3715SXin Li   template<>
14*bf2c3715SXin Li   struct traits<MatrixReplacement> :  public Eigen::internal::traits<Eigen::SparseMatrix<double> >
15*bf2c3715SXin Li   {};
16*bf2c3715SXin Li }
17*bf2c3715SXin Li }
18*bf2c3715SXin Li 
19*bf2c3715SXin Li // Example of a matrix-free wrapper from a user type to Eigen's compatible type
20*bf2c3715SXin Li // For the sake of simplicity, this example simply wrap a Eigen::SparseMatrix.
21*bf2c3715SXin Li class MatrixReplacement : public Eigen::EigenBase<MatrixReplacement> {
22*bf2c3715SXin Li public:
23*bf2c3715SXin Li   // Required typedefs, constants, and method:
24*bf2c3715SXin Li   typedef double Scalar;
25*bf2c3715SXin Li   typedef double RealScalar;
26*bf2c3715SXin Li   typedef int StorageIndex;
27*bf2c3715SXin Li   enum {
28*bf2c3715SXin Li     ColsAtCompileTime = Eigen::Dynamic,
29*bf2c3715SXin Li     MaxColsAtCompileTime = Eigen::Dynamic,
30*bf2c3715SXin Li     IsRowMajor = false
31*bf2c3715SXin Li   };
32*bf2c3715SXin Li 
rows() const33*bf2c3715SXin Li   Index rows() const { return mp_mat->rows(); }
cols() const34*bf2c3715SXin Li   Index cols() const { return mp_mat->cols(); }
35*bf2c3715SXin Li 
36*bf2c3715SXin Li   template<typename Rhs>
operator *(const Eigen::MatrixBase<Rhs> & x) const37*bf2c3715SXin Li   Eigen::Product<MatrixReplacement,Rhs,Eigen::AliasFreeProduct> operator*(const Eigen::MatrixBase<Rhs>& x) const {
38*bf2c3715SXin Li     return Eigen::Product<MatrixReplacement,Rhs,Eigen::AliasFreeProduct>(*this, x.derived());
39*bf2c3715SXin Li   }
40*bf2c3715SXin Li 
41*bf2c3715SXin Li   // Custom API:
MatrixReplacement()42*bf2c3715SXin Li   MatrixReplacement() : mp_mat(0) {}
43*bf2c3715SXin Li 
attachMyMatrix(const SparseMatrix<double> & mat)44*bf2c3715SXin Li   void attachMyMatrix(const SparseMatrix<double> &mat) {
45*bf2c3715SXin Li     mp_mat = &mat;
46*bf2c3715SXin Li   }
my_matrix() const47*bf2c3715SXin Li   const SparseMatrix<double> my_matrix() const { return *mp_mat; }
48*bf2c3715SXin Li 
49*bf2c3715SXin Li private:
50*bf2c3715SXin Li   const SparseMatrix<double> *mp_mat;
51*bf2c3715SXin Li };
52*bf2c3715SXin Li 
53*bf2c3715SXin Li 
54*bf2c3715SXin Li // Implementation of MatrixReplacement * Eigen::DenseVector though a specialization of internal::generic_product_impl:
55*bf2c3715SXin Li namespace Eigen {
56*bf2c3715SXin Li namespace internal {
57*bf2c3715SXin Li 
58*bf2c3715SXin Li   template<typename Rhs>
59*bf2c3715SXin Li   struct generic_product_impl<MatrixReplacement, Rhs, SparseShape, DenseShape, GemvProduct> // GEMV stands for matrix-vector
60*bf2c3715SXin Li   : generic_product_impl_base<MatrixReplacement,Rhs,generic_product_impl<MatrixReplacement,Rhs> >
61*bf2c3715SXin Li   {
62*bf2c3715SXin Li     typedef typename Product<MatrixReplacement,Rhs>::Scalar Scalar;
63*bf2c3715SXin Li 
64*bf2c3715SXin Li     template<typename Dest>
scaleAndAddToEigen::internal::generic_product_impl65*bf2c3715SXin Li     static void scaleAndAddTo(Dest& dst, const MatrixReplacement& lhs, const Rhs& rhs, const Scalar& alpha)
66*bf2c3715SXin Li     {
67*bf2c3715SXin Li       // This method should implement "dst += alpha * lhs * rhs" inplace,
68*bf2c3715SXin Li       // however, for iterative solvers, alpha is always equal to 1, so let's not bother about it.
69*bf2c3715SXin Li       assert(alpha==Scalar(1) && "scaling is not implemented");
70*bf2c3715SXin Li       EIGEN_ONLY_USED_FOR_DEBUG(alpha);
71*bf2c3715SXin Li 
72*bf2c3715SXin Li       // Here we could simply call dst.noalias() += lhs.my_matrix() * rhs,
73*bf2c3715SXin Li       // but let's do something fancier (and less efficient):
74*bf2c3715SXin Li       for(Index i=0; i<lhs.cols(); ++i)
75*bf2c3715SXin Li         dst += rhs(i) * lhs.my_matrix().col(i);
76*bf2c3715SXin Li     }
77*bf2c3715SXin Li   };
78*bf2c3715SXin Li 
79*bf2c3715SXin Li }
80*bf2c3715SXin Li }
81*bf2c3715SXin Li 
main()82*bf2c3715SXin Li int main()
83*bf2c3715SXin Li {
84*bf2c3715SXin Li   int n = 10;
85*bf2c3715SXin Li   Eigen::SparseMatrix<double> S = Eigen::MatrixXd::Random(n,n).sparseView(0.5,1);
86*bf2c3715SXin Li   S = S.transpose()*S;
87*bf2c3715SXin Li 
88*bf2c3715SXin Li   MatrixReplacement A;
89*bf2c3715SXin Li   A.attachMyMatrix(S);
90*bf2c3715SXin Li 
91*bf2c3715SXin Li   Eigen::VectorXd b(n), x;
92*bf2c3715SXin Li   b.setRandom();
93*bf2c3715SXin Li 
94*bf2c3715SXin Li   // Solve Ax = b using various iterative solver with matrix-free version:
95*bf2c3715SXin Li   {
96*bf2c3715SXin Li     Eigen::ConjugateGradient<MatrixReplacement, Eigen::Lower|Eigen::Upper, Eigen::IdentityPreconditioner> cg;
97*bf2c3715SXin Li     cg.compute(A);
98*bf2c3715SXin Li     x = cg.solve(b);
99*bf2c3715SXin Li     std::cout << "CG:       #iterations: " << cg.iterations() << ", estimated error: " << cg.error() << std::endl;
100*bf2c3715SXin Li   }
101*bf2c3715SXin Li 
102*bf2c3715SXin Li   {
103*bf2c3715SXin Li     Eigen::BiCGSTAB<MatrixReplacement, Eigen::IdentityPreconditioner> bicg;
104*bf2c3715SXin Li     bicg.compute(A);
105*bf2c3715SXin Li     x = bicg.solve(b);
106*bf2c3715SXin Li     std::cout << "BiCGSTAB: #iterations: " << bicg.iterations() << ", estimated error: " << bicg.error() << std::endl;
107*bf2c3715SXin Li   }
108*bf2c3715SXin Li 
109*bf2c3715SXin Li   {
110*bf2c3715SXin Li     Eigen::GMRES<MatrixReplacement, Eigen::IdentityPreconditioner> gmres;
111*bf2c3715SXin Li     gmres.compute(A);
112*bf2c3715SXin Li     x = gmres.solve(b);
113*bf2c3715SXin Li     std::cout << "GMRES:    #iterations: " << gmres.iterations() << ", estimated error: " << gmres.error() << std::endl;
114*bf2c3715SXin Li   }
115*bf2c3715SXin Li 
116*bf2c3715SXin Li   {
117*bf2c3715SXin Li     Eigen::DGMRES<MatrixReplacement, Eigen::IdentityPreconditioner> gmres;
118*bf2c3715SXin Li     gmres.compute(A);
119*bf2c3715SXin Li     x = gmres.solve(b);
120*bf2c3715SXin Li     std::cout << "DGMRES:   #iterations: " << gmres.iterations() << ", estimated error: " << gmres.error() << std::endl;
121*bf2c3715SXin Li   }
122*bf2c3715SXin Li 
123*bf2c3715SXin Li   {
124*bf2c3715SXin Li     Eigen::MINRES<MatrixReplacement, Eigen::Lower|Eigen::Upper, Eigen::IdentityPreconditioner> minres;
125*bf2c3715SXin Li     minres.compute(A);
126*bf2c3715SXin Li     x = minres.solve(b);
127*bf2c3715SXin Li     std::cout << "MINRES:   #iterations: " << minres.iterations() << ", estimated error: " << minres.error() << std::endl;
128*bf2c3715SXin Li   }
129*bf2c3715SXin Li }
130