yat/classifier/KNN.h

Code
Comments
Other
Rev Date Author Line
3552 03 Jan 17 peter 1 #ifndef _theplu_yat_classifier_knn_
3552 03 Jan 17 peter 2 #define _theplu_yat_classifier_knn_
902 27 Sep 07 markus 3
902 27 Sep 07 markus 4 // $Id$
902 27 Sep 07 markus 5
999 23 Dec 07 jari 6 /*
2119 12 Dec 09 peter 7   Copyright (C) 2007, 2008 Jari Häkkinen, Peter Johansson, Markus Ringnér
4359 23 Aug 23 peter 8   Copyright (C) 2009, 2010 Peter Johansson
999 23 Dec 07 jari 9
1437 25 Aug 08 peter 10   This file is part of the yat library, http://dev.thep.lu.se/yat
999 23 Dec 07 jari 11
999 23 Dec 07 jari 12   The yat library is free software; you can redistribute it and/or
999 23 Dec 07 jari 13   modify it under the terms of the GNU General Public License as
1486 09 Sep 08 jari 14   published by the Free Software Foundation; either version 3 of the
999 23 Dec 07 jari 15   License, or (at your option) any later version.
999 23 Dec 07 jari 16
999 23 Dec 07 jari 17   The yat library is distributed in the hope that it will be useful,
999 23 Dec 07 jari 18   but WITHOUT ANY WARRANTY; without even the implied warranty of
999 23 Dec 07 jari 19   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
999 23 Dec 07 jari 20   General Public License for more details.
999 23 Dec 07 jari 21
999 23 Dec 07 jari 22   You should have received a copy of the GNU General Public License
1487 10 Sep 08 jari 23   along with yat. If not, see <http://www.gnu.org/licenses/>.
999 23 Dec 07 jari 24 */
999 23 Dec 07 jari 25
1050 07 Feb 08 peter 26 #include "DataLookup1D.h"
902 27 Sep 07 markus 27 #include "DataLookupWeighted1D.h"
1112 21 Feb 08 markus 28 #include "KNN_Uniform.h"
948 08 Oct 07 markus 29 #include "MatrixLookup.h"
902 27 Sep 07 markus 30 #include "MatrixLookupWeighted.h"
902 27 Sep 07 markus 31 #include "SupervisedClassifier.h"
902 27 Sep 07 markus 32 #include "Target.h"
2334 15 Oct 10 peter 33 #include "yat/utility/concept_check.h"
2210 05 Mar 10 peter 34 #include "yat/utility/Exception.h"
1121 22 Feb 08 peter 35 #include "yat/utility/Matrix.h"
2340 16 Oct 10 peter 36 #include "yat/utility/Vector.h"
2337 15 Oct 10 peter 37 #include "yat/utility/VectorConstView.h"
2337 15 Oct 10 peter 38 #include "yat/utility/VectorView.h"
916 30 Sep 07 peter 39 #include "yat/utility/yat_assert.h"
902 27 Sep 07 markus 40
2334 15 Oct 10 peter 41 #include <boost/concept_check.hpp>
2334 15 Oct 10 peter 42
916 30 Sep 07 peter 43 #include <cmath>
2055 08 Sep 09 peter 44 #include <limits>
902 27 Sep 07 markus 45 #include <map>
936 05 Oct 07 peter 46 #include <stdexcept>
2337 15 Oct 10 peter 47 #include <vector>
902 27 Sep 07 markus 48
902 27 Sep 07 markus 49 namespace theplu {
902 27 Sep 07 markus 50 namespace yat {
902 27 Sep 07 markus 51 namespace classifier {
902 27 Sep 07 markus 52
1188 29 Feb 08 markus 53   /**
1189 29 Feb 08 markus 54      \brief Nearest Neighbor Classifier
3552 03 Jan 17 peter 55
1188 29 Feb 08 markus 56      A sample is predicted based on the classes of its k nearest
1188 29 Feb 08 markus 57      neighbors among the training data samples. KNN supports using
1188 29 Feb 08 markus 58      different measures, for example, Euclidean distance, to define
1188 29 Feb 08 markus 59      distance between samples. KNN also supports using different ways to
1188 29 Feb 08 markus 60      weight the votes of the k nearest neighbors. For example, using a
1188 29 Feb 08 markus 61      uniform vote a test sample gets a vote for each class which is the
1188 29 Feb 08 markus 62      number of nearest neighbors belonging to the class.
3552 03 Jan 17 peter 63
1188 29 Feb 08 markus 64      The template argument Distance should be a class modelling the
1188 29 Feb 08 markus 65      concept \ref concept_distance. The template argument
1188 29 Feb 08 markus 66      NeighborWeighting should be a class modelling the concept \ref
1188 29 Feb 08 markus 67      concept_neighbor_weighting.
1188 29 Feb 08 markus 68   */
1112 21 Feb 08 markus 69   template <typename Distance, typename NeighborWeighting=KNN_Uniform>
902 27 Sep 07 markus 70   class KNN : public SupervisedClassifier
902 27 Sep 07 markus 71   {
3552 03 Jan 17 peter 72
902 27 Sep 07 markus 73   public:
1188 29 Feb 08 markus 74     /**
1189 29 Feb 08 markus 75        \brief Default constructor.
3552 03 Jan 17 peter 76
1188 29 Feb 08 markus 77        The number of nearest neighbors (k) is set to 3. Distance and
1188 29 Feb 08 markus 78        NeighborWeighting are initialized using their default
1188 29 Feb 08 markus 79        constructuors.
1188 29 Feb 08 markus 80     */
1157 26 Feb 08 markus 81     KNN(void);
948 08 Oct 07 markus 82
948 08 Oct 07 markus 83
1188 29 Feb 08 markus 84     /**
1189 29 Feb 08 markus 85        \brief Constructor using an intialized distance measure.
3552 03 Jan 17 peter 86
1189 29 Feb 08 markus 87        The number of nearest neighbors (k) is set to
1189 29 Feb 08 markus 88        3. NeighborWeighting is initialized using its default
1189 29 Feb 08 markus 89        constructor. This constructor should be used if Distance has
1189 29 Feb 08 markus 90        parameters and the user wants to specify the parameters by
1189 29 Feb 08 markus 91        initializing Distance prior to constructing the KNN.
3552 03 Jan 17 peter 92     */
1158 26 Feb 08 markus 93     KNN(const Distance&);
1158 26 Feb 08 markus 94
1158 26 Feb 08 markus 95
1188 29 Feb 08 markus 96     /**
1188 29 Feb 08 markus 97        Destructor
1188 29 Feb 08 markus 98     */
902 27 Sep 07 markus 99     virtual ~KNN();
3552 03 Jan 17 peter 100
3552 03 Jan 17 peter 101
1188 29 Feb 08 markus 102     /**
1188 29 Feb 08 markus 103        \brief Get the number of nearest neighbors.
1188 29 Feb 08 markus 104        \return The number of neighbors.
1188 29 Feb 08 markus 105     */
1271 09 Apr 08 peter 106     unsigned int k() const;
902 27 Sep 07 markus 107
1188 29 Feb 08 markus 108     /**
1188 29 Feb 08 markus 109        \brief Set the number of nearest neighbors.
3552 03 Jan 17 peter 110
3552 03 Jan 17 peter 111        Sets the number of neighbors to \a k_in.
1188 29 Feb 08 markus 112     */
1271 09 Apr 08 peter 113     void k(unsigned int k_in);
902 27 Sep 07 markus 114
902 27 Sep 07 markus 115
1157 26 Feb 08 markus 116     KNN<Distance,NeighborWeighting>* make_classifier(void) const;
3552 03 Jan 17 peter 117
1188 29 Feb 08 markus 118     /**
1189 29 Feb 08 markus 119        \brief Make predictions for unweighted test data.
3552 03 Jan 17 peter 120
1188 29 Feb 08 markus 121        Predictions are calculated and returned in \a results.  For
1188 29 Feb 08 markus 122        each sample in \a data, \a results contains the weighted number
1188 29 Feb 08 markus 123        of nearest neighbors which belong to each class. Numbers of
1188 29 Feb 08 markus 124        nearest neighbors are weighted according to
1188 29 Feb 08 markus 125        NeighborWeighting. If a class has no training samples NaN's are
1188 29 Feb 08 markus 126        returned for this class in \a results.
1188 29 Feb 08 markus 127     */
1188 29 Feb 08 markus 128     void predict(const MatrixLookup& data , utility::Matrix& results) const;
902 27 Sep 07 markus 129
3552 03 Jan 17 peter 130     /**
1189 29 Feb 08 markus 131         \brief Make predictions for weighted test data.
3552 03 Jan 17 peter 132
1188 29 Feb 08 markus 133         Predictions are calculated and returned in \a results. For
1188 29 Feb 08 markus 134         each sample in \a data, \a results contains the weighted
1188 29 Feb 08 markus 135         number of nearest neighbors which belong to each class as in
1188 29 Feb 08 markus 136         predict(const MatrixLookup& data, utility::Matrix& results).
1188 29 Feb 08 markus 137         If a test and training sample pair has no variables with
1188 29 Feb 08 markus 138         non-zero weights in common, there are no variables which can
1188 29 Feb 08 markus 139         be used to calculate the distance between the two samples. In
1188 29 Feb 08 markus 140         this case the distance between the two is set to infinity.
1188 29 Feb 08 markus 141     */
3552 03 Jan 17 peter 142     void predict(const MatrixLookupWeighted& data,
2336 15 Oct 10 peter 143                  utility::Matrix& results) const;
1157 26 Feb 08 markus 144
1160 26 Feb 08 markus 145
1188 29 Feb 08 markus 146     /**
1189 29 Feb 08 markus 147        \brief Train the KNN using unweighted training data with known
3552 03 Jan 17 peter 148        targets.
3552 03 Jan 17 peter 149
1188 29 Feb 08 markus 150        For KNN there is no actual training; the entire training data
1188 29 Feb 08 markus 151        set is stored with targets. KNN only stores references to \a data
1188 29 Feb 08 markus 152        and \a targets as copying these would make the %classifier
1188 29 Feb 08 markus 153        slow. If the number of training samples set is smaller than k,
1188 29 Feb 08 markus 154        k is set to the number of training samples.
3552 03 Jan 17 peter 155
1188 29 Feb 08 markus 156        \note If \a data or \a targets go out of scope ore are
1188 29 Feb 08 markus 157        deleted, the KNN becomes invalid and further use is undefined
1188 29 Feb 08 markus 158        unless it is trained again.
1188 29 Feb 08 markus 159     */
1188 29 Feb 08 markus 160     void train(const MatrixLookup& data, const Target& targets);
3552 03 Jan 17 peter 161
3552 03 Jan 17 peter 162     /**
3552 03 Jan 17 peter 163        \brief Train the KNN using weighted training data with known targets.
3552 03 Jan 17 peter 164
1188 29 Feb 08 markus 165        See train(const MatrixLookup& data, const Target& targets) for
1188 29 Feb 08 markus 166        additional information.
1188 29 Feb 08 markus 167     */
1188 29 Feb 08 markus 168     void train(const MatrixLookupWeighted& data, const Target& targets);
3552 03 Jan 17 peter 169
902 27 Sep 07 markus 170   private:
3552 03 Jan 17 peter 171
1160 26 Feb 08 markus 172     const MatrixLookup* data_ml_;
1160 26 Feb 08 markus 173     const MatrixLookupWeighted* data_mlw_;
1157 26 Feb 08 markus 174     const Target* target_;
902 27 Sep 07 markus 175
1112 21 Feb 08 markus 176     // The number of neighbors
1271 09 Apr 08 peter 177     unsigned int k_;
902 27 Sep 07 markus 178
1050 07 Feb 08 peter 179     Distance distance_;
1112 21 Feb 08 markus 180     NeighborWeighting weighting_;
1112 21 Feb 08 markus 181
1107 19 Feb 08 markus 182     void calculate_unweighted(const MatrixLookup&,
1107 19 Feb 08 markus 183                               const MatrixLookup&,
1121 22 Feb 08 peter 184                               utility::Matrix*) const;
1107 19 Feb 08 markus 185     void calculate_weighted(const MatrixLookupWeighted&,
1107 19 Feb 08 markus 186                             const MatrixLookupWeighted&,
1121 22 Feb 08 peter 187                             utility::Matrix*) const;
1160 26 Feb 08 markus 188
3552 03 Jan 17 peter 189     void predict_common(const utility::Matrix& distances,
1160 26 Feb 08 markus 190                         utility::Matrix& prediction) const;
1160 26 Feb 08 markus 191
902 27 Sep 07 markus 192   };
3552 03 Jan 17 peter 193
3552 03 Jan 17 peter 194
2340 16 Oct 10 peter 195   /**
2340 16 Oct 10 peter 196      \brief Concept check for a \ref concept_neighbor_weighting
2340 16 Oct 10 peter 197
3552 03 Jan 17 peter 198      This class is intended to be used in a <a
2340 16 Oct 10 peter 199      href="\boost_url/concept_check/using_concept_check.htm">
2340 16 Oct 10 peter 200      BOOST_CONCEPT_ASSERT </a>
2340 16 Oct 10 peter 201
2340 16 Oct 10 peter 202      \code
2340 16 Oct 10 peter 203      template<class Distance>
2340 16 Oct 10 peter 204      void some_function(double x)
2340 16 Oct 10 peter 205      {
2340 16 Oct 10 peter 206      BOOST_CONCEPT_ASSERT((DistanceConcept<Distance>));
2340 16 Oct 10 peter 207      ...
2340 16 Oct 10 peter 208      }
2340 16 Oct 10 peter 209      \endcode
2340 16 Oct 10 peter 210
2340 16 Oct 10 peter 211      \since New in yat 0.7
2340 16 Oct 10 peter 212   */
2340 16 Oct 10 peter 213   template <class T>
3552 03 Jan 17 peter 214   class NeighborWeightingConcept
2340 16 Oct 10 peter 215     : public boost::DefaultConstructible<T>, public boost::Assignable<T>
2340 16 Oct 10 peter 216   {
2340 16 Oct 10 peter 217   public:
2340 16 Oct 10 peter 218     /**
2340 16 Oct 10 peter 219        \brief function doing the concept test
2340 16 Oct 10 peter 220      */
2340 16 Oct 10 peter 221     BOOST_CONCEPT_USAGE(NeighborWeightingConcept)
2340 16 Oct 10 peter 222     {
2340 16 Oct 10 peter 223       T neighbor_weighting;
2340 16 Oct 10 peter 224       utility::Vector vec;
2340 16 Oct 10 peter 225       const utility::VectorBase& distance(vec);
2340 16 Oct 10 peter 226       utility::VectorMutable& prediction(vec);
2340 16 Oct 10 peter 227       std::vector<size_t> k_sorted;
2340 16 Oct 10 peter 228       Target target;
2340 16 Oct 10 peter 229       neighbor_weighting(distance, k_sorted, target, prediction);
2340 16 Oct 10 peter 230     }
2340 16 Oct 10 peter 231   private:
2340 16 Oct 10 peter 232   };
2340 16 Oct 10 peter 233
2340 16 Oct 10 peter 234   // template implementation
3552 03 Jan 17 peter 235
1112 21 Feb 08 markus 236   template <typename Distance, typename NeighborWeighting>
3552 03 Jan 17 peter 237   KNN<Distance, NeighborWeighting>::KNN()
1160 26 Feb 08 markus 238     : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3)
948 08 Oct 07 markus 239   {
2334 15 Oct 10 peter 240     BOOST_CONCEPT_ASSERT((utility::DistanceConcept<Distance>));
2340 16 Oct 10 peter 241     BOOST_CONCEPT_ASSERT((NeighborWeightingConcept<NeighborWeighting>));
948 08 Oct 07 markus 242   }
948 08 Oct 07 markus 243
1158 26 Feb 08 markus 244   template <typename Distance, typename NeighborWeighting>
3552 03 Jan 17 peter 245   KNN<Distance, NeighborWeighting>::KNN(const Distance& dist)
3552 03 Jan 17 peter 246     : SupervisedClassifier(), data_ml_(0), data_mlw_(0), target_(0), k_(3),
2335 15 Oct 10 peter 247       distance_(dist)
1158 26 Feb 08 markus 248   {
2334 15 Oct 10 peter 249     BOOST_CONCEPT_ASSERT((utility::DistanceConcept<Distance>));
3553 03 Jan 17 peter 250     BOOST_CONCEPT_ASSERT((NeighborWeightingConcept<NeighborWeighting>));
1158 26 Feb 08 markus 251   }
1158 26 Feb 08 markus 252
3552 03 Jan 17 peter 253
1112 21 Feb 08 markus 254   template <typename Distance, typename NeighborWeighting>
3552 03 Jan 17 peter 255   KNN<Distance, NeighborWeighting>::~KNN()
902 27 Sep 07 markus 256   {
902 27 Sep 07 markus 257   }
1107 19 Feb 08 markus 258
3552 03 Jan 17 peter 259
1112 21 Feb 08 markus 260   template <typename Distance, typename NeighborWeighting>
1112 21 Feb 08 markus 261   void  KNN<Distance, NeighborWeighting>::calculate_unweighted
1112 21 Feb 08 markus 262   (const MatrixLookup& training, const MatrixLookup& test,
1121 22 Feb 08 peter 263    utility::Matrix* distances) const
1107 19 Feb 08 markus 264   {
1107 19 Feb 08 markus 265     for(size_t i=0; i<training.columns(); i++) {
1107 19 Feb 08 markus 266       for(size_t j=0; j<test.columns(); j++) {
3552 03 Jan 17 peter 267         (*distances)(i,j) =  distance_(training.begin_column(i),
3552 03 Jan 17 peter 268                                       training.end_column(i),
1160 26 Feb 08 markus 269                                       test.begin_column(j));
1875 19 Mar 09 peter 270         YAT_ASSERT(!std::isnan((*distances)(i,j)));
948 08 Oct 07 markus 271       }
1107 19 Feb 08 markus 272     }
1107 19 Feb 08 markus 273   }
1160 26 Feb 08 markus 274
3552 03 Jan 17 peter 275
1112 21 Feb 08 markus 276   template <typename Distance, typename NeighborWeighting>
3552 03 Jan 17 peter 277   void
1112 21 Feb 08 markus 278   KNN<Distance, NeighborWeighting>::calculate_weighted
1112 21 Feb 08 markus 279   (const MatrixLookupWeighted& training, const MatrixLookupWeighted& test,
1121 22 Feb 08 peter 280    utility::Matrix* distances) const
1107 19 Feb 08 markus 281   {
3552 03 Jan 17 peter 282     for(size_t i=0; i<training.columns(); i++) {
1107 19 Feb 08 markus 283       for(size_t j=0; j<test.columns(); j++) {
3552 03 Jan 17 peter 284         (*distances)(i,j) =  distance_(training.begin_column(i),
3552 03 Jan 17 peter 285                                       training.end_column(i),
1160 26 Feb 08 markus 286                                       test.begin_column(j));
1156 26 Feb 08 markus 287         // If the distance is NaN (no common variables with non-zero weights),
1156 26 Feb 08 markus 288         // the distance is set to infinity to be sorted as a neighbor at the end
3552 03 Jan 17 peter 289         if(std::isnan((*distances)(i,j)))
1156 26 Feb 08 markus 290           (*distances)(i,j)=std::numeric_limits<double>::infinity();
948 08 Oct 07 markus 291       }
902 27 Sep 07 markus 292     }
902 27 Sep 07 markus 293   }
3552 03 Jan 17 peter 294
3552 03 Jan 17 peter 295
1112 21 Feb 08 markus 296   template <typename Distance, typename NeighborWeighting>
1271 09 Apr 08 peter 297   unsigned int KNN<Distance, NeighborWeighting>::k() const
902 27 Sep 07 markus 298   {
902 27 Sep 07 markus 299     return k_;
902 27 Sep 07 markus 300   }
902 27 Sep 07 markus 301
1112 21 Feb 08 markus 302   template <typename Distance, typename NeighborWeighting>
1274 10 Apr 08 peter 303   void KNN<Distance, NeighborWeighting>::k(unsigned int k)
902 27 Sep 07 markus 304   {
902 27 Sep 07 markus 305     k_=k;
902 27 Sep 07 markus 306   }
902 27 Sep 07 markus 307
902 27 Sep 07 markus 308
1112 21 Feb 08 markus 309   template <typename Distance, typename NeighborWeighting>
3552 03 Jan 17 peter 310   KNN<Distance, NeighborWeighting>*
3552 03 Jan 17 peter 311   KNN<Distance, NeighborWeighting>::make_classifier() const
3552 03 Jan 17 peter 312   {
1164 26 Feb 08 markus 313     // All private members should be copied here to generate an
1164 26 Feb 08 markus 314     // identical but untrained classifier
1164 26 Feb 08 markus 315     KNN* knn=new KNN<Distance, NeighborWeighting>(distance_);
1164 26 Feb 08 markus 316     knn->weighting_=this->weighting_;
1157 26 Feb 08 markus 317     knn->k(this->k());
902 27 Sep 07 markus 318     return knn;
902 27 Sep 07 markus 319   }
3552 03 Jan 17 peter 320
3552 03 Jan 17 peter 321
1112 21 Feb 08 markus 322   template <typename Distance, typename NeighborWeighting>
3552 03 Jan 17 peter 323   void KNN<Distance, NeighborWeighting>::train(const MatrixLookup& data,
1157 26 Feb 08 markus 324                                                const Target& target)
3552 03 Jan 17 peter 325   {
2210 05 Mar 10 peter 326     utility::yat_assert<utility::runtime_error>
1157 26 Feb 08 markus 327       (data.columns()==target.size(),
1157 26 Feb 08 markus 328        "KNN::train called with different sizes of target and data");
1157 26 Feb 08 markus 329     // k has to be at most the number of training samples.
3552 03 Jan 17 peter 330     if(data.columns()<k_)
1157 26 Feb 08 markus 331       k_=data.columns();
1160 26 Feb 08 markus 332     data_ml_=&data;
1160 26 Feb 08 markus 333     data_mlw_=0;
1157 26 Feb 08 markus 334     target_=&target;
902 27 Sep 07 markus 335   }
902 27 Sep 07 markus 336
1157 26 Feb 08 markus 337   template <typename Distance, typename NeighborWeighting>
3552 03 Jan 17 peter 338   void KNN<Distance, NeighborWeighting>::train(const MatrixLookupWeighted& data,
1157 26 Feb 08 markus 339                                                const Target& target)
3552 03 Jan 17 peter 340   {
2210 05 Mar 10 peter 341     utility::yat_assert<utility::runtime_error>
1157 26 Feb 08 markus 342       (data.columns()==target.size(),
1157 26 Feb 08 markus 343        "KNN::train called with different sizes of target and data");
1157 26 Feb 08 markus 344     // k has to be at most the number of training samples.
3552 03 Jan 17 peter 345     if(data.columns()<k_)
1157 26 Feb 08 markus 346       k_=data.columns();
1160 26 Feb 08 markus 347     data_ml_=0;
1160 26 Feb 08 markus 348     data_mlw_=&data;
1157 26 Feb 08 markus 349     target_=&target;
1157 26 Feb 08 markus 350   }
902 27 Sep 07 markus 351
1157 26 Feb 08 markus 352
1112 21 Feb 08 markus 353   template <typename Distance, typename NeighborWeighting>
3552 03 Jan 17 peter 354   void
2335 15 Oct 10 peter 355   KNN<Distance, NeighborWeighting>::predict(const MatrixLookup& test,
2335 15 Oct 10 peter 356                                             utility::Matrix& prediction) const
3552 03 Jan 17 peter 357   {
1160 26 Feb 08 markus 358     // matrix with training samples as rows and test samples as columns
1160 26 Feb 08 markus 359     utility::Matrix* distances = 0;
1160 26 Feb 08 markus 360     // unweighted training data
1160 26 Feb 08 markus 361     if(data_ml_ && !data_mlw_) {
2210 05 Mar 10 peter 362       utility::yat_assert<utility::runtime_error>
1160 26 Feb 08 markus 363         (data_ml_->rows()==test.rows(),
2336 15 Oct 10 peter 364          "KNN::predict different number of rows in training and test data");
1160 26 Feb 08 markus 365       distances=new utility::Matrix(data_ml_->columns(),test.columns());
1160 26 Feb 08 markus 366       calculate_unweighted(*data_ml_,test,distances);
1160 26 Feb 08 markus 367     }
1160 26 Feb 08 markus 368     else if (data_mlw_ && !data_ml_) {
1160 26 Feb 08 markus 369       // weighted training data
2210 05 Mar 10 peter 370       utility::yat_assert<utility::runtime_error>
1160 26 Feb 08 markus 371         (data_mlw_->rows()==test.rows(),
2336 15 Oct 10 peter 372          "KNN::predict different number of rows in training and test data");
1160 26 Feb 08 markus 373       distances=new utility::Matrix(data_mlw_->columns(),test.columns());
1160 26 Feb 08 markus 374       calculate_weighted(*data_mlw_,MatrixLookupWeighted(test),
3552 03 Jan 17 peter 375                          distances);
1160 26 Feb 08 markus 376     }
1160 26 Feb 08 markus 377     else {
2210 05 Mar 10 peter 378       throw utility::runtime_error("KNN::predict no training data");
1160 26 Feb 08 markus 379     }
1031 04 Feb 08 markus 380
1160 26 Feb 08 markus 381     prediction.resize(target_->nof_classes(),test.columns(),0.0);
1160 26 Feb 08 markus 382     predict_common(*distances,prediction);
1160 26 Feb 08 markus 383     if(distances)
1160 26 Feb 08 markus 384       delete distances;
1160 26 Feb 08 markus 385   }
1160 26 Feb 08 markus 386
1160 26 Feb 08 markus 387   template <typename Distance, typename NeighborWeighting>
3552 03 Jan 17 peter 388   void
2335 15 Oct 10 peter 389   KNN<Distance, NeighborWeighting>::predict(const MatrixLookupWeighted& test,
2335 15 Oct 10 peter 390                                             utility::Matrix& prediction) const
3552 03 Jan 17 peter 391   {
1160 26 Feb 08 markus 392     // matrix with training samples as rows and test samples as columns
3552 03 Jan 17 peter 393     utility::Matrix* distances=0;
1160 26 Feb 08 markus 394     // unweighted training data
3552 03 Jan 17 peter 395     if(data_ml_ && !data_mlw_) {
2210 05 Mar 10 peter 396       utility::yat_assert<utility::runtime_error>
1160 26 Feb 08 markus 397         (data_ml_->rows()==test.rows(),
3552 03 Jan 17 peter 398          "KNN::predict different number of rows in training and test data");
1160 26 Feb 08 markus 399       distances=new utility::Matrix(data_ml_->columns(),test.columns());
3552 03 Jan 17 peter 400       calculate_weighted(MatrixLookupWeighted(*data_ml_),test,distances);
1160 26 Feb 08 markus 401     }
1160 26 Feb 08 markus 402     // weighted training data
1160 26 Feb 08 markus 403     else if (data_mlw_ && !data_ml_) {
2210 05 Mar 10 peter 404       utility::yat_assert<utility::runtime_error>
1160 26 Feb 08 markus 405         (data_mlw_->rows()==test.rows(),
3552 03 Jan 17 peter 406          "KNN::predict different number of rows in training and test data");
1160 26 Feb 08 markus 407       distances=new utility::Matrix(data_mlw_->columns(),test.columns());
3552 03 Jan 17 peter 408       calculate_weighted(*data_mlw_,test,distances);
1160 26 Feb 08 markus 409     }
1160 26 Feb 08 markus 410     else {
2210 05 Mar 10 peter 411       throw utility::runtime_error("KNN::predict no training data");
1160 26 Feb 08 markus 412     }
1160 26 Feb 08 markus 413
1160 26 Feb 08 markus 414     prediction.resize(target_->nof_classes(),test.columns(),0.0);
1160 26 Feb 08 markus 415     predict_common(*distances,prediction);
3552 03 Jan 17 peter 416
1160 26 Feb 08 markus 417     if(distances)
1160 26 Feb 08 markus 418       delete distances;
1160 26 Feb 08 markus 419   }
3552 03 Jan 17 peter 420
1160 26 Feb 08 markus 421   template <typename Distance, typename NeighborWeighting>
1160 26 Feb 08 markus 422   void KNN<Distance, NeighborWeighting>::predict_common
1160 26 Feb 08 markus 423   (const utility::Matrix& distances, utility::Matrix& prediction) const
3552 03 Jan 17 peter 424   {
1160 26 Feb 08 markus 425     for(size_t sample=0;sample<distances.columns();sample++) {
902 27 Sep 07 markus 426       std::vector<size_t> k_index;
1160 26 Feb 08 markus 427       utility::VectorConstView dist=distances.column_const_view(sample);
1112 21 Feb 08 markus 428       utility::sort_smallest_index(k_index,k_,dist);
1112 21 Feb 08 markus 429       utility::VectorView pred=prediction.column_view(sample);
1157 26 Feb 08 markus 430       weighting_(dist,k_index,*target_,pred);
902 27 Sep 07 markus 431     }
3552 03 Jan 17 peter 432
1142 25 Feb 08 markus 433     // classes for which there are no training samples should be set
1142 25 Feb 08 markus 434     // to nan in the predictions
3552 03 Jan 17 peter 435     for(size_t c=0;c<target_->nof_classes(); c++)
3552 03 Jan 17 peter 436       if(!target_->size(c))
1142 25 Feb 08 markus 437         for(size_t j=0;j<prediction.columns();j++)
1142 25 Feb 08 markus 438           prediction(c,j)=std::numeric_limits<double>::quiet_NaN();
902 27 Sep 07 markus 439   }
902 27 Sep 07 markus 440 }}} // of namespace classifier, yat, and theplu
902 27 Sep 07 markus 441
902 27 Sep 07 markus 442 #endif