yat  0.8.3pre
KNN.h
00001 #ifndef _theplu_yat_classifier_knn_ 
00002 #define _theplu_yat_classifier_knn_ 
00003 
00004 // $Id: KNN.h 2384 2010-12-22 14:03:36Z peter $
00005 
00006 /*
00007   Copyright (C) 2007, 2008 Jari Häkkinen, Peter Johansson, Markus Ringnér
00008   Copyright (C) 2009, 2010 Peter Johansson
00009 
00010   This file is part of the yat library, http://dev.thep.lu.se/yat
00011 
00012   The yat library is free software; you can redistribute it and/or
00013   modify it under the terms of the GNU General Public License as
00014   published by the Free Software Foundation; either version 3 of the
00015   License, or (at your option) any later version.
00016 
00017   The yat library is distributed in the hope that it will be useful,
00018   but WITHOUT ANY WARRANTY; without even the implied warranty of
00019   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
00020   General Public License for more details.
00021 
00022   You should have received a copy of the GNU General Public License
00023   along with yat. If not, see <http://www.gnu.org/licenses/>.
00024 */
00025 
00026 #include "DataLookup1D.h"
00027 #include "DataLookupWeighted1D.h"
00028 #include "KNN_Uniform.h"
00029 #include "MatrixLookup.h"
00030 #include "MatrixLookupWeighted.h"
00031 #include "SupervisedClassifier.h"
00032 #include "Target.h"
00033 #include "yat/utility/concept_check.h"
00034 #include "yat/utility/Exception.h"
00035 #include "yat/utility/Matrix.h"
00036 #include "yat/utility/Vector.h"
00037 #include "yat/utility/VectorConstView.h"
00038 #include "yat/utility/VectorView.h"
00039 #include "yat/utility/yat_assert.h"
00040 
00041 #include <boost/concept_check.hpp>
00042 
00043 #include <cmath>
00044 #include <limits>
00045 #include <map>
00046 #include <stdexcept>
00047 #include <vector>
00048 
00049 namespace theplu {
00050 namespace yat {
00051 namespace classifier {
00052 
00069   template <typename Distance, typename NeighborWeighting=KNN_Uniform>
00070   class KNN : public SupervisedClassifier
00071   {
00072     
00073   public:
00081     KNN(void);
00082 
00083 
00093     KNN(const Distance&);
00094 
00095 
00099     virtual ~KNN();
00100     
00101     
00106     unsigned int k() const;
00107 
00113     void k(unsigned int k_in);
00114 
00115 
00116     KNN<Distance,NeighborWeighting>* make_classifier(void) const;
00117     
00128     void predict(const MatrixLookup& data , utility::Matrix& results) const;
00129 
00142     void predict(const MatrixLookupWeighted& data, 
00143                  utility::Matrix& results) const;
00144 
00145 
00160     void train(const MatrixLookup& data, const Target& targets);
00161     
00168     void train(const MatrixLookupWeighted& data, const Target& targets);
00169     
00170   private:
00171     
00172     const MatrixLookup* data_ml_;
00173     const MatrixLookupWeighted* data_mlw_;
00174     const Target* target_;
00175 
00176     // The number of neighbors
00177     unsigned int k_;
00178 
00179     Distance distance_;
00180     NeighborWeighting weighting_;
00181 
00182     void calculate_unweighted(const MatrixLookup&,
00183                               const MatrixLookup&,
00184                               utility::Matrix*) const;
00185     void calculate_weighted(const MatrixLookupWeighted&,
00186                             const MatrixLookupWeighted&,
00187                             utility::Matrix*) const;
00188 
00189     void predict_common(const utility::Matrix& distances, 
00190                         utility::Matrix& prediction) const;
00191 
00192   };
00193   
00194   
00213   template <class T>
00214   class NeighborWeightingConcept 
00215     : public boost::DefaultConstructible<T>, public boost::Assignable<T>
00216   {
00217   public:
00221     BOOST_CONCEPT_USAGE(NeighborWeightingConcept)
00222     {
00223       T neighbor_weighting;
00224       utility::Vector vec;
00225       const utility::VectorBase& distance(vec);
00226       utility::VectorMutable& prediction(vec);
00227       std::vector<size_t> k_sorted;
00228       Target target;
00229       neighbor_weighting(distance, k_sorted, target, prediction);
00230     }
00231   private:
00232   };
00233 
00234   // template implementation
00235   
00236   template <typename Distance, typename NeighborWeighting>
00237   KNN<Distance, NeighborWeighting>::KNN() 
00238     : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3)
00239   {
00240     BOOST_CONCEPT_ASSERT((utility::DistanceConcept<Distance>));
00241     BOOST_CONCEPT_ASSERT((NeighborWeightingConcept<NeighborWeighting>));
00242   }
00243 
00244   template <typename Distance, typename NeighborWeighting>
00245   KNN<Distance, NeighborWeighting>::KNN(const Distance& dist) 
00246     : SupervisedClassifier(), data_ml_(0), data_mlw_(0), target_(0), k_(3), 
00247       distance_(dist)
00248   {
00249     BOOST_CONCEPT_ASSERT((utility::DistanceConcept<Distance>));
00250     //    BOOST_CONCEPT_ASSERT((NeighborWeightingConcept<NeighborWeighting>));
00251   }
00252 
00253   
00254   template <typename Distance, typename NeighborWeighting>
00255   KNN<Distance, NeighborWeighting>::~KNN()    
00256   {
00257   }
00258   
00259 
00260   template <typename Distance, typename NeighborWeighting>
00261   void  KNN<Distance, NeighborWeighting>::calculate_unweighted
00262   (const MatrixLookup& training, const MatrixLookup& test,
00263    utility::Matrix* distances) const
00264   {
00265     for(size_t i=0; i<training.columns(); i++) {
00266       for(size_t j=0; j<test.columns(); j++) {
00267         (*distances)(i,j) = distance_(training.begin_column(i), 
00268                                       training.end_column(i), 
00269                                       test.begin_column(j));
00270         YAT_ASSERT(!std::isnan((*distances)(i,j)));
00271       }
00272     }
00273   }
00274 
00275   
00276   template <typename Distance, typename NeighborWeighting>
00277   void  
00278   KNN<Distance, NeighborWeighting>::calculate_weighted
00279   (const MatrixLookupWeighted& training, const MatrixLookupWeighted& test,
00280    utility::Matrix* distances) const
00281   {
00282     for(size_t i=0; i<training.columns(); i++) { 
00283       for(size_t j=0; j<test.columns(); j++) {
00284         (*distances)(i,j) = distance_(training.begin_column(i), 
00285                                       training.end_column(i), 
00286                                       test.begin_column(j));
00287         // If the distance is NaN (no common variables with non-zero weights),
00288         // the distance is set to infinity to be sorted as a neighbor at the end
00289         if(std::isnan((*distances)(i,j))) 
00290           (*distances)(i,j)=std::numeric_limits<double>::infinity();
00291       }
00292     }
00293   }
00294   
00295   
00296   template <typename Distance, typename NeighborWeighting>
00297   unsigned int KNN<Distance, NeighborWeighting>::k() const
00298   {
00299     return k_;
00300   }
00301 
00302   template <typename Distance, typename NeighborWeighting>
00303   void KNN<Distance, NeighborWeighting>::k(unsigned int k)
00304   {
00305     k_=k;
00306   }
00307 
00308 
00309   template <typename Distance, typename NeighborWeighting>
00310   KNN<Distance, NeighborWeighting>* 
00311   KNN<Distance, NeighborWeighting>::make_classifier() const 
00312   {     
00313     // All private members should be copied here to generate an
00314     // identical but untrained classifier
00315     KNN* knn=new KNN<Distance, NeighborWeighting>(distance_);
00316     knn->weighting_=this->weighting_;
00317     knn->k(this->k());
00318     return knn;
00319   }
00320   
00321   
00322   template <typename Distance, typename NeighborWeighting>
00323   void KNN<Distance, NeighborWeighting>::train(const MatrixLookup& data, 
00324                                                const Target& target)
00325   {   
00326     utility::yat_assert<utility::runtime_error>
00327       (data.columns()==target.size(),
00328        "KNN::train called with different sizes of target and data");
00329     // k has to be at most the number of training samples.
00330     if(data.columns()<k_) 
00331       k_=data.columns();
00332     data_ml_=&data;
00333     data_mlw_=0;
00334     target_=&target;
00335   }
00336 
00337   template <typename Distance, typename NeighborWeighting>
00338   void KNN<Distance, NeighborWeighting>::train(const MatrixLookupWeighted& data, 
00339                                                const Target& target)
00340   {   
00341     utility::yat_assert<utility::runtime_error>
00342       (data.columns()==target.size(),
00343        "KNN::train called with different sizes of target and data");
00344     // k has to be at most the number of training samples.
00345     if(data.columns()<k_) 
00346       k_=data.columns();
00347     data_ml_=0;
00348     data_mlw_=&data;
00349     target_=&target;
00350   }
00351 
00352 
00353   template <typename Distance, typename NeighborWeighting>
00354   void 
00355   KNN<Distance, NeighborWeighting>::predict(const MatrixLookup& test,
00356                                             utility::Matrix& prediction) const
00357   {   
00358     // matrix with training samples as rows and test samples as columns
00359     utility::Matrix* distances = 0;
00360     // unweighted training data
00361     if(data_ml_ && !data_mlw_) {
00362       utility::yat_assert<utility::runtime_error>
00363         (data_ml_->rows()==test.rows(),
00364          "KNN::predict different number of rows in training and test data");
00365       distances=new utility::Matrix(data_ml_->columns(),test.columns());
00366       calculate_unweighted(*data_ml_,test,distances);
00367     }
00368     else if (data_mlw_ && !data_ml_) {
00369       // weighted training data
00370       utility::yat_assert<utility::runtime_error>
00371         (data_mlw_->rows()==test.rows(),
00372          "KNN::predict different number of rows in training and test data");
00373       distances=new utility::Matrix(data_mlw_->columns(),test.columns());
00374       calculate_weighted(*data_mlw_,MatrixLookupWeighted(test),
00375                          distances);              
00376     }
00377     else {
00378       throw utility::runtime_error("KNN::predict no training data");
00379     }
00380 
00381     prediction.resize(target_->nof_classes(),test.columns(),0.0);
00382     predict_common(*distances,prediction);
00383     if(distances)
00384       delete distances;
00385   }
00386 
00387   template <typename Distance, typename NeighborWeighting>
00388   void 
00389   KNN<Distance, NeighborWeighting>::predict(const MatrixLookupWeighted& test,
00390                                             utility::Matrix& prediction) const
00391   {   
00392     // matrix with training samples as rows and test samples as columns
00393     utility::Matrix* distances=0; 
00394     // unweighted training data
00395     if(data_ml_ && !data_mlw_) { 
00396       utility::yat_assert<utility::runtime_error>
00397         (data_ml_->rows()==test.rows(),
00398          "KNN::predict different number of rows in training and test data");    
00399       distances=new utility::Matrix(data_ml_->columns(),test.columns());
00400       calculate_weighted(MatrixLookupWeighted(*data_ml_),test,distances);   
00401     }
00402     // weighted training data
00403     else if (data_mlw_ && !data_ml_) {
00404       utility::yat_assert<utility::runtime_error>
00405         (data_mlw_->rows()==test.rows(),
00406          "KNN::predict different number of rows in training and test data");    
00407       distances=new utility::Matrix(data_mlw_->columns(),test.columns());
00408       calculate_weighted(*data_mlw_,test,distances);              
00409     }
00410     else {
00411       throw utility::runtime_error("KNN::predict no training data");
00412     }
00413 
00414     prediction.resize(target_->nof_classes(),test.columns(),0.0);
00415     predict_common(*distances,prediction);
00416     
00417     if(distances)
00418       delete distances;
00419   }
00420   
00421   template <typename Distance, typename NeighborWeighting>
00422   void KNN<Distance, NeighborWeighting>::predict_common
00423   (const utility::Matrix& distances, utility::Matrix& prediction) const
00424   {   
00425     for(size_t sample=0;sample<distances.columns();sample++) {
00426       std::vector<size_t> k_index;
00427       utility::VectorConstView dist=distances.column_const_view(sample);
00428       utility::sort_smallest_index(k_index,k_,dist);
00429       utility::VectorView pred=prediction.column_view(sample);
00430       weighting_(dist,k_index,*target_,pred);
00431     }
00432     
00433     // classes for which there are no training samples should be set
00434     // to nan in the predictions
00435     for(size_t c=0;c<target_->nof_classes(); c++) 
00436       if(!target_->size(c)) 
00437         for(size_t j=0;j<prediction.columns();j++)
00438           prediction(c,j)=std::numeric_limits<double>::quiet_NaN();
00439   }
00440 }}} // of namespace classifier, yat, and theplu
00441 
00442 #endif

Generated on Thu Dec 20 2012 03:12:57 for yat by  doxygen 1.8.0-20120409