00001 #ifndef _theplu_yat_classifier_knn_
00002 #define _theplu_yat_classifier_knn_
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026 #include "DataLookup1D.h"
00027 #include "DataLookupWeighted1D.h"
00028 #include "KNN_Uniform.h"
00029 #include "MatrixLookup.h"
00030 #include "MatrixLookupWeighted.h"
00031 #include "SupervisedClassifier.h"
00032 #include "Target.h"
00033 #include "yat/utility/Matrix.h"
00034 #include "yat/utility/yat_assert.h"
00035
00036 #include <cmath>
00037 #include <limits>
00038 #include <map>
00039 #include <stdexcept>
00040
00041 namespace theplu {
00042 namespace yat {
00043 namespace classifier {
00044
00061 template <typename Distance, typename NeighborWeighting=KNN_Uniform>
00062 class KNN : public SupervisedClassifier
00063 {
00064
00065 public:
00073 KNN(void);
00074
00075
00085 KNN(const Distance&);
00086
00087
00091 virtual ~KNN();
00092
00093
00098 unsigned int k() const;
00099
00105 void k(unsigned int k_in);
00106
00107
00108 KNN<Distance,NeighborWeighting>* make_classifier(void) const;
00109
00120 void predict(const MatrixLookup& data , utility::Matrix& results) const;
00121
00134 void predict(const MatrixLookupWeighted& data, utility::Matrix& results) const;
00135
00136
00151 void train(const MatrixLookup& data, const Target& targets);
00152
00159 void train(const MatrixLookupWeighted& data, const Target& targets);
00160
00161 private:
00162
00163 const MatrixLookup* data_ml_;
00164 const MatrixLookupWeighted* data_mlw_;
00165 const Target* target_;
00166
00167
00168 unsigned int k_;
00169
00170 Distance distance_;
00171 NeighborWeighting weighting_;
00172
00173 void calculate_unweighted(const MatrixLookup&,
00174 const MatrixLookup&,
00175 utility::Matrix*) const;
00176 void calculate_weighted(const MatrixLookupWeighted&,
00177 const MatrixLookupWeighted&,
00178 utility::Matrix*) const;
00179
00180 void predict_common(const utility::Matrix& distances,
00181 utility::Matrix& prediction) const;
00182
00183 };
00184
00185
00186
00187
00188 template <typename Distance, typename NeighborWeighting>
00189 KNN<Distance, NeighborWeighting>::KNN()
00190 : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3)
00191 {
00192 }
00193
00194 template <typename Distance, typename NeighborWeighting>
00195 KNN<Distance, NeighborWeighting>::KNN(const Distance& dist)
00196 : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3), distance_(dist)
00197 {
00198 }
00199
00200
00201 template <typename Distance, typename NeighborWeighting>
00202 KNN<Distance, NeighborWeighting>::~KNN()
00203 {
00204 }
00205
00206
00207 template <typename Distance, typename NeighborWeighting>
00208 void KNN<Distance, NeighborWeighting>::calculate_unweighted
00209 (const MatrixLookup& training, const MatrixLookup& test,
00210 utility::Matrix* distances) const
00211 {
00212 for(size_t i=0; i<training.columns(); i++) {
00213 for(size_t j=0; j<test.columns(); j++) {
00214 (*distances)(i,j) = distance_(training.begin_column(i), training.end_column(i),
00215 test.begin_column(j));
00216 YAT_ASSERT(!std::isnan((*distances)(i,j)));
00217 }
00218 }
00219 }
00220
00221
00222 template <typename Distance, typename NeighborWeighting>
00223 void
00224 KNN<Distance, NeighborWeighting>::calculate_weighted
00225 (const MatrixLookupWeighted& training, const MatrixLookupWeighted& test,
00226 utility::Matrix* distances) const
00227 {
00228 for(size_t i=0; i<training.columns(); i++) {
00229 for(size_t j=0; j<test.columns(); j++) {
00230 (*distances)(i,j) = distance_(training.begin_column(i), training.end_column(i),
00231 test.begin_column(j));
00232
00233
00234 if(std::isnan((*distances)(i,j)))
00235 (*distances)(i,j)=std::numeric_limits<double>::infinity();
00236 }
00237 }
00238 }
00239
00240
00241 template <typename Distance, typename NeighborWeighting>
00242 unsigned int KNN<Distance, NeighborWeighting>::k() const
00243 {
00244 return k_;
00245 }
00246
00247 template <typename Distance, typename NeighborWeighting>
00248 void KNN<Distance, NeighborWeighting>::k(unsigned int k)
00249 {
00250 k_=k;
00251 }
00252
00253
00254 template <typename Distance, typename NeighborWeighting>
00255 KNN<Distance, NeighborWeighting>*
00256 KNN<Distance, NeighborWeighting>::make_classifier() const
00257 {
00258
00259
00260 KNN* knn=new KNN<Distance, NeighborWeighting>(distance_);
00261 knn->weighting_=this->weighting_;
00262 knn->k(this->k());
00263 return knn;
00264 }
00265
00266
00267 template <typename Distance, typename NeighborWeighting>
00268 void KNN<Distance, NeighborWeighting>::train(const MatrixLookup& data,
00269 const Target& target)
00270 {
00271 utility::yat_assert<std::runtime_error>
00272 (data.columns()==target.size(),
00273 "KNN::train called with different sizes of target and data");
00274
00275 if(data.columns()<k_)
00276 k_=data.columns();
00277 data_ml_=&data;
00278 data_mlw_=0;
00279 target_=⌖
00280 }
00281
00282 template <typename Distance, typename NeighborWeighting>
00283 void KNN<Distance, NeighborWeighting>::train(const MatrixLookupWeighted& data,
00284 const Target& target)
00285 {
00286 utility::yat_assert<std::runtime_error>
00287 (data.columns()==target.size(),
00288 "KNN::train called with different sizes of target and data");
00289
00290 if(data.columns()<k_)
00291 k_=data.columns();
00292 data_ml_=0;
00293 data_mlw_=&data;
00294 target_=⌖
00295 }
00296
00297
00298 template <typename Distance, typename NeighborWeighting>
00299 void KNN<Distance, NeighborWeighting>::predict(const MatrixLookup& test,
00300 utility::Matrix& prediction) const
00301 {
00302
00303 utility::Matrix* distances = 0;
00304
00305 if(data_ml_ && !data_mlw_) {
00306 utility::yat_assert<std::runtime_error>
00307 (data_ml_->rows()==test.rows(),
00308 "KNN::predict different number of rows in training and test data");
00309 distances=new utility::Matrix(data_ml_->columns(),test.columns());
00310 calculate_unweighted(*data_ml_,test,distances);
00311 }
00312 else if (data_mlw_ && !data_ml_) {
00313
00314 utility::yat_assert<std::runtime_error>
00315 (data_mlw_->rows()==test.rows(),
00316 "KNN::predict different number of rows in training and test data");
00317 distances=new utility::Matrix(data_mlw_->columns(),test.columns());
00318 calculate_weighted(*data_mlw_,MatrixLookupWeighted(test),
00319 distances);
00320 }
00321 else {
00322 std::runtime_error("KNN::predict no training data");
00323 }
00324
00325 prediction.resize(target_->nof_classes(),test.columns(),0.0);
00326 predict_common(*distances,prediction);
00327 if(distances)
00328 delete distances;
00329 }
00330
00331 template <typename Distance, typename NeighborWeighting>
00332 void KNN<Distance, NeighborWeighting>::predict(const MatrixLookupWeighted& test,
00333 utility::Matrix& prediction) const
00334 {
00335
00336 utility::Matrix* distances=0;
00337
00338 if(data_ml_ && !data_mlw_) {
00339 utility::yat_assert<std::runtime_error>
00340 (data_ml_->rows()==test.rows(),
00341 "KNN::predict different number of rows in training and test data");
00342 distances=new utility::Matrix(data_ml_->columns(),test.columns());
00343 calculate_weighted(MatrixLookupWeighted(*data_ml_),test,distances);
00344 }
00345
00346 else if (data_mlw_ && !data_ml_) {
00347 utility::yat_assert<std::runtime_error>
00348 (data_mlw_->rows()==test.rows(),
00349 "KNN::predict different number of rows in training and test data");
00350 distances=new utility::Matrix(data_mlw_->columns(),test.columns());
00351 calculate_weighted(*data_mlw_,test,distances);
00352 }
00353 else {
00354 std::runtime_error("KNN::predict no training data");
00355 }
00356
00357 prediction.resize(target_->nof_classes(),test.columns(),0.0);
00358 predict_common(*distances,prediction);
00359
00360 if(distances)
00361 delete distances;
00362 }
00363
00364 template <typename Distance, typename NeighborWeighting>
00365 void KNN<Distance, NeighborWeighting>::predict_common
00366 (const utility::Matrix& distances, utility::Matrix& prediction) const
00367 {
00368 for(size_t sample=0;sample<distances.columns();sample++) {
00369 std::vector<size_t> k_index;
00370 utility::VectorConstView dist=distances.column_const_view(sample);
00371 utility::sort_smallest_index(k_index,k_,dist);
00372 utility::VectorView pred=prediction.column_view(sample);
00373 weighting_(dist,k_index,*target_,pred);
00374 }
00375
00376
00377
00378 for(size_t c=0;c<target_->nof_classes(); c++)
00379 if(!target_->size(c))
00380 for(size_t j=0;j<prediction.columns();j++)
00381 prediction(c,j)=std::numeric_limits<double>::quiet_NaN();
00382 }
00383
00384
00385 }}}
00386
00387 #endif
00388