yat  0.17pre
KNN.h
1 #ifndef _theplu_yat_classifier_knn_
2 #define _theplu_yat_classifier_knn_
3 
4 // $Id: KNN.h 3562 2017-01-04 01:16:07Z peter $
5 
6 /*
7  Copyright (C) 2007, 2008 Jari Häkkinen, Peter Johansson, Markus Ringnér
8  Copyright (C) 2009, 2010, 2017 Peter Johansson
9 
10  This file is part of the yat library, http://dev.thep.lu.se/yat
11 
12  The yat library is free software; you can redistribute it and/or
13  modify it under the terms of the GNU General Public License as
14  published by the Free Software Foundation; either version 3 of the
15  License, or (at your option) any later version.
16 
17  The yat library is distributed in the hope that it will be useful,
18  but WITHOUT ANY WARRANTY; without even the implied warranty of
19  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
20  General Public License for more details.
21 
22  You should have received a copy of the GNU General Public License
23  along with yat. If not, see <http://www.gnu.org/licenses/>.
24 */
25 
26 #include "DataLookup1D.h"
27 #include "DataLookupWeighted1D.h"
28 #include "KNN_Uniform.h"
29 #include "MatrixLookup.h"
30 #include "MatrixLookupWeighted.h"
31 #include "SupervisedClassifier.h"
32 #include "Target.h"
33 #include "yat/utility/concept_check.h"
34 #include "yat/utility/Exception.h"
35 #include "yat/utility/Matrix.h"
36 #include "yat/utility/Vector.h"
37 #include "yat/utility/VectorConstView.h"
38 #include "yat/utility/VectorView.h"
39 #include "yat/utility/yat_assert.h"
40 
41 #include <boost/concept_check.hpp>
42 
43 #include <cmath>
44 #include <limits>
45 #include <map>
46 #include <stdexcept>
47 #include <vector>
48 
49 namespace theplu {
50 namespace yat {
51 namespace classifier {
52 
69  template <typename Distance, typename NeighborWeighting=KNN_Uniform>
70  class KNN : public SupervisedClassifier
71  {
72 
73  public:
81  KNN(void);
82 
83 
93  KNN(const Distance&);
94 
95 
99  virtual ~KNN();
100 
101 
106  unsigned int k() const;
107 
113  void k(unsigned int k_in);
114 
115 
117 
128  void predict(const MatrixLookup& data , utility::Matrix& results) const;
129 
142  void predict(const MatrixLookupWeighted& data,
143  utility::Matrix& results) const;
144 
145 
160  void train(const MatrixLookup& data, const Target& targets);
161 
168  void train(const MatrixLookupWeighted& data, const Target& targets);
169 
170  private:
171 
172  const MatrixLookup* data_ml_;
173  const MatrixLookupWeighted* data_mlw_;
174  const Target* target_;
175 
176  // The number of neighbors
177  unsigned int k_;
178 
179  Distance distance_;
180  NeighborWeighting weighting_;
181 
182  void calculate_unweighted(const MatrixLookup&,
183  const MatrixLookup&,
184  utility::Matrix*) const;
185  void calculate_weighted(const MatrixLookupWeighted&,
186  const MatrixLookupWeighted&,
187  utility::Matrix*) const;
188 
189  void predict_common(const utility::Matrix& distances,
190  utility::Matrix& prediction) const;
191 
192  };
193 
194 
213  template <class T>
215  : public boost::DefaultConstructible<T>, public boost::Assignable<T>
216  {
217  public:
222  {
223  T neighbor_weighting;
224  utility::Vector vec;
225  const utility::VectorBase& distance(vec);
226  utility::VectorMutable& prediction(vec);
227  std::vector<size_t> k_sorted;
228  Target target;
229  neighbor_weighting(distance, k_sorted, target, prediction);
230  }
231  private:
232  };
233 
234  // template implementation
235 
236  template <typename Distance, typename NeighborWeighting>
238  : SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3)
239  {
240  BOOST_CONCEPT_ASSERT((utility::DistanceConcept<Distance>));
241  BOOST_CONCEPT_ASSERT((NeighborWeightingConcept<NeighborWeighting>));
242  }
243 
244  template <typename Distance, typename NeighborWeighting>
246  : SupervisedClassifier(), data_ml_(0), data_mlw_(0), target_(0), k_(3),
247  distance_(dist)
248  {
249  BOOST_CONCEPT_ASSERT((utility::DistanceConcept<Distance>));
250  BOOST_CONCEPT_ASSERT((NeighborWeightingConcept<NeighborWeighting>));
251  }
252 
253 
254  template <typename Distance, typename NeighborWeighting>
256  {
257  }
258 
259 
260  template <typename Distance, typename NeighborWeighting>
262  (const MatrixLookup& training, const MatrixLookup& test,
263  utility::Matrix* distances) const
264  {
265  for(size_t i=0; i<training.columns(); i++) {
266  for(size_t j=0; j<test.columns(); j++) {
267  (*distances)(i,j) = distance_(training.begin_column(i),
268  training.end_column(i),
269  test.begin_column(j));
270  YAT_ASSERT(!std::isnan((*distances)(i,j)));
271  }
272  }
273  }
274 
275 
276  template <typename Distance, typename NeighborWeighting>
277  void
279  (const MatrixLookupWeighted& training, const MatrixLookupWeighted& test,
280  utility::Matrix* distances) const
281  {
282  for(size_t i=0; i<training.columns(); i++) {
283  for(size_t j=0; j<test.columns(); j++) {
284  (*distances)(i,j) = distance_(training.begin_column(i),
285  training.end_column(i),
286  test.begin_column(j));
287  // If the distance is NaN (no common variables with non-zero weights),
288  // the distance is set to infinity to be sorted as a neighbor at the end
289  if(std::isnan((*distances)(i,j)))
290  (*distances)(i,j)=std::numeric_limits<double>::infinity();
291  }
292  }
293  }
294 
295 
296  template <typename Distance, typename NeighborWeighting>
298  {
299  return k_;
300  }
301 
302  template <typename Distance, typename NeighborWeighting>
304  {
305  k_=k;
306  }
307 
308 
309  template <typename Distance, typename NeighborWeighting>
312  {
313  // All private members should be copied here to generate an
314  // identical but untrained classifier
315  KNN* knn=new KNN<Distance, NeighborWeighting>(distance_);
316  knn->weighting_=this->weighting_;
317  knn->k(this->k());
318  return knn;
319  }
320 
321 
322  template <typename Distance, typename NeighborWeighting>
324  const Target& target)
325  {
326  utility::yat_assert<utility::runtime_error>
327  (data.columns()==target.size(),
328  "KNN::train called with different sizes of target and data");
329  // k has to be at most the number of training samples.
330  if(data.columns()<k_)
331  k_=data.columns();
332  data_ml_=&data;
333  data_mlw_=0;
334  target_=&target;
335  }
336 
337  template <typename Distance, typename NeighborWeighting>
339  const Target& target)
340  {
341  utility::yat_assert<utility::runtime_error>
342  (data.columns()==target.size(),
343  "KNN::train called with different sizes of target and data");
344  // k has to be at most the number of training samples.
345  if(data.columns()<k_)
346  k_=data.columns();
347  data_ml_=0;
348  data_mlw_=&data;
349  target_=&target;
350  }
351 
352 
353  template <typename Distance, typename NeighborWeighting>
354  void
356  utility::Matrix& prediction) const
357  {
358  // matrix with training samples as rows and test samples as columns
359  utility::Matrix* distances = 0;
360  // unweighted training data
361  if(data_ml_ && !data_mlw_) {
362  utility::yat_assert<utility::runtime_error>
363  (data_ml_->rows()==test.rows(),
364  "KNN::predict different number of rows in training and test data");
365  distances=new utility::Matrix(data_ml_->columns(),test.columns());
366  calculate_unweighted(*data_ml_,test,distances);
367  }
368  else if (data_mlw_ && !data_ml_) {
369  // weighted training data
370  utility::yat_assert<utility::runtime_error>
371  (data_mlw_->rows()==test.rows(),
372  "KNN::predict different number of rows in training and test data");
373  distances=new utility::Matrix(data_mlw_->columns(),test.columns());
374  calculate_weighted(*data_mlw_,MatrixLookupWeighted(test),
375  distances);
376  }
377  else {
378  throw utility::runtime_error("KNN::predict no training data");
379  }
380 
381  prediction.resize(target_->nof_classes(),test.columns(),0.0);
382  predict_common(*distances,prediction);
383  if(distances)
384  delete distances;
385  }
386 
387  template <typename Distance, typename NeighborWeighting>
388  void
390  utility::Matrix& prediction) const
391  {
392  // matrix with training samples as rows and test samples as columns
393  utility::Matrix* distances=0;
394  // unweighted training data
395  if(data_ml_ && !data_mlw_) {
396  utility::yat_assert<utility::runtime_error>
397  (data_ml_->rows()==test.rows(),
398  "KNN::predict different number of rows in training and test data");
399  distances=new utility::Matrix(data_ml_->columns(),test.columns());
400  calculate_weighted(MatrixLookupWeighted(*data_ml_),test,distances);
401  }
402  // weighted training data
403  else if (data_mlw_ && !data_ml_) {
404  utility::yat_assert<utility::runtime_error>
405  (data_mlw_->rows()==test.rows(),
406  "KNN::predict different number of rows in training and test data");
407  distances=new utility::Matrix(data_mlw_->columns(),test.columns());
408  calculate_weighted(*data_mlw_,test,distances);
409  }
410  else {
411  throw utility::runtime_error("KNN::predict no training data");
412  }
413 
414  prediction.resize(target_->nof_classes(),test.columns(),0.0);
415  predict_common(*distances,prediction);
416 
417  if(distances)
418  delete distances;
419  }
420 
421  template <typename Distance, typename NeighborWeighting>
423  (const utility::Matrix& distances, utility::Matrix& prediction) const
424  {
425  for(size_t sample=0;sample<distances.columns();sample++) {
426  std::vector<size_t> k_index;
427  utility::VectorConstView dist=distances.column_const_view(sample);
428  utility::sort_smallest_index(k_index,k_,dist);
429  utility::VectorView pred=prediction.column_view(sample);
430  weighting_(dist,k_index,*target_,pred);
431  }
432 
433  // classes for which there are no training samples should be set
434  // to nan in the predictions
435  for(size_t c=0;c<target_->nof_classes(); c++)
436  if(!target_->size(c))
437  for(size_t j=0;j<prediction.columns();j++)
438  prediction(c,j)=std::numeric_limits<double>::quiet_NaN();
439  }
440 }}} // of namespace classifier, yat, and theplu
441 
442 #endif
const VectorConstView column_const_view(size_t) const
General view into utility::Matrix.
Definition: MatrixLookup.h:70
virtual ~KNN()
Definition: KNN.h:255
Class for containing sample labels.
Definition: Target.h:47
BOOST_CONCEPT_USAGE(NeighborWeightingConcept)
function doing the concept test
Definition: KNN.h:221
Nearest Neighbor Classifier.
Definition: KNN.h:70
The Department of Theoretical Physics namespace as we define it.
This is the yat interface to gsl_vector_view.
Definition: VectorView.h:79
void predict(const MatrixLookup &data, utility::Matrix &results) const
Make predictions for unweighted test data.
Definition: KNN.h:355
void train(const MatrixLookup &data, const Target &targets)
Train the KNN using unweighted training data with known targets.
Definition: KNN.h:323
KNN(void)
Default constructor.
Definition: KNN.h:237
void resize(size_t r, size_t c, double init_value=0)
Resize Matrix.
void sort_smallest_index(std::vector< size_t > &sort_index, size_t k, const VectorBase &invec)
Interface class for supervised classifiers that use data in a matrix format.
Definition: SupervisedClassifier.h:56
unsigned int k() const
Get the number of nearest neighbors.
Definition: KNN.h:297
Read-only view.
Definition: VectorConstView.h:56
Class used for all runtime error detected within yat library.
Definition: Exception.h:38
This is the yat interface to GSL vector.
Definition: Vector.h:59
This is the yat interface to GSL vector.
Definition: VectorBase.h:55
General view into utility::MatrixWeighted.
Definition: MatrixLookupWeighted.h:63
KNN< Distance, NeighborWeighting > * make_classifier(void) const
Create an untrained copy of the classifier.
Definition: KNN.h:311
VectorView column_view(size_t i)
This is the mutable interface to GSL vector.
Definition: VectorMutable.h:56
const_column_iterator end_column(size_t) const
const_column_iterator begin_column(size_t) const
const_column_iterator end_column(size_t) const
Concept check for a Distance.
Definition: concept_check.h:290
Interface to GSL matrix.
Definition: Matrix.h:74
const_column_iterator begin_column(size_t) const
Concept check for a Neighbor Weighting Method.
Definition: KNN.h:214
size_t nof_classes(void) const
size_t columns(void) const

Generated on Wed Jul 17 2019 02:25:31 for yat by  doxygen 1.8.11