yat/classifier/NBC.cc

Code
Comments
Other
Rev Date Author Line
662 27 Sep 06 peter 1 // $Id$
662 27 Sep 06 peter 2
675 10 Oct 06 jari 3 /*
4359 23 Aug 23 peter 4   Copyright (C) 2006 Jari Häkkinen, Peter Johansson, Markus Ringnér
4359 23 Aug 23 peter 5   Copyright (C) 2007 Peter Johansson, Markus Ringnér
4359 23 Aug 23 peter 6   Copyright (C) 2008 Jari Häkkinen, Peter Johansson, Markus Ringnér
4359 23 Aug 23 peter 7   Copyright (C) 2012 Peter Johansson
662 27 Sep 06 peter 8
1437 25 Aug 08 peter 9   This file is part of the yat library, http://dev.thep.lu.se/yat
662 27 Sep 06 peter 10
675 10 Oct 06 jari 11   The yat library is free software; you can redistribute it and/or
675 10 Oct 06 jari 12   modify it under the terms of the GNU General Public License as
1486 09 Sep 08 jari 13   published by the Free Software Foundation; either version 3 of the
675 10 Oct 06 jari 14   License, or (at your option) any later version.
675 10 Oct 06 jari 15
675 10 Oct 06 jari 16   The yat library is distributed in the hope that it will be useful,
675 10 Oct 06 jari 17   but WITHOUT ANY WARRANTY; without even the implied warranty of
675 10 Oct 06 jari 18   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
675 10 Oct 06 jari 19   General Public License for more details.
675 10 Oct 06 jari 20
675 10 Oct 06 jari 21   You should have received a copy of the GNU General Public License
1487 10 Sep 08 jari 22   along with yat. If not, see <http://www.gnu.org/licenses/>.
675 10 Oct 06 jari 23 */
675 10 Oct 06 jari 24
2881 18 Nov 12 peter 25 #include <config.h>
2881 18 Nov 12 peter 26
680 11 Oct 06 jari 27 #include "NBC.h"
680 11 Oct 06 jari 28 #include "MatrixLookup.h"
680 11 Oct 06 jari 29 #include "MatrixLookupWeighted.h"
680 11 Oct 06 jari 30 #include "Target.h"
1160 26 Feb 08 markus 31 #include "yat/statistics/Averager.h"
675 10 Oct 06 jari 32 #include "yat/statistics/AveragerWeighted.h"
1121 22 Feb 08 peter 33 #include "yat/utility/Matrix.h"
2909 16 Dec 12 peter 34 #include "yat/utility/WeightedIterator.h"
675 10 Oct 06 jari 35
767 22 Feb 07 peter 36 #include <cassert>
812 16 Mar 07 peter 37 #include <cmath>
1437 25 Aug 08 peter 38 #include <limits>
963 10 Oct 07 peter 39 #include <stdexcept>
662 27 Sep 06 peter 40 #include <vector>
662 27 Sep 06 peter 41
662 27 Sep 06 peter 42 namespace theplu {
680 11 Oct 06 jari 43 namespace yat {
662 27 Sep 06 peter 44 namespace classifier {
662 27 Sep 06 peter 45
4200 19 Aug 22 peter 46   NBC::NBC()
1157 26 Feb 08 markus 47     : SupervisedClassifier()
662 27 Sep 06 peter 48   {
662 27 Sep 06 peter 49   }
662 27 Sep 06 peter 50
662 27 Sep 06 peter 51
4200 19 Aug 22 peter 52   NBC::~NBC()
662 27 Sep 06 peter 53   {
662 27 Sep 06 peter 54   }
662 27 Sep 06 peter 55
662 27 Sep 06 peter 56
4200 19 Aug 22 peter 57   NBC* NBC::make_classifier() const
4200 19 Aug 22 peter 58   {
1157 26 Feb 08 markus 59     return new NBC();
662 27 Sep 06 peter 60   }
662 27 Sep 06 peter 61
662 27 Sep 06 peter 62
1157 26 Feb 08 markus 63   void NBC::train(const MatrixLookup& data, const Target& target)
4200 19 Aug 22 peter 64   {
1157 26 Feb 08 markus 65     sigma2_.resize(data.rows(), target.nof_classes());
1157 26 Feb 08 markus 66     centroids_.resize(data.rows(), target.nof_classes());
4200 19 Aug 22 peter 67
1157 26 Feb 08 markus 68     for(size_t i=0; i<data.rows(); ++i) {
1157 26 Feb 08 markus 69       std::vector<statistics::Averager> aver(target.nof_classes());
4200 19 Aug 22 peter 70       for(size_t j=0; j<data.columns(); ++j)
1157 26 Feb 08 markus 71         aver[target(j)].add(data(i,j));
4200 19 Aug 22 peter 72
1157 26 Feb 08 markus 73       assert(centroids_.columns()==target.nof_classes());
1157 26 Feb 08 markus 74       for (size_t j=0; j<target.nof_classes(); ++j){
1157 26 Feb 08 markus 75         assert(i<centroids_.rows());
1157 26 Feb 08 markus 76         assert(j<centroids_.columns());
1157 26 Feb 08 markus 77         assert(i<sigma2_.rows());
1157 26 Feb 08 markus 78         assert(j<sigma2_.columns());
1157 26 Feb 08 markus 79         if (aver[j].n()>1){
1157 26 Feb 08 markus 80           sigma2_(i,j) = aver[j].variance();
1157 26 Feb 08 markus 81           centroids_(i,j) = aver[j].mean();
1157 26 Feb 08 markus 82         }
1184 28 Feb 08 peter 83         else {
960 10 Oct 07 peter 84             sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN();
1120 21 Feb 08 peter 85             centroids_(i,j) = std::numeric_limits<double>::quiet_NaN();
1184 28 Feb 08 peter 86         }
662 27 Sep 06 peter 87       }
960 10 Oct 07 peter 88     }
4200 19 Aug 22 peter 89   }
960 10 Oct 07 peter 90
1157 26 Feb 08 markus 91
1157 26 Feb 08 markus 92   void NBC::train(const MatrixLookupWeighted& data, const Target& target)
4200 19 Aug 22 peter 93   {
1157 26 Feb 08 markus 94     sigma2_.resize(data.rows(), target.nof_classes());
1157 26 Feb 08 markus 95     centroids_.resize(data.rows(), target.nof_classes());
1157 26 Feb 08 markus 96
1157 26 Feb 08 markus 97     for(size_t i=0; i<data.rows(); ++i) {
1157 26 Feb 08 markus 98       std::vector<statistics::AveragerWeighted> aver(target.nof_classes());
4200 19 Aug 22 peter 99       for(size_t j=0; j<data.columns(); ++j)
1157 26 Feb 08 markus 100         aver[target(j)].add(data.data(i,j), data.weight(i,j));
4200 19 Aug 22 peter 101
1157 26 Feb 08 markus 102       assert(centroids_.columns()==target.nof_classes());
1157 26 Feb 08 markus 103       for (size_t j=0; j<target.nof_classes(); ++j) {
1157 26 Feb 08 markus 104         assert(i<centroids_.rows());
1157 26 Feb 08 markus 105         assert(j<centroids_.columns());
1157 26 Feb 08 markus 106         assert(i<sigma2_.rows());
1157 26 Feb 08 markus 107         assert(j<sigma2_.columns());
1157 26 Feb 08 markus 108         if (aver[j].n()>1){
1157 26 Feb 08 markus 109           sigma2_(i,j) = aver[j].variance();
960 10 Oct 07 peter 110           centroids_(i,j) = aver[j].mean();
960 10 Oct 07 peter 111         }
1157 26 Feb 08 markus 112         else {
1157 26 Feb 08 markus 113           sigma2_(i,j) = std::numeric_limits<double>::quiet_NaN();
1157 26 Feb 08 markus 114           centroids_(i,j) = std::numeric_limits<double>::quiet_NaN();
1157 26 Feb 08 markus 115         }
662 27 Sep 06 peter 116       }
1157 26 Feb 08 markus 117     }
662 27 Sep 06 peter 118   }
662 27 Sep 06 peter 119
662 27 Sep 06 peter 120
4200 19 Aug 22 peter 121   void NBC::predict(const MatrixLookup& ml,
1121 22 Feb 08 peter 122                     utility::Matrix& prediction) const
4200 19 Aug 22 peter 123   {
1160 26 Feb 08 markus 124     assert(ml.rows()==sigma2_.rows());
1160 26 Feb 08 markus 125     assert(ml.rows()==centroids_.rows());
812 16 Mar 07 peter 126     // each row in prediction corresponds to a sample label (class)
1160 26 Feb 08 markus 127     prediction.resize(centroids_.columns(), ml.columns(), 0);
1160 26 Feb 08 markus 128
1160 26 Feb 08 markus 129     // first calculate -lnP = sum sigma_i + (x_i-m_i)^2/2sigma_i^2
1160 26 Feb 08 markus 130     for (size_t label=0; label<centroids_.columns(); ++label) {
1160 26 Feb 08 markus 131       double sum_log_sigma = sum_logsigma(label);
1160 26 Feb 08 markus 132       for (size_t sample=0; sample<prediction.rows(); ++sample) {
1160 26 Feb 08 markus 133         prediction(label,sample) = sum_log_sigma;
4200 19 Aug 22 peter 134         for (size_t i=0; i<ml.rows(); ++i)
4200 19 Aug 22 peter 135           prediction(label, sample) +=
1184 28 Feb 08 peter 136             std::pow(ml(i, label)-centroids_(i, label),2)/
1184 28 Feb 08 peter 137             sigma2_(i, label);
959 10 Oct 07 peter 138       }
959 10 Oct 07 peter 139     }
1160 26 Feb 08 markus 140     standardize_lnP(prediction);
1160 26 Feb 08 markus 141   }
1160 26 Feb 08 markus 142
4200 19 Aug 22 peter 143
4200 19 Aug 22 peter 144   void NBC::predict(const MatrixLookupWeighted& mlw,
1160 26 Feb 08 markus 145                     utility::Matrix& prediction) const
4200 19 Aug 22 peter 146   {
1160 26 Feb 08 markus 147     assert(mlw.rows()==sigma2_.rows());
1160 26 Feb 08 markus 148     assert(mlw.rows()==centroids_.rows());
4200 19 Aug 22 peter 149
1160 26 Feb 08 markus 150     // each row in prediction corresponds to a sample label (class)
1160 26 Feb 08 markus 151     prediction.resize(centroids_.columns(), mlw.columns(), 0);
1160 26 Feb 08 markus 152
4200 19 Aug 22 peter 153     // first calculate -lnP = sum (sigma_i) +
1182 28 Feb 08 peter 154     // N sum w_i(x_i-m_i)^2/2sigma_i^2 / sum w_i
1160 26 Feb 08 markus 155     for (size_t label=0; label<centroids_.columns(); ++label) {
1160 26 Feb 08 markus 156       double sum_log_sigma = sum_logsigma(label);
1160 26 Feb 08 markus 157       for (size_t sample=0; sample<prediction.rows(); ++sample) {
1182 28 Feb 08 peter 158         statistics::AveragerWeighted aw;
4200 19 Aug 22 peter 159         for (size_t i=0; i<mlw.rows(); ++i)
1184 28 Feb 08 peter 160           aw.add(std::pow(mlw.data(i, label)-centroids_(i, label),2)/
1184 28 Feb 08 peter 161                  sigma2_(i, label), mlw.weight(i, label));
1182 28 Feb 08 peter 162         prediction(label,sample) = sum_log_sigma + mlw.rows()*aw.mean()/2;
767 22 Feb 07 peter 163       }
767 22 Feb 07 peter 164     }
1160 26 Feb 08 markus 165     standardize_lnP(prediction);
1160 26 Feb 08 markus 166   }
812 16 Mar 07 peter 167
1160 26 Feb 08 markus 168   void NBC::standardize_lnP(utility::Matrix& prediction) const
1160 26 Feb 08 markus 169   {
1184 28 Feb 08 peter 170     /// -lnP might be a large number, in order to avoid out of bound
1184 28 Feb 08 peter 171     /// problems when calculating P = exp(- -lnP), we centralize matrix
1184 28 Feb 08 peter 172     /// by adding a constant.
2909 16 Dec 12 peter 173     utility::Matrix weights;
2909 16 Dec 12 peter 174     // create zero/unity weight matrix (w=0 if NaN)
2909 16 Dec 12 peter 175     nan(prediction, weights);
2909 16 Dec 12 peter 176     using utility::weighted_iterator;
1184 28 Feb 08 peter 177     statistics::AveragerWeighted a;
2909 16 Dec 12 peter 178     add(a, weighted_iterator(prediction.begin(), weights.begin()),
2909 16 Dec 12 peter 179         weighted_iterator(prediction.end(), weights.end()));
1120 21 Feb 08 peter 180     prediction -= a.mean();
2909 16 Dec 12 peter 181
812 16 Mar 07 peter 182     // exponentiate
812 16 Mar 07 peter 183     for (size_t i=0; i<prediction.rows(); ++i)
812 16 Mar 07 peter 184       for (size_t j=0; j<prediction.columns(); ++j)
812 16 Mar 07 peter 185         prediction(i,j) = std::exp(prediction(i,j));
2909 16 Dec 12 peter 186
812 16 Mar 07 peter 187     // normalize each row (label) to sum up to unity (probability)
1009 01 Feb 08 peter 188     for (size_t i=0; i<prediction.rows(); ++i){
1184 28 Feb 08 peter 189       // calculate sum of row ignoring NaNs
1184 28 Feb 08 peter 190       statistics::AveragerWeighted a;
2909 16 Dec 12 peter 191       add(a, weighted_iterator(prediction.begin_row(i), weights.begin_row(i)),
2909 16 Dec 12 peter 192           weighted_iterator(prediction.end_row(i), weights.end_row(i)));
1184 28 Feb 08 peter 193       prediction.row_view(i) *= 1.0/a.sum_wx();
1009 01 Feb 08 peter 194     }
662 27 Sep 06 peter 195   }
662 27 Sep 06 peter 196
662 27 Sep 06 peter 197
959 10 Oct 07 peter 198   double NBC::sum_logsigma(size_t label) const
959 10 Oct 07 peter 199   {
959 10 Oct 07 peter 200     double sum_log_sigma=0;
959 10 Oct 07 peter 201     assert(label<sigma2_.columns());
959 10 Oct 07 peter 202     for (size_t i=0; i<sigma2_.rows(); ++i) {
1184 28 Feb 08 peter 203       sum_log_sigma += std::log(sigma2_(i, label));
959 10 Oct 07 peter 204     }
959 10 Oct 07 peter 205     return sum_log_sigma / 2; // taking sum of log(sigma) not sigma2
959 10 Oct 07 peter 206   }
959 10 Oct 07 peter 207
680 11 Oct 06 jari 208 }}} // of namespace classifier, yat, and theplu