yat
0.8.3pre
|
00001 #ifndef _theplu_yat_classifier_knn_ 00002 #define _theplu_yat_classifier_knn_ 00003 00004 // $Id: KNN.h 2384 2010-12-22 14:03:36Z peter $ 00005 00006 /* 00007 Copyright (C) 2007, 2008 Jari Häkkinen, Peter Johansson, Markus Ringnér 00008 Copyright (C) 2009, 2010 Peter Johansson 00009 00010 This file is part of the yat library, http://dev.thep.lu.se/yat 00011 00012 The yat library is free software; you can redistribute it and/or 00013 modify it under the terms of the GNU General Public License as 00014 published by the Free Software Foundation; either version 3 of the 00015 License, or (at your option) any later version. 00016 00017 The yat library is distributed in the hope that it will be useful, 00018 but WITHOUT ANY WARRANTY; without even the implied warranty of 00019 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 00020 General Public License for more details. 00021 00022 You should have received a copy of the GNU General Public License 00023 along with yat. If not, see <http://www.gnu.org/licenses/>. 00024 */ 00025 00026 #include "DataLookup1D.h" 00027 #include "DataLookupWeighted1D.h" 00028 #include "KNN_Uniform.h" 00029 #include "MatrixLookup.h" 00030 #include "MatrixLookupWeighted.h" 00031 #include "SupervisedClassifier.h" 00032 #include "Target.h" 00033 #include "yat/utility/concept_check.h" 00034 #include "yat/utility/Exception.h" 00035 #include "yat/utility/Matrix.h" 00036 #include "yat/utility/Vector.h" 00037 #include "yat/utility/VectorConstView.h" 00038 #include "yat/utility/VectorView.h" 00039 #include "yat/utility/yat_assert.h" 00040 00041 #include <boost/concept_check.hpp> 00042 00043 #include <cmath> 00044 #include <limits> 00045 #include <map> 00046 #include <stdexcept> 00047 #include <vector> 00048 00049 namespace theplu { 00050 namespace yat { 00051 namespace classifier { 00052 00069 template <typename Distance, typename NeighborWeighting=KNN_Uniform> 00070 class KNN : public SupervisedClassifier 00071 { 00072 00073 public: 00081 KNN(void); 00082 00083 00093 KNN(const Distance&); 00094 00095 00099 virtual ~KNN(); 00100 00101 00106 unsigned int k() const; 00107 00113 void k(unsigned int k_in); 00114 00115 00116 KNN<Distance,NeighborWeighting>* make_classifier(void) const; 00117 00128 void predict(const MatrixLookup& data , utility::Matrix& results) const; 00129 00142 void predict(const MatrixLookupWeighted& data, 00143 utility::Matrix& results) const; 00144 00145 00160 void train(const MatrixLookup& data, const Target& targets); 00161 00168 void train(const MatrixLookupWeighted& data, const Target& targets); 00169 00170 private: 00171 00172 const MatrixLookup* data_ml_; 00173 const MatrixLookupWeighted* data_mlw_; 00174 const Target* target_; 00175 00176 // The number of neighbors 00177 unsigned int k_; 00178 00179 Distance distance_; 00180 NeighborWeighting weighting_; 00181 00182 void calculate_unweighted(const MatrixLookup&, 00183 const MatrixLookup&, 00184 utility::Matrix*) const; 00185 void calculate_weighted(const MatrixLookupWeighted&, 00186 const MatrixLookupWeighted&, 00187 utility::Matrix*) const; 00188 00189 void predict_common(const utility::Matrix& distances, 00190 utility::Matrix& prediction) const; 00191 00192 }; 00193 00194 00213 template <class T> 00214 class NeighborWeightingConcept 00215 : public boost::DefaultConstructible<T>, public boost::Assignable<T> 00216 { 00217 public: 00221 BOOST_CONCEPT_USAGE(NeighborWeightingConcept) 00222 { 00223 T neighbor_weighting; 00224 utility::Vector vec; 00225 const utility::VectorBase& distance(vec); 00226 utility::VectorMutable& prediction(vec); 00227 std::vector<size_t> k_sorted; 00228 Target target; 00229 neighbor_weighting(distance, k_sorted, target, prediction); 00230 } 00231 private: 00232 }; 00233 00234 // template implementation 00235 00236 template <typename Distance, typename NeighborWeighting> 00237 KNN<Distance, NeighborWeighting>::KNN() 00238 : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3) 00239 { 00240 BOOST_CONCEPT_ASSERT((utility::DistanceConcept<Distance>)); 00241 BOOST_CONCEPT_ASSERT((NeighborWeightingConcept<NeighborWeighting>)); 00242 } 00243 00244 template <typename Distance, typename NeighborWeighting> 00245 KNN<Distance, NeighborWeighting>::KNN(const Distance& dist) 00246 : SupervisedClassifier(), data_ml_(0), data_mlw_(0), target_(0), k_(3), 00247 distance_(dist) 00248 { 00249 BOOST_CONCEPT_ASSERT((utility::DistanceConcept<Distance>)); 00250 // BOOST_CONCEPT_ASSERT((NeighborWeightingConcept<NeighborWeighting>)); 00251 } 00252 00253 00254 template <typename Distance, typename NeighborWeighting> 00255 KNN<Distance, NeighborWeighting>::~KNN() 00256 { 00257 } 00258 00259 00260 template <typename Distance, typename NeighborWeighting> 00261 void KNN<Distance, NeighborWeighting>::calculate_unweighted 00262 (const MatrixLookup& training, const MatrixLookup& test, 00263 utility::Matrix* distances) const 00264 { 00265 for(size_t i=0; i<training.columns(); i++) { 00266 for(size_t j=0; j<test.columns(); j++) { 00267 (*distances)(i,j) = distance_(training.begin_column(i), 00268 training.end_column(i), 00269 test.begin_column(j)); 00270 YAT_ASSERT(!std::isnan((*distances)(i,j))); 00271 } 00272 } 00273 } 00274 00275 00276 template <typename Distance, typename NeighborWeighting> 00277 void 00278 KNN<Distance, NeighborWeighting>::calculate_weighted 00279 (const MatrixLookupWeighted& training, const MatrixLookupWeighted& test, 00280 utility::Matrix* distances) const 00281 { 00282 for(size_t i=0; i<training.columns(); i++) { 00283 for(size_t j=0; j<test.columns(); j++) { 00284 (*distances)(i,j) = distance_(training.begin_column(i), 00285 training.end_column(i), 00286 test.begin_column(j)); 00287 // If the distance is NaN (no common variables with non-zero weights), 00288 // the distance is set to infinity to be sorted as a neighbor at the end 00289 if(std::isnan((*distances)(i,j))) 00290 (*distances)(i,j)=std::numeric_limits<double>::infinity(); 00291 } 00292 } 00293 } 00294 00295 00296 template <typename Distance, typename NeighborWeighting> 00297 unsigned int KNN<Distance, NeighborWeighting>::k() const 00298 { 00299 return k_; 00300 } 00301 00302 template <typename Distance, typename NeighborWeighting> 00303 void KNN<Distance, NeighborWeighting>::k(unsigned int k) 00304 { 00305 k_=k; 00306 } 00307 00308 00309 template <typename Distance, typename NeighborWeighting> 00310 KNN<Distance, NeighborWeighting>* 00311 KNN<Distance, NeighborWeighting>::make_classifier() const 00312 { 00313 // All private members should be copied here to generate an 00314 // identical but untrained classifier 00315 KNN* knn=new KNN<Distance, NeighborWeighting>(distance_); 00316 knn->weighting_=this->weighting_; 00317 knn->k(this->k()); 00318 return knn; 00319 } 00320 00321 00322 template <typename Distance, typename NeighborWeighting> 00323 void KNN<Distance, NeighborWeighting>::train(const MatrixLookup& data, 00324 const Target& target) 00325 { 00326 utility::yat_assert<utility::runtime_error> 00327 (data.columns()==target.size(), 00328 "KNN::train called with different sizes of target and data"); 00329 // k has to be at most the number of training samples. 00330 if(data.columns()<k_) 00331 k_=data.columns(); 00332 data_ml_=&data; 00333 data_mlw_=0; 00334 target_=⌖ 00335 } 00336 00337 template <typename Distance, typename NeighborWeighting> 00338 void KNN<Distance, NeighborWeighting>::train(const MatrixLookupWeighted& data, 00339 const Target& target) 00340 { 00341 utility::yat_assert<utility::runtime_error> 00342 (data.columns()==target.size(), 00343 "KNN::train called with different sizes of target and data"); 00344 // k has to be at most the number of training samples. 00345 if(data.columns()<k_) 00346 k_=data.columns(); 00347 data_ml_=0; 00348 data_mlw_=&data; 00349 target_=⌖ 00350 } 00351 00352 00353 template <typename Distance, typename NeighborWeighting> 00354 void 00355 KNN<Distance, NeighborWeighting>::predict(const MatrixLookup& test, 00356 utility::Matrix& prediction) const 00357 { 00358 // matrix with training samples as rows and test samples as columns 00359 utility::Matrix* distances = 0; 00360 // unweighted training data 00361 if(data_ml_ && !data_mlw_) { 00362 utility::yat_assert<utility::runtime_error> 00363 (data_ml_->rows()==test.rows(), 00364 "KNN::predict different number of rows in training and test data"); 00365 distances=new utility::Matrix(data_ml_->columns(),test.columns()); 00366 calculate_unweighted(*data_ml_,test,distances); 00367 } 00368 else if (data_mlw_ && !data_ml_) { 00369 // weighted training data 00370 utility::yat_assert<utility::runtime_error> 00371 (data_mlw_->rows()==test.rows(), 00372 "KNN::predict different number of rows in training and test data"); 00373 distances=new utility::Matrix(data_mlw_->columns(),test.columns()); 00374 calculate_weighted(*data_mlw_,MatrixLookupWeighted(test), 00375 distances); 00376 } 00377 else { 00378 throw utility::runtime_error("KNN::predict no training data"); 00379 } 00380 00381 prediction.resize(target_->nof_classes(),test.columns(),0.0); 00382 predict_common(*distances,prediction); 00383 if(distances) 00384 delete distances; 00385 } 00386 00387 template <typename Distance, typename NeighborWeighting> 00388 void 00389 KNN<Distance, NeighborWeighting>::predict(const MatrixLookupWeighted& test, 00390 utility::Matrix& prediction) const 00391 { 00392 // matrix with training samples as rows and test samples as columns 00393 utility::Matrix* distances=0; 00394 // unweighted training data 00395 if(data_ml_ && !data_mlw_) { 00396 utility::yat_assert<utility::runtime_error> 00397 (data_ml_->rows()==test.rows(), 00398 "KNN::predict different number of rows in training and test data"); 00399 distances=new utility::Matrix(data_ml_->columns(),test.columns()); 00400 calculate_weighted(MatrixLookupWeighted(*data_ml_),test,distances); 00401 } 00402 // weighted training data 00403 else if (data_mlw_ && !data_ml_) { 00404 utility::yat_assert<utility::runtime_error> 00405 (data_mlw_->rows()==test.rows(), 00406 "KNN::predict different number of rows in training and test data"); 00407 distances=new utility::Matrix(data_mlw_->columns(),test.columns()); 00408 calculate_weighted(*data_mlw_,test,distances); 00409 } 00410 else { 00411 throw utility::runtime_error("KNN::predict no training data"); 00412 } 00413 00414 prediction.resize(target_->nof_classes(),test.columns(),0.0); 00415 predict_common(*distances,prediction); 00416 00417 if(distances) 00418 delete distances; 00419 } 00420 00421 template <typename Distance, typename NeighborWeighting> 00422 void KNN<Distance, NeighborWeighting>::predict_common 00423 (const utility::Matrix& distances, utility::Matrix& prediction) const 00424 { 00425 for(size_t sample=0;sample<distances.columns();sample++) { 00426 std::vector<size_t> k_index; 00427 utility::VectorConstView dist=distances.column_const_view(sample); 00428 utility::sort_smallest_index(k_index,k_,dist); 00429 utility::VectorView pred=prediction.column_view(sample); 00430 weighting_(dist,k_index,*target_,pred); 00431 } 00432 00433 // classes for which there are no training samples should be set 00434 // to nan in the predictions 00435 for(size_t c=0;c<target_->nof_classes(); c++) 00436 if(!target_->size(c)) 00437 for(size_t j=0;j<prediction.columns();j++) 00438 prediction(c,j)=std::numeric_limits<double>::quiet_NaN(); 00439 } 00440 }}} // of namespace classifier, yat, and theplu 00441 00442 #endif