00001 #ifndef _theplu_yat_classifier_knn_
00002 #define _theplu_yat_classifier_knn_
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026 #include "DataLookup1D.h"
00027 #include "DataLookupWeighted1D.h"
00028 #include "KNN_Uniform.h"
00029 #include "MatrixLookup.h"
00030 #include "MatrixLookupWeighted.h"
00031 #include "SupervisedClassifier.h"
00032 #include "Target.h"
00033 #include "yat/utility/concept_check.h"
00034 #include "yat/utility/Exception.h"
00035 #include "yat/utility/Matrix.h"
00036 #include "yat/utility/Vector.h"
00037 #include "yat/utility/VectorConstView.h"
00038 #include "yat/utility/VectorView.h"
00039 #include "yat/utility/yat_assert.h"
00040
00041 #include <boost/concept_check.hpp>
00042
00043 #include <cmath>
00044 #include <limits>
00045 #include <map>
00046 #include <stdexcept>
00047 #include <vector>
00048
00049 namespace theplu {
00050 namespace yat {
00051 namespace classifier {
00052
00069 template <typename Distance, typename NeighborWeighting=KNN_Uniform>
00070 class KNN : public SupervisedClassifier
00071 {
00072
00073 public:
00081 KNN(void);
00082
00083
00093 KNN(const Distance&);
00094
00095
00099 virtual ~KNN();
00100
00101
00106 unsigned int k() const;
00107
00113 void k(unsigned int k_in);
00114
00115
00116 KNN<Distance,NeighborWeighting>* make_classifier(void) const;
00117
00128 void predict(const MatrixLookup& data , utility::Matrix& results) const;
00129
00142 void predict(const MatrixLookupWeighted& data,
00143 utility::Matrix& results) const;
00144
00145
00160 void train(const MatrixLookup& data, const Target& targets);
00161
00168 void train(const MatrixLookupWeighted& data, const Target& targets);
00169
00170 private:
00171
00172 const MatrixLookup* data_ml_;
00173 const MatrixLookupWeighted* data_mlw_;
00174 const Target* target_;
00175
00176
00177 unsigned int k_;
00178
00179 Distance distance_;
00180 NeighborWeighting weighting_;
00181
00182 void calculate_unweighted(const MatrixLookup&,
00183 const MatrixLookup&,
00184 utility::Matrix*) const;
00185 void calculate_weighted(const MatrixLookupWeighted&,
00186 const MatrixLookupWeighted&,
00187 utility::Matrix*) const;
00188
00189 void predict_common(const utility::Matrix& distances,
00190 utility::Matrix& prediction) const;
00191
00192 };
00193
00194
00213 template <class T>
00214 class NeighborWeightingConcept
00215 : public boost::DefaultConstructible<T>, public boost::Assignable<T>
00216 {
00217 public:
00221 BOOST_CONCEPT_USAGE(NeighborWeightingConcept)
00222 {
00223 T neighbor_weighting;
00224 utility::Vector vec;
00225 const utility::VectorBase& distance(vec);
00226 utility::VectorMutable& prediction(vec);
00227 std::vector<size_t> k_sorted;
00228 Target target;
00229 neighbor_weighting(distance, k_sorted, target, prediction);
00230 }
00231 private:
00232 };
00233
00234
00235
00236 template <typename Distance, typename NeighborWeighting>
00237 KNN<Distance, NeighborWeighting>::KNN()
00238 : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3)
00239 {
00240 BOOST_CONCEPT_ASSERT((utility::DistanceConcept<Distance>));
00241 BOOST_CONCEPT_ASSERT((NeighborWeightingConcept<NeighborWeighting>));
00242 }
00243
00244 template <typename Distance, typename NeighborWeighting>
00245 KNN<Distance, NeighborWeighting>::KNN(const Distance& dist)
00246 : SupervisedClassifier(), data_ml_(0), data_mlw_(0), target_(0), k_(3),
00247 distance_(dist)
00248 {
00249 BOOST_CONCEPT_ASSERT((utility::DistanceConcept<Distance>));
00250
00251 }
00252
00253
00254 template <typename Distance, typename NeighborWeighting>
00255 KNN<Distance, NeighborWeighting>::~KNN()
00256 {
00257 }
00258
00259
00260 template <typename Distance, typename NeighborWeighting>
00261 void KNN<Distance, NeighborWeighting>::calculate_unweighted
00262 (const MatrixLookup& training, const MatrixLookup& test,
00263 utility::Matrix* distances) const
00264 {
00265 for(size_t i=0; i<training.columns(); i++) {
00266 for(size_t j=0; j<test.columns(); j++) {
00267 (*distances)(i,j) = distance_(training.begin_column(i),
00268 training.end_column(i),
00269 test.begin_column(j));
00270 YAT_ASSERT(!std::isnan((*distances)(i,j)));
00271 }
00272 }
00273 }
00274
00275
00276 template <typename Distance, typename NeighborWeighting>
00277 void
00278 KNN<Distance, NeighborWeighting>::calculate_weighted
00279 (const MatrixLookupWeighted& training, const MatrixLookupWeighted& test,
00280 utility::Matrix* distances) const
00281 {
00282 for(size_t i=0; i<training.columns(); i++) {
00283 for(size_t j=0; j<test.columns(); j++) {
00284 (*distances)(i,j) = distance_(training.begin_column(i),
00285 training.end_column(i),
00286 test.begin_column(j));
00287
00288
00289 if(std::isnan((*distances)(i,j)))
00290 (*distances)(i,j)=std::numeric_limits<double>::infinity();
00291 }
00292 }
00293 }
00294
00295
00296 template <typename Distance, typename NeighborWeighting>
00297 unsigned int KNN<Distance, NeighborWeighting>::k() const
00298 {
00299 return k_;
00300 }
00301
00302 template <typename Distance, typename NeighborWeighting>
00303 void KNN<Distance, NeighborWeighting>::k(unsigned int k)
00304 {
00305 k_=k;
00306 }
00307
00308
00309 template <typename Distance, typename NeighborWeighting>
00310 KNN<Distance, NeighborWeighting>*
00311 KNN<Distance, NeighborWeighting>::make_classifier() const
00312 {
00313
00314
00315 KNN* knn=new KNN<Distance, NeighborWeighting>(distance_);
00316 knn->weighting_=this->weighting_;
00317 knn->k(this->k());
00318 return knn;
00319 }
00320
00321
00322 template <typename Distance, typename NeighborWeighting>
00323 void KNN<Distance, NeighborWeighting>::train(const MatrixLookup& data,
00324 const Target& target)
00325 {
00326 utility::yat_assert<utility::runtime_error>
00327 (data.columns()==target.size(),
00328 "KNN::train called with different sizes of target and data");
00329
00330 if(data.columns()<k_)
00331 k_=data.columns();
00332 data_ml_=&data;
00333 data_mlw_=0;
00334 target_=⌖
00335 }
00336
00337 template <typename Distance, typename NeighborWeighting>
00338 void KNN<Distance, NeighborWeighting>::train(const MatrixLookupWeighted& data,
00339 const Target& target)
00340 {
00341 utility::yat_assert<utility::runtime_error>
00342 (data.columns()==target.size(),
00343 "KNN::train called with different sizes of target and data");
00344
00345 if(data.columns()<k_)
00346 k_=data.columns();
00347 data_ml_=0;
00348 data_mlw_=&data;
00349 target_=⌖
00350 }
00351
00352
00353 template <typename Distance, typename NeighborWeighting>
00354 void
00355 KNN<Distance, NeighborWeighting>::predict(const MatrixLookup& test,
00356 utility::Matrix& prediction) const
00357 {
00358
00359 utility::Matrix* distances = 0;
00360
00361 if(data_ml_ && !data_mlw_) {
00362 utility::yat_assert<utility::runtime_error>
00363 (data_ml_->rows()==test.rows(),
00364 "KNN::predict different number of rows in training and test data");
00365 distances=new utility::Matrix(data_ml_->columns(),test.columns());
00366 calculate_unweighted(*data_ml_,test,distances);
00367 }
00368 else if (data_mlw_ && !data_ml_) {
00369
00370 utility::yat_assert<utility::runtime_error>
00371 (data_mlw_->rows()==test.rows(),
00372 "KNN::predict different number of rows in training and test data");
00373 distances=new utility::Matrix(data_mlw_->columns(),test.columns());
00374 calculate_weighted(*data_mlw_,MatrixLookupWeighted(test),
00375 distances);
00376 }
00377 else {
00378 throw utility::runtime_error("KNN::predict no training data");
00379 }
00380
00381 prediction.resize(target_->nof_classes(),test.columns(),0.0);
00382 predict_common(*distances,prediction);
00383 if(distances)
00384 delete distances;
00385 }
00386
00387 template <typename Distance, typename NeighborWeighting>
00388 void
00389 KNN<Distance, NeighborWeighting>::predict(const MatrixLookupWeighted& test,
00390 utility::Matrix& prediction) const
00391 {
00392
00393 utility::Matrix* distances=0;
00394
00395 if(data_ml_ && !data_mlw_) {
00396 utility::yat_assert<utility::runtime_error>
00397 (data_ml_->rows()==test.rows(),
00398 "KNN::predict different number of rows in training and test data");
00399 distances=new utility::Matrix(data_ml_->columns(),test.columns());
00400 calculate_weighted(MatrixLookupWeighted(*data_ml_),test,distances);
00401 }
00402
00403 else if (data_mlw_ && !data_ml_) {
00404 utility::yat_assert<utility::runtime_error>
00405 (data_mlw_->rows()==test.rows(),
00406 "KNN::predict different number of rows in training and test data");
00407 distances=new utility::Matrix(data_mlw_->columns(),test.columns());
00408 calculate_weighted(*data_mlw_,test,distances);
00409 }
00410 else {
00411 throw utility::runtime_error("KNN::predict no training data");
00412 }
00413
00414 prediction.resize(target_->nof_classes(),test.columns(),0.0);
00415 predict_common(*distances,prediction);
00416
00417 if(distances)
00418 delete distances;
00419 }
00420
00421 template <typename Distance, typename NeighborWeighting>
00422 void KNN<Distance, NeighborWeighting>::predict_common
00423 (const utility::Matrix& distances, utility::Matrix& prediction) const
00424 {
00425 for(size_t sample=0;sample<distances.columns();sample++) {
00426 std::vector<size_t> k_index;
00427 utility::VectorConstView dist=distances.column_const_view(sample);
00428 utility::sort_smallest_index(k_index,k_,dist);
00429 utility::VectorView pred=prediction.column_view(sample);
00430 weighting_(dist,k_index,*target_,pred);
00431 }
00432
00433
00434
00435 for(size_t c=0;c<target_->nof_classes(); c++)
00436 if(!target_->size(c))
00437 for(size_t j=0;j<prediction.columns();j++)
00438 prediction(c,j)=std::numeric_limits<double>::quiet_NaN();
00439 }
00440 }}}
00441
00442 #endif