yat  0.16.4pre
DiagonalMatrix.h
1 #ifndef theplu_yat_utility_diagonal_matrix
2 #define theplu_yat_utility_diagonal_matrix
3 
4 // $Id: DiagonalMatrix.h 3655 2017-07-13 00:34:18Z peter $
5 
6 /*
7  Copyright (C) 2017 Peter Johansson
8 
9  This file is part of the yat library, http://dev.thep.lu.se/yat
10 
11  The yat library is free software; you can redistribute it and/or
12  modify it under the terms of the GNU General Public License as
13  published by the Free Software Foundation; either version 3 of the
14  License, or (at your option) any later version.
15 
16  The yat library is distributed in the hope that it will be useful,
17  but WITHOUT ANY WARRANTY; without even the implied warranty of
18  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19  General Public License for more details.
20 
21  You should have received a copy of the GNU General Public License
22  along with yat. If not, see <http://www.gnu.org/licenses/>.
23 */
24 
25 #include "BasicMatrix.h"
26 #include "BasicVector.h"
27 #include "BLAS_utility.h"
28 
29 #include "MatrixExpression.h"
30 #include "VectorExpression.h"
31 
32 #include "Vector.h"
33 #include "yat_assert.h"
34 
35 namespace theplu {
36 namespace yat {
37 namespace utility {
38 
39  class Matrix;
40 
50  {
51  public:
55  DiagonalMatrix(void);
56 
60  DiagonalMatrix(size_t row, size_t column, double x=0.0);
61 
65  explicit DiagonalMatrix(const VectorBase& diagonal);
66 
71  explicit DiagonalMatrix(const Matrix& m);
72 
76  size_t rows(void) const;
77 
81  size_t columns(void) const;
82 
86  const double operator()(size_t row, size_t col) const;
87 
93  double& operator()(size_t i);
94  private:
95  Vector data_;
96  size_t row_;
97  size_t col_;
98  };
99 
101 
102  // expression classes
103  namespace expression
104  {
105  template<typename T>
106  class DiagonalMatrixMatrix :
107  public MatrixExpression<DiagonalMatrixMatrix<T> >
108  {
109  public:
110  DiagonalMatrixMatrix(const DiagonalMatrix& lhs,const BasicMatrix<T>& rhs)
111  : lhs_(lhs), rhs_(rhs) {}
112  size_t rows(void) const { return lhs_.rows(); }
113  size_t columns(void) const { return rhs_.columns(); }
114  double operator()(size_t i, size_t j) const
115  { return lhs_(i, i) * rhs_(i, j); }
116 
117  void calculate_matrix(gsl_matrix*& result) const
118  {
119  detail::reallocate(result, rows(), columns());
120  for (size_t i=0; i<rows(); ++i)
121  for (size_t j=0; j<columns(); ++j)
122  gsl_matrix_set(result, i, j, (*this)(i, j));
123  }
124 
125  private:
126  const DiagonalMatrix& lhs_;
127  const BasicMatrix<T>& rhs_;
128  };
129 
130 
131  template<class T>
132  class DiagonalMatrixVector
133  : public VectorExpression<DiagonalMatrixVector<T> >
134  {
135  public:
136  DiagonalMatrixVector(const DiagonalMatrix& lhs, const BasicVector<T>& rhs)
137  : m_(lhs), vec_(rhs), size_(lhs.rows())
138  {
139  }
140 
141  DiagonalMatrixVector(const BasicVector<T>& lhs, const DiagonalMatrix& rhs)
142  : m_(rhs), vec_(lhs), size_(rhs.columns())
143  {
144  }
145 
146  double operator()(size_t i) const
147  {
148  YAT_ASSERT(i < size());
149  return m_(i, i) * vec_(i);
150  }
151 
152 
153  size_t size(void) const
154  {
155  return size_;
156  }
157 
158 
159  void calculate_gsl_vector_p(void) const
160  {
161  this->allocate_memory(size_);
162  for (size_t i=0; i<size_; ++i)
163  gsl_vector_set(this->v_, i, (*this)(i));
164  }
165  private:
166  const DiagonalMatrix& m_;
167  const BasicVector<T>& vec_;
168  size_t size_;
169  };
170 
171 
172 
173  template<typename T>
174  class MatrixDiagonalMatrix :
175  public MatrixExpression<MatrixDiagonalMatrix<T> >
176  {
177  public:
178  MatrixDiagonalMatrix(const BasicMatrix<T>& lhs, const DiagonalMatrix& rhs)
179  : lhs_(lhs), rhs_(rhs) {}
180  size_t rows(void) const { return lhs_.rows(); }
181  size_t columns(void) const { return rhs_.columns(); }
182  double operator()(size_t i, size_t j) const
183  { return lhs_(i, j) * rhs_(j, j); }
184 
185  void calculate_matrix(gsl_matrix*& result) const
186  {
187  detail::reallocate(result, rows(), columns());
188  for (size_t i=0; i<rows(); ++i)
189  for (size_t j=0; j<columns(); ++j)
190  gsl_matrix_set(result, i, j, (*this)(i, j));
191  }
192 
193  private:
194  const BasicMatrix<T>& lhs_;
195  const DiagonalMatrix& rhs_;
196  };
197 
198 
199  template<typename T>
200  class VectorDiagonalMatrix :
201  public BasicVector<VectorDiagonalMatrix<T> >
202  {
203  public:
204  VectorDiagonalMatrix(const BasicVector<T>& rhs, const DiagonalMatrix& lhs)
205  : lhs_(lhs), rhs_(rhs) {}
206 
207  double operator()(size_t i) const
208  {
209  YAT_ASSERT(i < size());
210  return lhs_(i) * rhs_(i, i);
211  }
212 
213  size_t size(void) const
214  {
215  return rhs_.columns();
216  }
217 
218 
219  const gsl_vector* gsl_vector_p(void) const
220  {
221  size_t n = size();
222  gsl_vector* vec = detail::create_gsl_vector(n);
223  for (size_t i=0; i<n; ++i)
224  gsl_vector_set(vec, i, (*this)(i));
225  return vec;
226  }
227 
228  private:
229  const BasicVector<T>& lhs_;
230  const DiagonalMatrix& rhs_;
231  };
232 
233 
234  template<class M1, class M2, class OP>
235  class DiagonalMatrix
236  : public MatrixExpression<DiagonalMatrix<M1, M2, OP> >
237  {
238  public:
239  DiagonalMatrix(const M1& lhs, const M2& rhs)
240  : lhs_(lhs), rhs_(rhs)
241  {
242  YAT_ASSERT(lhs.rows() == rhs.rows());
243  YAT_ASSERT(lhs.columns() == rhs.columns());
244  }
245  double operator()(size_t i, size_t j) const
246  {
247  return get(i, j, OP());
248  }
249  size_t rows(void) const { return lhs_.rows(); }
250  size_t columns(void) const { return rhs_.columns(); }
251  void calculate_matrix(gsl_matrix*& m) const
252  {
253  detail::reallocate(m, rows(), columns());
254  for (size_t i=0; i<rows(); ++i)
255  for (size_t j=0; j<columns(); ++j)
256  gsl_matrix_set(m, i, j, (*this)(i,j));
257  }
258 
259  private:
260  const M1& lhs_;
261  const M2& rhs_;
262 
263  double get(size_t i, size_t j, Plus) const
264  { return lhs_(i, j) + rhs_(i, j); }
265 
266  double get(size_t i, size_t j, Minus) const
267  { return lhs_(i, j) - rhs_(i, j); }
268  };
269 
270  } // end of namespace expression
271 
273 
274 
285  const DiagonalMatrix& rhs);
286 
294  template<class T>
295  expression::DiagonalMatrixMatrix<T>
296  operator*(const DiagonalMatrix& lhs, const BasicMatrix<T>& rhs)
297  {
298  YAT_ASSERT(lhs.columns() == rhs.rows());
299  return expression::DiagonalMatrixMatrix<T>(lhs, rhs);
300  }
301 
302 
310  template<class T>
311  expression::MatrixDiagonalMatrix<T>
312  operator*(const BasicMatrix<T>& lhs, const DiagonalMatrix& rhs)
313  {
314  YAT_ASSERT(lhs.columns() == rhs.rows());
315  return expression::MatrixDiagonalMatrix<T>(lhs, rhs);
316  }
317 
318 
327  operator+(const DiagonalMatrix& lhs, const DiagonalMatrix& rhs);
328 
329 
337  template<class T>
338  expression::DiagonalMatrix<DiagonalMatrix, BasicMatrix<T>, expression::Plus>
339  operator+(const DiagonalMatrix& lhs, const BasicMatrix<T>& rhs)
340  {
341  YAT_ASSERT(lhs.rows() == rhs.rows());
342  YAT_ASSERT(lhs.columns() == rhs.columns());
343  return expression::DiagonalMatrix<DiagonalMatrix, BasicMatrix<T>,
344  expression::Plus>(lhs, rhs);
345  }
346 
347 
355  template<class T>
356  expression::DiagonalMatrix<BasicMatrix<T>, DiagonalMatrix, expression::Plus>
357  operator+(const BasicMatrix<T>& lhs, const DiagonalMatrix& rhs)
358  {
359  YAT_ASSERT(lhs.rows() == rhs.rows());
360  YAT_ASSERT(lhs.columns() == rhs.columns());
361  return expression::DiagonalMatrix<BasicMatrix<T>, DiagonalMatrix,
362  expression::Plus>(lhs, rhs);
363  }
364 
365 
373  DiagonalMatrix
374  operator-(const DiagonalMatrix& lhs, const DiagonalMatrix& rhs);
375 
376 
384  template<class T>
385  expression::DiagonalMatrix<DiagonalMatrix, BasicMatrix<T>, expression::Minus>
386  operator-(const DiagonalMatrix& lhs, const BasicMatrix<T>& rhs)
387  {
388  YAT_ASSERT(lhs.rows() == rhs.rows());
389  YAT_ASSERT(lhs.columns() == rhs.columns());
390  return expression::DiagonalMatrix<DiagonalMatrix, BasicMatrix<T>,
391  expression::Minus>(lhs, rhs);
392  }
393 
394 
402  template<class T>
403  expression::DiagonalMatrix<BasicMatrix<T>, DiagonalMatrix, expression::Minus>
404  operator-(const BasicMatrix<T>& lhs, const DiagonalMatrix& rhs)
405  {
406  YAT_ASSERT(lhs.rows() == rhs.rows());
407  YAT_ASSERT(lhs.columns() == rhs.columns());
408  return expression::DiagonalMatrix<BasicMatrix<T>, DiagonalMatrix,
409  expression::Minus>(lhs, rhs);
410  }
411 
412 
420  template<class T>
421  expression::DiagonalMatrixVector<T>
422  operator*(const DiagonalMatrix& lhs, const BasicVector<T>& rhs)
423  {
424  YAT_ASSERT(lhs.columns() == rhs.size());
425  return expression::DiagonalMatrixVector<T>(lhs, rhs);
426  }
427 
428 
436  template<class T>
437  expression::DiagonalMatrixVector<T>
438  operator*(const BasicVector<T>& lhs, const DiagonalMatrix& rhs)
439  {
440  YAT_ASSERT(lhs.size() == rhs.rows());
441  return expression::DiagonalMatrixVector<T>(lhs, rhs);
442  }
443 
444 
445 }}}
446 #endif
const double operator()(size_t row, size_t col) const
expression::MatrixDiagonalMatrix< T > operator*(const BasicMatrix< T > &lhs, const DiagonalMatrix &rhs)
matrix matrix multiplication
Definition: DiagonalMatrix.h:312
size_t size(void) const
Definition: BasicVector.h:71
The Department of Theoretical Physics namespace as we define it.
size_t rows(void) const
Definition: BasicMatrix.h:61
size_t columns(void) const
Definition: BasicMatrix.h:67
DiagonalMatrix operator-(const DiagonalMatrix &lhs, const DiagonalMatrix &rhs)
matrix matrix subtraction
DiagonalMatrix operator+(const DiagonalMatrix &lhs, const DiagonalMatrix &rhs)
matrix matrix addition
expression::DiagonalMatrixVector< T > operator*(const BasicVector< T > &lhs, const DiagonalMatrix &rhs)
vector matrix multiplication
Definition: DiagonalMatrix.h:438
expression::DiagonalMatrix< DiagonalMatrix, BasicMatrix< T >, expression::Minus > operator-(const DiagonalMatrix &lhs, const BasicMatrix< T > &rhs)
matrix matrix subtraction
Definition: DiagonalMatrix.h:386
expression::DiagonalMatrix< BasicMatrix< T >, DiagonalMatrix, expression::Minus > operator-(const BasicMatrix< T > &lhs, const DiagonalMatrix &rhs)
matrix matrix subtraction
Definition: DiagonalMatrix.h:404
This is the yat interface to GSL vector.
Definition: Vector.h:59
Definition: BLAS_utility.h:31
This is the yat interface to GSL vector.
Definition: VectorBase.h:55
DiagonalMatrix operator*(const DiagonalMatrix &lhs, const DiagonalMatrix &rhs)
An expression that can be converted to a Vector.
Definition: VectorExpression.h:67
An expression that can be converted to a Matrix.
Definition: MatrixExpression.h:46
Definition: BasicVector.h:48
Definition: BLAS_utility.h:30
expression::DiagonalMatrixMatrix< T > operator*(const DiagonalMatrix &lhs, const BasicMatrix< T > &rhs)
matrix matrix multiplication
Definition: DiagonalMatrix.h:296
Interface to GSL matrix.
Definition: Matrix.h:74
Definition: BasicMatrix.h:38
DiagonalMatrix(void)
Default constructor.
expression::DiagonalMatrixVector< T > operator*(const DiagonalMatrix &lhs, const BasicVector< T > &rhs)
matrix vector multiplication
Definition: DiagonalMatrix.h:422
expression::DiagonalMatrix< DiagonalMatrix, BasicMatrix< T >, expression::Plus > operator+(const DiagonalMatrix &lhs, const BasicMatrix< T > &rhs)
matrix matrix addition
Definition: DiagonalMatrix.h:339
Diagonal Matrix.
Definition: DiagonalMatrix.h:49
expression::DiagonalMatrix< BasicMatrix< T >, DiagonalMatrix, expression::Plus > operator+(const BasicMatrix< T > &lhs, const DiagonalMatrix &rhs)
matrix matrix addition
Definition: DiagonalMatrix.h:357

Generated on Thu Dec 12 2019 03:12:08 for yat by  doxygen 1.8.11