yat  0.16.4pre
BLAS_level3.h
1 #ifndef _theplu_yat_utility_blas_level3
2 #define _theplu_yat_utility_blas_level3
3 
4 // $Id: BLAS_level3.h 3654 2017-07-10 05:36:55Z peter $
5 
6 /*
7  Copyright (C) 2017 Peter Johansson
8 
9  This file is part of the yat library, http://dev.thep.lu.se/yat
10 
11  The yat library is free software; you can redistribute it and/or
12  modify it under the terms of the GNU General Public License as
13  published by the Free Software Foundation; either version 3 of the
14  License, or (at your option) any later version.
15 
16  The yat library is distributed in the hope that it will be useful,
17  but WITHOUT ANY WARRANTY; without even the implied warranty of
18  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
19  General Public License for more details.
20 
21  You should have received a copy of the GNU General Public License
22  along with yat. If not, see <http://www.gnu.org/licenses/>.
23 */
24 
25 #include "BasicMatrix.h"
26 #include "BLAS_utility.h"
27 #include "MatrixExpression.h"
28 
29 #include <gsl/gsl_matrix.h>
30 
31 namespace theplu {
32 namespace yat {
33 namespace utility {
34 
35  // This file defines operations using both Matrix (but not Vector)
36 
38 
39  namespace expression {
40  template<typename LHS, typename RHS, class OP>
41  class MatrixBinary
42  : public MatrixExpression<MatrixBinary<LHS, RHS, OP> >
43  {
44  public:
45  MatrixBinary(const BasicMatrix<LHS>& lhs, const BasicMatrix<RHS>& rhs)
46  : lhs_(lhs), rhs_(rhs)
47  {
48  }
49 
50  size_t rows(void) const { return lhs_.rows(); }
51  size_t columns(void) const { return rhs_.columns(); }
52 
53  double operator()(size_t row, size_t column) const
54  { return get(row, column, op_); }
55 
56  void calculate_matrix(gsl_matrix*& result) const
57  {
58  detail::reallocate(result, this->rows(), this->columns());
59  calculate_matrix(result, op_);
60  }
61 
62  private:
63  const BasicMatrix<LHS>& lhs_;
64  const BasicMatrix<RHS>& rhs_;
65  OP op_;
66 
67  void calculate_matrix(gsl_matrix*& result, Multiplies) const
68  {
69  YAT_ASSERT(detail::rows(result) == this->rows());
70  YAT_ASSERT(detail::columns(result) == this->columns());
71  YAT_ASSERT(lhs_.columns() == rhs_.rows());
72  gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0,
73  lhs_.gsl_matrix_p(), rhs_.gsl_matrix_p(),0.0, result);
74  YAT_ASSERT(result);
75  }
76 
77  template<class T>
78  void calculate_matrix(gsl_matrix*& result, T) const
79  {
80  YAT_ASSERT(detail::rows(result) == this->rows());
81  YAT_ASSERT(detail::columns(result) == this->columns());
82  for (size_t i=0; i<rows(); ++i)
83  for (size_t j=0; j<columns(); ++j)
84  gsl_matrix_set(result, i, j, (*this)(i, j));
85  }
86 
87 
88  double get(size_t row, size_t column, Plus) const
89  {
90  return lhs_(row, column) + rhs_(row, column);
91  }
92 
93 
94  double get(size_t row, size_t column, Minus) const
95  {
96  return lhs_(row, column) - rhs_(row, column);
97  }
98 
99 
100  double get(size_t row, size_t column, Multiplies) const
101  {
102  return gsl_matrix_get(this->gsl_matrix_p(), row, column);
103  }
104  };
105 
106 
107  template<class T>
108  class ScaledMatrix : public MatrixExpression<ScaledMatrix<T> >
109  {
110  public:
111  ScaledMatrix(double factor, const BasicMatrix<T>& A)
112  : A_(A), factor_(factor) {}
113 
114  size_t rows(void) const { return A_.rows(); }
115  size_t columns(void) const { return A_.columns(); }
116 
117 
118  double operator()(size_t i, size_t j) const
119  {
120  return factor_ * A_(i, j);
121  }
122 
123 
124  void calculate_matrix(gsl_matrix*& result) const
125  {
126  detail::copy(result, A_.gsl_matrix_p());
127  gsl_matrix_scale(result, factor_);
128  }
129 
130  private:
131  const BasicMatrix<T>& A_;
132  double factor_;
133  };
134 
135 
136  template<class T>
137  class TransposedMatrix : public MatrixExpression<TransposedMatrix<T> >
138  {
139  public:
140  TransposedMatrix(const BasicMatrix<T>& A)
141  : A_(A) {}
142 
143  size_t rows(void) const { return A_.columns(); }
144  size_t columns(void) const { return A_.rows(); }
145 
146  double operator()(size_t i, size_t j) const
147  {
148  return A_(j, i);
149  }
150 
151  void calculate_matrix(gsl_matrix*& result) const
152  {
153  detail::reallocate(result, rows(), columns());
154  gsl_matrix_transpose_memcpy(result, A_.gsl_matrix_p());
155  }
156 
157  private:
158  const BasicMatrix<T>& A_;
159  };
160 
161 
162  } // end namespace expression
163 
165 
178  template<class Derived1, class Derived2>
179  expression::MatrixBinary<Derived1, Derived2, expression::Plus>
181  {
182  return expression::MatrixBinary<Derived1, Derived2,
183  expression::Plus>(lhs, rhs);
184  }
185 
186 
199  template<class Derived1, class Derived2>
200  expression::MatrixBinary<Derived1, Derived2, expression::Minus>
202  {
203  return expression::MatrixBinary<Derived1, Derived2,
204  expression::Minus>(lhs, rhs);
205  }
206 
207 
217  template<class Derived1, class Derived2>
218  expression::MatrixBinary<Derived1, Derived2, expression::Multiplies>
220  {
221  YAT_ASSERT(lhs.columns() == rhs.rows());
222  return expression::MatrixBinary<Derived1, Derived2,
223  expression::Multiplies>(lhs, rhs);
224  }
225 
226 
237  template<class T>
238  expression::ScaledMatrix<T>
239  operator*(const BasicMatrix<T>& A, double k)
240  {
241  return expression::ScaledMatrix<T>(k, A);
242  }
243 
244 
255  template<class T>
256  expression::ScaledMatrix<T>
257  operator*(double k, const BasicMatrix<T>& A)
258  {
259  return expression::ScaledMatrix<T>(k, A);
260  }
261 
262 
270  template<typename T>
271  expression::ScaledMatrix<T>
273  {
274  return expression::ScaledMatrix<T>(-1.0, m);
275  }
276 
277 
285  template<typename T>
286  expression::TransposedMatrix<T> transpose(const BasicMatrix<T>& A)
287  {
288  return expression::TransposedMatrix<T>(A);
289  }
290 
291 }}} // of namespace utility, yat, and theplu
292 
293 #endif
expression::ScaledMatrix< T > operator-(const BasicMatrix< T > &m)
negation operator
Definition: BLAS_level3.h:272
The Department of Theoretical Physics namespace as we define it.
size_t rows(void) const
Definition: BasicMatrix.h:61
expression::MatrixBinary< Derived1, Derived2, expression::Minus > operator-(const BasicMatrix< Derived1 > &lhs, const BasicMatrix< Derived2 > &rhs)
Matrix subtraction operator.
Definition: BLAS_level3.h:201
size_t columns(void) const
Definition: BasicMatrix.h:67
Definition: BLAS_utility.h:31
Definition: BLAS_utility.h:30
expression::ScaledMatrix< T > operator*(double k, const BasicMatrix< T > &A)
Definition: BLAS_level3.h:257
expression::ScaledMatrix< T > operator*(const BasicMatrix< T > &A, double k)
Definition: BLAS_level3.h:239
Definition: BasicMatrix.h:38
expression::MatrixBinary< Derived1, Derived2, expression::Multiplies > operator*(const BasicMatrix< Derived1 > &lhs, const BasicMatrix< Derived2 > &rhs)
Matrix multiplication operator.
Definition: BLAS_level3.h:219
expression::MatrixBinary< Derived1, Derived2, expression::Plus > operator+(const BasicMatrix< Derived1 > &lhs, const BasicMatrix< Derived2 > &rhs)
Matrix addition operator.
Definition: BLAS_level3.h:180
expression::TransposedMatrix< T > transpose(const BasicMatrix< T > &A)
transpose function
Definition: BLAS_level3.h:286

Generated on Thu Dec 12 2019 03:12:08 for yat by  doxygen 1.8.11