1 #ifndef _theplu_yat_utility_blas_level3 2 #define _theplu_yat_utility_blas_level3 25 #include "BasicMatrix.h" 26 #include "BLAS_utility.h" 27 #include "MatrixExpression.h" 29 #include <gsl/gsl_matrix.h> 39 namespace expression {
40 template<
typename LHS,
typename RHS,
class OP>
42 :
public MatrixExpression<MatrixBinary<LHS, RHS, OP> >
45 MatrixBinary(
const BasicMatrix<LHS>& lhs,
const BasicMatrix<RHS>& rhs)
46 : lhs_(lhs), rhs_(rhs)
50 size_t rows(
void)
const {
return lhs_.rows(); }
51 size_t columns(
void)
const {
return rhs_.columns(); }
53 double operator()(
size_t row,
size_t column)
const 54 {
return get(row, column, op_); }
56 void calculate_matrix(gsl_matrix*& result)
const 58 detail::reallocate(result, this->rows(), this->columns());
59 calculate_matrix(result, op_);
63 const BasicMatrix<LHS>& lhs_;
64 const BasicMatrix<RHS>& rhs_;
67 void calculate_matrix(gsl_matrix*& result, Multiplies)
const 69 YAT_ASSERT(detail::rows(result) == this->rows());
70 YAT_ASSERT(detail::columns(result) == this->columns());
71 YAT_ASSERT(lhs_.columns() == rhs_.rows());
72 gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0,
73 lhs_.gsl_matrix_p(), rhs_.gsl_matrix_p(),0.0, result);
78 void calculate_matrix(gsl_matrix*& result, T)
const 80 YAT_ASSERT(detail::rows(result) == this->rows());
81 YAT_ASSERT(detail::columns(result) == this->columns());
82 for (
size_t i=0; i<rows(); ++i)
83 for (
size_t j=0; j<columns(); ++j)
84 gsl_matrix_set(result, i, j, (*
this)(i, j));
88 double get(
size_t row,
size_t column, Plus)
const 90 return lhs_(row, column) + rhs_(row, column);
94 double get(
size_t row,
size_t column, Minus)
const 96 return lhs_(row, column) - rhs_(row, column);
100 double get(
size_t row,
size_t column, Multiplies)
const 102 return gsl_matrix_get(this->gsl_matrix_p(), row, column);
108 class ScaledMatrix :
public MatrixExpression<ScaledMatrix<T> >
111 ScaledMatrix(
double factor,
const BasicMatrix<T>& A)
112 : A_(A), factor_(factor) {}
114 size_t rows(
void)
const {
return A_.rows(); }
115 size_t columns(
void)
const {
return A_.columns(); }
118 double operator()(
size_t i,
size_t j)
const 120 return factor_ * A_(i, j);
124 void calculate_matrix(gsl_matrix*& result)
const 126 detail::copy(result, A_.gsl_matrix_p());
127 gsl_matrix_scale(result, factor_);
131 const BasicMatrix<T>& A_;
137 class TransposedMatrix :
public MatrixExpression<TransposedMatrix<T> >
140 TransposedMatrix(
const BasicMatrix<T>& A)
143 size_t rows(
void)
const {
return A_.columns(); }
144 size_t columns(
void)
const {
return A_.rows(); }
146 double operator()(
size_t i,
size_t j)
const 151 void calculate_matrix(gsl_matrix*& result)
const 153 detail::reallocate(result, rows(), columns());
154 gsl_matrix_transpose_memcpy(result, A_.gsl_matrix_p());
158 const BasicMatrix<T>& A_;
178 template<
class Derived1,
class Derived2>
179 expression::MatrixBinary<Derived1, Derived2, expression::Plus>
182 return expression::MatrixBinary<Derived1, Derived2,
199 template<
class Derived1,
class Derived2>
200 expression::MatrixBinary<Derived1, Derived2, expression::Minus>
203 return expression::MatrixBinary<Derived1, Derived2,
217 template<
class Derived1,
class Derived2>
218 expression::MatrixBinary<Derived1, Derived2, expression::Multiplies>
222 return expression::MatrixBinary<Derived1, Derived2,
238 expression::ScaledMatrix<T>
241 return expression::ScaledMatrix<T>(k, A);
256 expression::ScaledMatrix<T>
259 return expression::ScaledMatrix<T>(k, A);
271 expression::ScaledMatrix<T>
274 return expression::ScaledMatrix<T>(-1.0, m);
288 return expression::TransposedMatrix<T>(A);
expression::ScaledMatrix< T > operator-(const BasicMatrix< T > &m)
negation operator
Definition: BLAS_level3.h:272
The Department of Theoretical Physics namespace as we define it.
size_t rows(void) const
Definition: BasicMatrix.h:61
Definition: BLAS_utility.h:32
expression::MatrixBinary< Derived1, Derived2, expression::Minus > operator-(const BasicMatrix< Derived1 > &lhs, const BasicMatrix< Derived2 > &rhs)
Matrix subtraction operator.
Definition: BLAS_level3.h:201
size_t columns(void) const
Definition: BasicMatrix.h:67
Definition: BLAS_utility.h:31
Definition: BLAS_utility.h:30
expression::ScaledMatrix< T > operator*(double k, const BasicMatrix< T > &A)
Definition: BLAS_level3.h:257
expression::ScaledMatrix< T > operator*(const BasicMatrix< T > &A, double k)
Definition: BLAS_level3.h:239
Definition: BasicMatrix.h:38
expression::MatrixBinary< Derived1, Derived2, expression::Multiplies > operator*(const BasicMatrix< Derived1 > &lhs, const BasicMatrix< Derived2 > &rhs)
Matrix multiplication operator.
Definition: BLAS_level3.h:219
expression::MatrixBinary< Derived1, Derived2, expression::Plus > operator+(const BasicMatrix< Derived1 > &lhs, const BasicMatrix< Derived2 > &rhs)
Matrix addition operator.
Definition: BLAS_level3.h:180
expression::TransposedMatrix< T > transpose(const BasicMatrix< T > &A)
transpose function
Definition: BLAS_level3.h:286