yat/classifier/Perceptron.cc

Code
Comments
Other
Rev Date Author Line
3709 08 Nov 17 peter 1 // $Id$
3709 08 Nov 17 peter 2
3709 08 Nov 17 peter 3 /*
4207 26 Aug 22 peter 4   Copyright (C) 2017, 2021, 2022 Peter Johansson
3709 08 Nov 17 peter 5
3709 08 Nov 17 peter 6   This file is part of the yat library, http://dev.thep.lu.se/yat
3709 08 Nov 17 peter 7
3709 08 Nov 17 peter 8   The yat library is free software; you can redistribute it and/or
3709 08 Nov 17 peter 9   modify it under the terms of the GNU General Public License as
3709 08 Nov 17 peter 10   published by the Free Software Foundation; either version 3 of the
3709 08 Nov 17 peter 11   License, or (at your option) any later version.
3709 08 Nov 17 peter 12
3709 08 Nov 17 peter 13   The yat library is distributed in the hope that it will be useful,
3709 08 Nov 17 peter 14   but WITHOUT ANY WARRANTY; without even the implied warranty of
3709 08 Nov 17 peter 15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
3709 08 Nov 17 peter 16   General Public License for more details.
3709 08 Nov 17 peter 17
3709 08 Nov 17 peter 18   You should have received a copy of the GNU General Public License
3709 08 Nov 17 peter 19   along with yat. If not, see <http://www.gnu.org/licenses/>.
3709 08 Nov 17 peter 20 */
3709 08 Nov 17 peter 21
3709 08 Nov 17 peter 22 #include <config.h>
3709 08 Nov 17 peter 23
3709 08 Nov 17 peter 24 #include "Perceptron.h"
3709 08 Nov 17 peter 25
3709 08 Nov 17 peter 26 #include "Target.h"
3709 08 Nov 17 peter 27
3709 08 Nov 17 peter 28 #include "yat/utility/DiagonalMatrix.h"
3709 08 Nov 17 peter 29 #include "yat/utility/Matrix.h"
3709 08 Nov 17 peter 30 #include "yat/utility/Vector.h"
3709 08 Nov 17 peter 31
3709 08 Nov 17 peter 32 #include <gsl/gsl_cdf.h>
3709 08 Nov 17 peter 33
3709 08 Nov 17 peter 34 #include <cassert>
3709 08 Nov 17 peter 35 #include <cmath>
3709 08 Nov 17 peter 36 #include <cmath>
3709 08 Nov 17 peter 37
3709 08 Nov 17 peter 38 namespace theplu {
3709 08 Nov 17 peter 39 namespace yat {
3709 08 Nov 17 peter 40 namespace classifier {
3709 08 Nov 17 peter 41
3709 08 Nov 17 peter 42   const utility::Matrix& Perceptron::covariance(void) const
3709 08 Nov 17 peter 43   {
3709 08 Nov 17 peter 44     return covariance_;
3709 08 Nov 17 peter 45   }
3709 08 Nov 17 peter 46
3709 08 Nov 17 peter 47
3709 08 Nov 17 peter 48   double Perceptron::margin(size_t i, double alpha) const
3709 08 Nov 17 peter 49   {
3709 08 Nov 17 peter 50     return gsl_cdf_ugaussian_Qinv(alpha/2) * std::sqrt(covariance_(i, i));
3709 08 Nov 17 peter 51   }
3709 08 Nov 17 peter 52
3709 08 Nov 17 peter 53
3709 08 Nov 17 peter 54   double Perceptron::oddsratio(size_t i) const
3709 08 Nov 17 peter 55   {
3709 08 Nov 17 peter 56     return std::exp(weight_(i));
3709 08 Nov 17 peter 57   }
3709 08 Nov 17 peter 58
3709 08 Nov 17 peter 59
3709 08 Nov 17 peter 60   double Perceptron::oddsratio_lower_CI(size_t i, double alpha) const
3709 08 Nov 17 peter 61   {
3709 08 Nov 17 peter 62     return std::exp(weight_(i) - margin(i, alpha));
3709 08 Nov 17 peter 63   }
3709 08 Nov 17 peter 64
3709 08 Nov 17 peter 65
3709 08 Nov 17 peter 66   double Perceptron::oddsratio_upper_CI(size_t i, double alpha) const
3709 08 Nov 17 peter 67   {
3709 08 Nov 17 peter 68     return std::exp(weight_(i) + margin(i, alpha));
3709 08 Nov 17 peter 69   }
3709 08 Nov 17 peter 70
3709 08 Nov 17 peter 71
3709 08 Nov 17 peter 72   double Perceptron::p_value(size_t i) const
3709 08 Nov 17 peter 73   {
3709 08 Nov 17 peter 74     double z = weight_(i) / std::sqrt(covariance_(i, i));
3709 08 Nov 17 peter 75     return 2*gsl_cdf_ugaussian_Q(std::abs(z));
3709 08 Nov 17 peter 76   }
3709 08 Nov 17 peter 77
3709 08 Nov 17 peter 78
3709 08 Nov 17 peter 79   double Perceptron::predict(const utility::VectorBase& x) const
3709 08 Nov 17 peter 80   {
3709 08 Nov 17 peter 81     assert(x.size() == weight_.size());
3709 08 Nov 17 peter 82     const double f = weight_ * x;
3709 08 Nov 17 peter 83     return 1.0 / (1 + std::exp(-f));
3709 08 Nov 17 peter 84   }
3709 08 Nov 17 peter 85
3709 08 Nov 17 peter 86
4125 14 Jan 22 peter 87   void Perceptron::train(const utility::MatrixBase& X, const Target& target)
3709 08 Nov 17 peter 88   {
3709 08 Nov 17 peter 89     size_t n = X.rows();
3709 08 Nov 17 peter 90     size_t p = X.columns();
3709 08 Nov 17 peter 91
3709 08 Nov 17 peter 92     assert(target.size() == n);
3709 08 Nov 17 peter 93     weight_.resize(p);
3709 08 Nov 17 peter 94     covariance_.resize(p, p);
3709 08 Nov 17 peter 95
3709 08 Nov 17 peter 96     // weight vector is updated as
3709 08 Nov 17 peter 97     // w = (X'SX)^-1 X' (SXw + y - mu)
3709 08 Nov 17 peter 98     // X is n x p
3709 08 Nov 17 peter 99     // mu is vector of (trained) expected values (see predict(1))
3709 08 Nov 17 peter 100     utility::Vector mu(n);
3709 08 Nov 17 peter 101     // S is diagonal n x n with S_ii = mu_i (1 - mu_i)
3709 08 Nov 17 peter 102     utility::DiagonalMatrix S(n, n);
3709 08 Nov 17 peter 103     // y is binary vector
3709 08 Nov 17 peter 104     utility::Vector y(n);
3709 08 Nov 17 peter 105     for (size_t i=0; i<n; ++i)
3709 08 Nov 17 peter 106       if (target.binary(i))
3709 08 Nov 17 peter 107         y(i) = 1.0;
3709 08 Nov 17 peter 108
4052 26 Mar 21 peter 109     // We use the Iteratively Rewighted Least Square algorithm as described
4052 26 Mar 21 peter 110     // https://en.wikipedia.org/wiki/Logistic_regression
4052 26 Mar 21 peter 111
3709 08 Nov 17 peter 112     size_t max_epochs = 100;
3709 08 Nov 17 peter 113     double sum_squared = 1.0; // some (relatively) large number
3709 08 Nov 17 peter 114     for (size_t epoch=0; sum_squared > 1e-8 && epoch < max_epochs; ++epoch) {
3709 08 Nov 17 peter 115       for (size_t i=0; i<mu.size(); ++i) {
3709 08 Nov 17 peter 116         mu(i) = predict(X.row_const_view(i));
3709 08 Nov 17 peter 117         S(i) = mu(i) * (1.0 - mu(i));
3709 08 Nov 17 peter 118       }
3709 08 Nov 17 peter 119
3709 08 Nov 17 peter 120       // w = (X'SX)^-1 X' (SXw + y - mu)
3709 08 Nov 17 peter 121       assert(X.rows() == S.rows());
3709 08 Nov 17 peter 122       assert(S.columns() == X.rows());
4140 29 Jan 22 peter 123       utility::inverse_svd(utility::Matrix(transpose(X)*S*X), covariance_);
3709 08 Nov 17 peter 124
3709 08 Nov 17 peter 125       assert(y.size() == mu.size());
3709 08 Nov 17 peter 126       utility::Vector delta = covariance_ * (transpose(X) * (y - mu));
3709 08 Nov 17 peter 127       weight_ += delta;
3709 08 Nov 17 peter 128       sum_squared = delta * delta;
3709 08 Nov 17 peter 129     }
3709 08 Nov 17 peter 130   }
3709 08 Nov 17 peter 131
3709 08 Nov 17 peter 132
3709 08 Nov 17 peter 133   const utility::Vector& Perceptron::weight(void) const
3709 08 Nov 17 peter 134   {
3709 08 Nov 17 peter 135     return weight_;
3709 08 Nov 17 peter 136   }
3709 08 Nov 17 peter 137
3709 08 Nov 17 peter 138
3709 08 Nov 17 peter 139 }}}