yat  0.8.3pre
NCC.h
00001 #ifndef _theplu_yat_classifier_ncc_ 
00002 #define _theplu_yat_classifier_ncc_ 
00003 
00004 // $Id: NCC.h 2384 2010-12-22 14:03:36Z peter $
00005 
00006 /*
00007   Copyright (C) 2005 Peter Johansson, Markus Ringnér
00008   Copyright (C) 2006, 2007, 2008 Jari Häkkinen, Peter Johansson, Markus Ringnér
00009   Copyright (C) 2009, 2010 Peter Johansson
00010 
00011   This file is part of the yat library, http://dev.thep.lu.se/yat
00012 
00013   The yat library is free software; you can redistribute it and/or
00014   modify it under the terms of the GNU General Public License as
00015   published by the Free Software Foundation; either version 3 of the
00016   License, or (at your option) any later version.
00017 
00018   The yat library is distributed in the hope that it will be useful,
00019   but WITHOUT ANY WARRANTY; without even the implied warranty of
00020   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
00021   General Public License for more details.
00022 
00023   You should have received a copy of the GNU General Public License
00024   along with yat. If not, see <http://www.gnu.org/licenses/>.
00025 */
00026 
00027 #include "MatrixLookup.h"
00028 #include "MatrixLookupWeighted.h"
00029 #include "SupervisedClassifier.h"
00030 #include "Target.h"
00031 
00032 #include "yat/statistics/Averager.h"
00033 #include "yat/statistics/AveragerWeighted.h"
00034 #include "yat/utility/concept_check.h"
00035 #include "yat/utility/Exception.h"
00036 #include "yat/utility/Matrix.h"
00037 #include "yat/utility/MatrixWeighted.h"
00038 #include "yat/utility/Vector.h"
00039 #include "yat/utility/stl_utility.h"
00040 #include "yat/utility/yat_assert.h"
00041 
00042 #include <boost/concept_check.hpp>
00043 
00044 #include <iterator>
00045 #include <map>
00046 #include <cmath>
00047 
00048 namespace theplu {
00049 namespace yat {
00050 namespace classifier {  
00051 
00052 
00064   template <typename Distance>
00065   class NCC : public SupervisedClassifier
00066   {
00067   
00068   public:
00074     NCC(void);
00075     
00083     NCC(const Distance&);
00084 
00085 
00089     virtual ~NCC(void);
00090 
00096     const utility::Matrix& centroids(void) const;
00097 
00098     NCC<Distance>* make_classifier(void) const;
00099         
00117     void predict(const MatrixLookup& data, utility::Matrix& results) const;
00118     
00140     void predict(const MatrixLookupWeighted& data, utility::Matrix& results) const;
00141 
00150     void train(const MatrixLookup& data, const Target& targets);
00151 
00152 
00165     void train(const MatrixLookupWeighted& data, const Target& targets);
00166 
00167     
00168   private:
00169 
00170     void predict_unweighted(const MatrixLookup&, utility::Matrix&) const;
00171     void predict_weighted(const MatrixLookupWeighted&, utility::Matrix&) const;   
00172 
00173     utility::Matrix centroids_;
00174     bool centroids_nan_;
00175     Distance distance_;
00176   };  
00177 
00178   // templates
00179 
00180   template <typename Distance>
00181   NCC<Distance>::NCC() 
00182     : SupervisedClassifier(), centroids_nan_(false)
00183   {
00184     BOOST_CONCEPT_ASSERT((utility::DistanceConcept<Distance>));
00185   }
00186 
00187   template <typename Distance>
00188   NCC<Distance>::NCC(const Distance& dist) 
00189     : SupervisedClassifier(), centroids_nan_(false), distance_(dist)
00190   {
00191     BOOST_CONCEPT_ASSERT((utility::DistanceConcept<Distance>));
00192   }
00193 
00194 
00195   template <typename Distance>
00196   NCC<Distance>::~NCC()   
00197   {
00198   }
00199 
00200 
00201   template <typename Distance>
00202   const utility::Matrix& NCC<Distance>::centroids(void) const
00203   {
00204     return centroids_;
00205   }
00206   
00207 
00208   template <typename Distance>
00209   NCC<Distance>* 
00210   NCC<Distance>::make_classifier() const 
00211   {     
00212     // All private members should be copied here to generate an
00213     // identical but untrained classifier
00214     return new NCC<Distance>(distance_);
00215   }
00216 
00217   template <typename Distance>
00218   void NCC<Distance>::train(const MatrixLookup& data, const Target& target)
00219   {   
00220     centroids_.resize(data.rows(), target.nof_classes());
00221     for(size_t i=0; i<data.rows(); i++) {
00222       std::vector<statistics::Averager> class_averager;
00223       class_averager.resize(target.nof_classes());
00224       for(size_t j=0; j<data.columns(); j++) {
00225         class_averager[target(j)].add(data(i,j));
00226       }
00227       for(size_t c=0;c<target.nof_classes();c++) {
00228         centroids_(i,c) = class_averager[c].mean();
00229       }
00230     }
00231   }
00232 
00233 
00234   template <typename Distance>
00235   void NCC<Distance>::train(const MatrixLookupWeighted& data, const Target& target)
00236   {   
00237     centroids_.resize(data.rows(), target.nof_classes());
00238     for(size_t i=0; i<data.rows(); i++) {
00239       std::vector<statistics::AveragerWeighted> class_averager;
00240       class_averager.resize(target.nof_classes());
00241       for(size_t j=0; j<data.columns(); j++) 
00242         class_averager[target(j)].add(data.data(i,j),data.weight(i,j));
00243       for(size_t c=0;c<target.nof_classes();c++) {
00244         if(class_averager[c].sum_w()==0) {
00245           centroids_nan_=true;
00246         }
00247         centroids_(i,c) = class_averager[c].mean();
00248       }
00249     }
00250   }
00251 
00252 
00253   template <typename Distance>
00254   void NCC<Distance>::predict(const MatrixLookup& test,                     
00255                               utility::Matrix& prediction) const
00256   {   
00257     utility::yat_assert<utility::runtime_error>
00258       (centroids_.rows()==test.rows(),
00259        "NCC::predict test data with incorrect number of rows");
00260     
00261     prediction.resize(centroids_.columns(), test.columns());
00262 
00263     // If weighted training data has resulted in NaN in centroids: weighted calculations
00264     if(centroids_nan_) { 
00265       predict_weighted(MatrixLookupWeighted(test),prediction);
00266     }
00267     // If unweighted training data: unweighted calculations
00268     else {
00269       predict_unweighted(test,prediction);
00270     }
00271   }
00272 
00273   template <typename Distance>
00274   void NCC<Distance>::predict(const MatrixLookupWeighted& test,                     
00275                               utility::Matrix& prediction) const
00276   {   
00277     utility::yat_assert<utility::runtime_error>
00278       (centroids_.rows()==test.rows(),
00279        "NCC::predict test data with incorrect number of rows");
00280     
00281     prediction.resize(centroids_.columns(), test.columns());
00282     predict_weighted(test,prediction);
00283   }
00284 
00285   
00286   template <typename Distance>
00287   void NCC<Distance>::predict_unweighted(const MatrixLookup& test, 
00288                                          utility::Matrix& prediction) const
00289   {
00290     for(size_t j=0; j<test.columns();j++)
00291       for(size_t k=0; k<centroids_.columns();k++) 
00292         prediction(k,j) = distance_(test.begin_column(j), test.end_column(j), 
00293                                     centroids_.begin_column(k));
00294   }
00295   
00296   template <typename Distance>
00297   void NCC<Distance>::predict_weighted(const MatrixLookupWeighted& test, 
00298                                           utility::Matrix& prediction) const
00299   {
00300     utility::MatrixWeighted weighted_centroids(centroids_);
00301     for(size_t j=0; j<test.columns();j++) 
00302       for(size_t k=0; k<centroids_.columns();k++)
00303         prediction(k,j) = distance_(test.begin_column(j), test.end_column(j), 
00304                                     weighted_centroids.begin_column(k));
00305   }
00306 
00307       
00308 }}} // of namespace classifier, yat, and theplu
00309 
00310 #endif

Generated on Thu Dec 20 2012 03:12:57 for yat by  doxygen 1.8.0-20120409