yat/classifier/EnsembleBuilder.h

Code
Comments
Other
Rev Date Author Line
4200 19 Aug 22 peter 1 #ifndef _theplu_yat_classifier_ensemblebuilder_
4200 19 Aug 22 peter 2 #define _theplu_yat_classifier_ensemblebuilder_
481 22 Dec 05 markus 3
675 10 Oct 06 jari 4 // $Id$
481 22 Dec 05 markus 5
675 10 Oct 06 jari 6 /*
2119 12 Dec 09 peter 7   Copyright (C) 2005 Markus Ringnér
2119 12 Dec 09 peter 8   Copyright (C) 2006 Jari Häkkinen, Peter Johansson, Markus Ringnér
4359 23 Aug 23 peter 9   Copyright (C) 2007 Peter Johansson
2119 12 Dec 09 peter 10   Copyright (C) 2008 Jari Häkkinen, Peter Johansson, Markus Ringnér
2121 13 Dec 09 peter 11   Copyright (C) 2009 Jari Häkkinen, Peter Johansson
4359 23 Aug 23 peter 12   Copyright (C) 2010 Peter Johansson
675 10 Oct 06 jari 13
1437 25 Aug 08 peter 14   This file is part of the yat library, http://dev.thep.lu.se/yat
675 10 Oct 06 jari 15
675 10 Oct 06 jari 16   The yat library is free software; you can redistribute it and/or
675 10 Oct 06 jari 17   modify it under the terms of the GNU General Public License as
1486 09 Sep 08 jari 18   published by the Free Software Foundation; either version 3 of the
675 10 Oct 06 jari 19   License, or (at your option) any later version.
675 10 Oct 06 jari 20
675 10 Oct 06 jari 21   The yat library is distributed in the hope that it will be useful,
675 10 Oct 06 jari 22   but WITHOUT ANY WARRANTY; without even the implied warranty of
675 10 Oct 06 jari 23   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
675 10 Oct 06 jari 24   General Public License for more details.
675 10 Oct 06 jari 25
675 10 Oct 06 jari 26   You should have received a copy of the GNU General Public License
1487 10 Sep 08 jari 27   along with yat. If not, see <http://www.gnu.org/licenses/>.
675 10 Oct 06 jari 28 */
675 10 Oct 06 jari 29
1088 14 Feb 08 peter 30 #include "FeatureSelector.h"
1079 13 Feb 08 peter 31 #include "Sampler.h"
1072 12 Feb 08 peter 32 #include "SubsetGenerator.h"
677 10 Oct 06 jari 33 #include "yat/statistics/Averager.h"
1121 22 Feb 08 peter 34 #include "yat/utility/Matrix.h"
2077 07 Oct 09 peter 35 #include "yat/utility/yat_assert.h"
675 10 Oct 06 jari 36
485 04 Jan 06 markus 37 #include <vector>
485 04 Jan 06 markus 38
481 22 Dec 05 markus 39 namespace theplu {
680 11 Oct 06 jari 40 namespace yat {
4200 19 Aug 22 peter 41 namespace classifier {
481 22 Dec 05 markus 42
481 22 Dec 05 markus 43   ///
767 22 Feb 07 peter 44   /// @brief Class for ensembles of supervised classifiers
481 22 Dec 05 markus 45   ///
1079 13 Feb 08 peter 46   template <class Classifier, class Data>
481 22 Dec 05 markus 47   class EnsembleBuilder
481 22 Dec 05 markus 48   {
481 22 Dec 05 markus 49   public:
1125 22 Feb 08 peter 50     /**
1125 22 Feb 08 peter 51        \brief Type of classifier that ensemble is built on.
1125 22 Feb 08 peter 52      */
1079 13 Feb 08 peter 53     typedef Classifier classifier_type;
1125 22 Feb 08 peter 54
1125 22 Feb 08 peter 55     /**
1221 11 Mar 08 peter 56        Type of container used for storing data. Must be MatrixLookup,
1221 11 Mar 08 peter 57        MatrixLookupWeighted, or KernelLookup
1125 22 Feb 08 peter 58      */
1079 13 Feb 08 peter 59     typedef Data data_type;
1079 13 Feb 08 peter 60
481 22 Dec 05 markus 61     ///
481 22 Dec 05 markus 62     /// Constructor.
481 22 Dec 05 markus 63     ///
1087 14 Feb 08 peter 64     EnsembleBuilder(const Classifier&, const Data&, const Sampler&);
481 22 Dec 05 markus 65
736 06 Jan 07 peter 66     ///
736 06 Jan 07 peter 67     /// Constructor.
736 06 Jan 07 peter 68     ///
4200 19 Aug 22 peter 69     EnsembleBuilder(const Classifier&, const Data&, const Sampler&,
736 06 Jan 07 peter 70                     FeatureSelector&);
736 06 Jan 07 peter 71
481 22 Dec 05 markus 72     ///
485 04 Jan 06 markus 73     /// Destructor.
485 04 Jan 06 markus 74     ///
505 02 Feb 06 markus 75     virtual ~EnsembleBuilder(void);
485 04 Jan 06 markus 76
1227 13 Mar 08 peter 77     /**
4200 19 Aug 22 peter 78        \brief Generate ensemble.
4200 19 Aug 22 peter 79
1227 13 Mar 08 peter 80        Function trains each member of the Ensemble.
1227 13 Mar 08 peter 81     */
481 22 Dec 05 markus 82     void build(void);
481 22 Dec 05 markus 83
505 02 Feb 06 markus 84     ///
1227 13 Mar 08 peter 85     /// @return ith classifier
505 02 Feb 06 markus 86     ///
1079 13 Feb 08 peter 87     const Classifier& classifier(size_t i) const;
4200 19 Aug 22 peter 88
505 02 Feb 06 markus 89     ///
1227 13 Mar 08 peter 90     /// @return Number of classifiers in ensemble. Prior build(void)
1227 13 Mar 08 peter 91     /// is issued size is zero.
505 02 Feb 06 markus 92     ///
1273 10 Apr 08 jari 93     unsigned long size(void) const;
505 02 Feb 06 markus 94
505 02 Feb 06 markus 95     ///
553 07 Mar 06 peter 96     /// @brief Generate validation data for ensemble
505 02 Feb 06 markus 97     ///
553 07 Mar 06 peter 98     /// validate()[i][j] return averager for class @a i for sample @a j
553 07 Mar 06 peter 99     ///
505 02 Feb 06 markus 100     const std::vector<std::vector<statistics::Averager> >& validate(void);
4200 19 Aug 22 peter 101
655 22 Sep 06 peter 102     /**
655 22 Sep 06 peter 103        Predict a dataset using the ensemble.
4200 19 Aug 22 peter 104
655 22 Sep 06 peter 105        If @a data is a KernelLookup each column should correspond to a
655 22 Sep 06 peter 106        test sample and each row should correspond to a training
655 22 Sep 06 peter 107        sample. More exactly row \f$ i \f$ in @a data should correspond
655 22 Sep 06 peter 108        to the same sample as row/column \f$ i \f$ in the training
655 22 Sep 06 peter 109        kernel corresponds to.
655 22 Sep 06 peter 110     */
4200 19 Aug 22 peter 111     void predict(const Data& data,
720 26 Dec 06 jari 112                  std::vector<std::vector<statistics::Averager> > &);
505 02 Feb 06 markus 113
481 22 Dec 05 markus 114   private:
1079 13 Feb 08 peter 115     // no copying
722 27 Dec 06 markus 116     EnsembleBuilder(const EnsembleBuilder&);
722 27 Dec 06 markus 117     const EnsembleBuilder& operator=(const EnsembleBuilder&);
722 27 Dec 06 markus 118
4200 19 Aug 22 peter 119
1079 13 Feb 08 peter 120     const Classifier& mother_;
1079 13 Feb 08 peter 121     SubsetGenerator<Data>* subset_;
1079 13 Feb 08 peter 122     std::vector<Classifier*> classifier_;
1206 05 Mar 08 peter 123     KernelLookup test_data(const KernelLookup&, size_t k);
1206 05 Mar 08 peter 124     MatrixLookup test_data(const MatrixLookup&, size_t k);
1206 05 Mar 08 peter 125     MatrixLookupWeighted test_data(const MatrixLookupWeighted&, size_t k);
505 02 Feb 06 markus 126     std::vector<std::vector<statistics::Averager> > validation_result_;
481 22 Dec 05 markus 127
481 22 Dec 05 markus 128   };
1079 13 Feb 08 peter 129
4200 19 Aug 22 peter 130
1079 13 Feb 08 peter 131   // implementation
1079 13 Feb 08 peter 132
1954 07 May 09 jari 133   template <class Classifier, class Data>
1954 07 May 09 jari 134   EnsembleBuilder<Classifier, Data>::EnsembleBuilder(const Classifier& sc,
1954 07 May 09 jari 135                                                      const Data& data,
1954 07 May 09 jari 136                                                      const Sampler& sampler)
1954 07 May 09 jari 137     : mother_(sc),subset_(new SubsetGenerator<Data>(sampler,data))
1079 13 Feb 08 peter 138   {
1079 13 Feb 08 peter 139   }
1079 13 Feb 08 peter 140
1079 13 Feb 08 peter 141
1954 07 May 09 jari 142   template <class Classifier, class Data>
1954 07 May 09 jari 143   EnsembleBuilder<Classifier, Data>::EnsembleBuilder(const Classifier& sc,
1954 07 May 09 jari 144                                                      const Data& data,
1954 07 May 09 jari 145                                                      const Sampler& sampler,
1954 07 May 09 jari 146                                                      FeatureSelector& fs)
1079 13 Feb 08 peter 147     : mother_(sc),
1954 07 May 09 jari 148       subset_(new SubsetGenerator<Data>(sampler,data,fs))
1079 13 Feb 08 peter 149   {
1079 13 Feb 08 peter 150   }
1079 13 Feb 08 peter 151
1079 13 Feb 08 peter 152
1954 07 May 09 jari 153   template <class Classifier, class Data>
1954 07 May 09 jari 154   EnsembleBuilder<Classifier, Data>::~EnsembleBuilder(void)
1079 13 Feb 08 peter 155   {
1079 13 Feb 08 peter 156     for(size_t i=0; i<classifier_.size(); i++)
1079 13 Feb 08 peter 157       delete classifier_[i];
1079 13 Feb 08 peter 158     delete subset_;
1079 13 Feb 08 peter 159   }
1079 13 Feb 08 peter 160
1079 13 Feb 08 peter 161
1954 07 May 09 jari 162   template <class Classifier, class Data>
1954 07 May 09 jari 163   void EnsembleBuilder<Classifier, Data>::build(void)
1079 13 Feb 08 peter 164   {
1227 13 Mar 08 peter 165     if (classifier_.empty()){
1273 10 Apr 08 jari 166       for(unsigned long i=0; i<subset_->size();++i) {
1954 07 May 09 jari 167         Classifier* classifier = mother_.make_classifier();
4200 19 Aug 22 peter 168         classifier->train(subset_->training_data(i),
1227 13 Mar 08 peter 169                           subset_->training_target(i));
1227 13 Mar 08 peter 170         classifier_.push_back(classifier);
4200 19 Aug 22 peter 171       }
1227 13 Mar 08 peter 172     }
1079 13 Feb 08 peter 173   }
1079 13 Feb 08 peter 174
1079 13 Feb 08 peter 175
1954 07 May 09 jari 176   template <class Classifier, class Data>
1954 07 May 09 jari 177   const Classifier& EnsembleBuilder<Classifier, Data>::classifier(size_t i) const
1079 13 Feb 08 peter 178   {
1079 13 Feb 08 peter 179     return *(classifier_[i]);
1079 13 Feb 08 peter 180   }
1079 13 Feb 08 peter 181
1079 13 Feb 08 peter 182
1954 07 May 09 jari 183   template <class Classifier, class Data>
1954 07 May 09 jari 184   void EnsembleBuilder<Classifier, Data>::predict
1954 07 May 09 jari 185   (const Data& data, std::vector<std::vector<statistics::Averager> >& result)
1079 13 Feb 08 peter 186   {
1227 13 Mar 08 peter 187     result = std::vector<std::vector<statistics::Averager> >
4200 19 Aug 22 peter 188       (subset_->target().nof_classes(),
1227 13 Mar 08 peter 189        std::vector<statistics::Averager>(data.columns()));
1079 13 Feb 08 peter 190
4200 19 Aug 22 peter 191     utility::Matrix prediction;
4200 19 Aug 22 peter 192
4200 19 Aug 22 peter 193     for(unsigned long k=0;k<size();++k) {
1954 07 May 09 jari 194       Data sub_data = test_data(data, k);
1206 05 Mar 08 peter 195       classifier(k).predict(sub_data,prediction);
4200 19 Aug 22 peter 196       for(size_t i=0; i<prediction.rows();i++)
4200 19 Aug 22 peter 197         for(size_t j=0; j<prediction.columns();j++)
4200 19 Aug 22 peter 198           result[i][j].add(prediction(i,j));
1079 13 Feb 08 peter 199     }
1079 13 Feb 08 peter 200   }
1079 13 Feb 08 peter 201
4200 19 Aug 22 peter 202
1954 07 May 09 jari 203   template <class Classifier, class Data>
1954 07 May 09 jari 204   unsigned long EnsembleBuilder<Classifier, Data>::size(void) const
1206 05 Mar 08 peter 205   {
1206 05 Mar 08 peter 206     return classifier_.size();
1206 05 Mar 08 peter 207   }
1206 05 Mar 08 peter 208
1206 05 Mar 08 peter 209
1954 07 May 09 jari 210   template <class Classifier, class Data>
1954 07 May 09 jari 211   MatrixLookup EnsembleBuilder<Classifier,
1954 07 May 09 jari 212                                Data>::test_data(const MatrixLookup& data,
1206 05 Mar 08 peter 213                                                 size_t k)
1206 05 Mar 08 peter 214   {
1206 05 Mar 08 peter 215     return MatrixLookup(data, subset_->training_features(k), true);
1206 05 Mar 08 peter 216   }
1206 05 Mar 08 peter 217
4200 19 Aug 22 peter 218
1954 07 May 09 jari 219   template <class Classifier, class Data>
4200 19 Aug 22 peter 220   MatrixLookupWeighted
1954 07 May 09 jari 221   EnsembleBuilder<Classifier, Data>::test_data(const MatrixLookupWeighted& data,
1954 07 May 09 jari 222                                                size_t k)
1206 05 Mar 08 peter 223   {
4200 19 Aug 22 peter 224     return MatrixLookupWeighted(data, subset_->training_features(k),
2226 24 Mar 10 peter 225                                 utility::Index(data.columns()));
1206 05 Mar 08 peter 226   }
1206 05 Mar 08 peter 227
4200 19 Aug 22 peter 228
1954 07 May 09 jari 229   template <class Classifier, class Data>
1206 05 Mar 08 peter 230   KernelLookup
1954 07 May 09 jari 231   EnsembleBuilder<Classifier, Data>::test_data(const KernelLookup& kernel,
1954 07 May 09 jari 232                                                size_t k)
1206 05 Mar 08 peter 233   {
1206 05 Mar 08 peter 234     // weighted case
1206 05 Mar 08 peter 235     if (kernel.weighted()){
1206 05 Mar 08 peter 236       // no feature selection
1206 05 Mar 08 peter 237       if (kernel.data_weighted().rows()==subset_->training_features(k).size())
1206 05 Mar 08 peter 238         return KernelLookup(kernel, subset_->training_index(k), true);
1206 05 Mar 08 peter 239       MatrixLookupWeighted mlw = test_data(kernel.data_weighted(), k);
1206 05 Mar 08 peter 240       return subset_->training_data(k).test_kernel(mlw);
1206 05 Mar 08 peter 241
1206 05 Mar 08 peter 242     }
1206 05 Mar 08 peter 243     // unweighted case
1206 05 Mar 08 peter 244
1206 05 Mar 08 peter 245     // no feature selection
1206 05 Mar 08 peter 246     if (kernel.data().rows()==subset_->training_features(k).size())
1206 05 Mar 08 peter 247       return KernelLookup(kernel, subset_->training_index(k), true);
4200 19 Aug 22 peter 248
1206 05 Mar 08 peter 249     // feature selection
2138 24 Dec 09 peter 250     MatrixLookup ml = test_data(kernel.data(),k);
2138 24 Dec 09 peter 251     return subset_->training_data(k).test_kernel(ml);
1206 05 Mar 08 peter 252   }
1206 05 Mar 08 peter 253
4200 19 Aug 22 peter 254
1954 07 May 09 jari 255   template <class Classifier, class Data>
4200 19 Aug 22 peter 256   const std::vector<std::vector<statistics::Averager> >&
1954 07 May 09 jari 257   EnsembleBuilder<Classifier, Data>::validate(void)
1079 13 Feb 08 peter 258   {
1227 13 Mar 08 peter 259     // Don't recalculate validation_result_
1227 13 Mar 08 peter 260     if (!validation_result_.empty())
1227 13 Mar 08 peter 261       return validation_result_;
1079 13 Feb 08 peter 262
1227 13 Mar 08 peter 263     validation_result_ = std::vector<std::vector<statistics::Averager> >
4200 19 Aug 22 peter 264       (subset_->target().nof_classes(),
1227 13 Mar 08 peter 265        std::vector<statistics::Averager>(subset_->target().size()));
1227 13 Mar 08 peter 266
4200 19 Aug 22 peter 267     utility::Matrix prediction;
1273 10 Apr 08 jari 268     for(unsigned long k=0;k<size();k++) {
1079 13 Feb 08 peter 269       classifier(k).predict(subset_->validation_data(k),prediction);
4200 19 Aug 22 peter 270
1079 13 Feb 08 peter 271       // map results to indices of samples in training + validation data set
4200 19 Aug 22 peter 272        for(size_t i=0; i<prediction.rows();i++)
1079 13 Feb 08 peter 273          for(size_t j=0; j<prediction.columns();j++) {
1079 13 Feb 08 peter 274            validation_result_[i][subset_->validation_index(k)[j]].
1079 13 Feb 08 peter 275             add(prediction(i,j));
4200 19 Aug 22 peter 276          }
1079 13 Feb 08 peter 277     }
1079 13 Feb 08 peter 278     return validation_result_;
1079 13 Feb 08 peter 279   }
1079 13 Feb 08 peter 280
680 11 Oct 06 jari 281 }}} // of namespace classifier, yat, and theplu
481 22 Dec 05 markus 282
481 22 Dec 05 markus 283 #endif