3552 |
03 Jan 17 |
peter |
1 |
#ifndef _theplu_yat_classifier_knn_ |
3552 |
03 Jan 17 |
peter |
2 |
#define _theplu_yat_classifier_knn_ |
902 |
27 Sep 07 |
markus |
3 |
|
902 |
27 Sep 07 |
markus |
// $Id$ |
902 |
27 Sep 07 |
markus |
5 |
|
999 |
23 Dec 07 |
jari |
6 |
/* |
2119 |
12 Dec 09 |
peter |
Copyright (C) 2007, 2008 Jari Häkkinen, Peter Johansson, Markus Ringnér |
4359 |
23 Aug 23 |
peter |
Copyright (C) 2009, 2010 Peter Johansson |
999 |
23 Dec 07 |
jari |
9 |
|
1437 |
25 Aug 08 |
peter |
This file is part of the yat library, http://dev.thep.lu.se/yat |
999 |
23 Dec 07 |
jari |
11 |
|
999 |
23 Dec 07 |
jari |
The yat library is free software; you can redistribute it and/or |
999 |
23 Dec 07 |
jari |
modify it under the terms of the GNU General Public License as |
1486 |
09 Sep 08 |
jari |
published by the Free Software Foundation; either version 3 of the |
999 |
23 Dec 07 |
jari |
License, or (at your option) any later version. |
999 |
23 Dec 07 |
jari |
16 |
|
999 |
23 Dec 07 |
jari |
The yat library is distributed in the hope that it will be useful, |
999 |
23 Dec 07 |
jari |
but WITHOUT ANY WARRANTY; without even the implied warranty of |
999 |
23 Dec 07 |
jari |
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU |
999 |
23 Dec 07 |
jari |
General Public License for more details. |
999 |
23 Dec 07 |
jari |
21 |
|
999 |
23 Dec 07 |
jari |
You should have received a copy of the GNU General Public License |
1487 |
10 Sep 08 |
jari |
along with yat. If not, see <http://www.gnu.org/licenses/>. |
999 |
23 Dec 07 |
jari |
24 |
*/ |
999 |
23 Dec 07 |
jari |
25 |
|
1050 |
07 Feb 08 |
peter |
26 |
#include "DataLookup1D.h" |
902 |
27 Sep 07 |
markus |
27 |
#include "DataLookupWeighted1D.h" |
1112 |
21 Feb 08 |
markus |
28 |
#include "KNN_Uniform.h" |
948 |
08 Oct 07 |
markus |
29 |
#include "MatrixLookup.h" |
902 |
27 Sep 07 |
markus |
30 |
#include "MatrixLookupWeighted.h" |
902 |
27 Sep 07 |
markus |
31 |
#include "SupervisedClassifier.h" |
902 |
27 Sep 07 |
markus |
32 |
#include "Target.h" |
2334 |
15 Oct 10 |
peter |
33 |
#include "yat/utility/concept_check.h" |
2210 |
05 Mar 10 |
peter |
34 |
#include "yat/utility/Exception.h" |
1121 |
22 Feb 08 |
peter |
35 |
#include "yat/utility/Matrix.h" |
2340 |
16 Oct 10 |
peter |
36 |
#include "yat/utility/Vector.h" |
2337 |
15 Oct 10 |
peter |
37 |
#include "yat/utility/VectorConstView.h" |
2337 |
15 Oct 10 |
peter |
38 |
#include "yat/utility/VectorView.h" |
916 |
30 Sep 07 |
peter |
39 |
#include "yat/utility/yat_assert.h" |
902 |
27 Sep 07 |
markus |
40 |
|
2334 |
15 Oct 10 |
peter |
41 |
#include <boost/concept_check.hpp> |
2334 |
15 Oct 10 |
peter |
42 |
|
916 |
30 Sep 07 |
peter |
43 |
#include <cmath> |
2055 |
08 Sep 09 |
peter |
44 |
#include <limits> |
902 |
27 Sep 07 |
markus |
45 |
#include <map> |
936 |
05 Oct 07 |
peter |
46 |
#include <stdexcept> |
2337 |
15 Oct 10 |
peter |
47 |
#include <vector> |
902 |
27 Sep 07 |
markus |
48 |
|
902 |
27 Sep 07 |
markus |
49 |
namespace theplu { |
902 |
27 Sep 07 |
markus |
50 |
namespace yat { |
902 |
27 Sep 07 |
markus |
51 |
namespace classifier { |
902 |
27 Sep 07 |
markus |
52 |
|
1188 |
29 Feb 08 |
markus |
53 |
/** |
1189 |
29 Feb 08 |
markus |
\brief Nearest Neighbor Classifier |
3552 |
03 Jan 17 |
peter |
55 |
|
1188 |
29 Feb 08 |
markus |
A sample is predicted based on the classes of its k nearest |
1188 |
29 Feb 08 |
markus |
neighbors among the training data samples. KNN supports using |
1188 |
29 Feb 08 |
markus |
different measures, for example, Euclidean distance, to define |
1188 |
29 Feb 08 |
markus |
distance between samples. KNN also supports using different ways to |
1188 |
29 Feb 08 |
markus |
weight the votes of the k nearest neighbors. For example, using a |
1188 |
29 Feb 08 |
markus |
uniform vote a test sample gets a vote for each class which is the |
1188 |
29 Feb 08 |
markus |
number of nearest neighbors belonging to the class. |
3552 |
03 Jan 17 |
peter |
63 |
|
1188 |
29 Feb 08 |
markus |
The template argument Distance should be a class modelling the |
1188 |
29 Feb 08 |
markus |
concept \ref concept_distance. The template argument |
1188 |
29 Feb 08 |
markus |
NeighborWeighting should be a class modelling the concept \ref |
1188 |
29 Feb 08 |
markus |
concept_neighbor_weighting. |
1188 |
29 Feb 08 |
markus |
68 |
*/ |
1112 |
21 Feb 08 |
markus |
69 |
template <typename Distance, typename NeighborWeighting=KNN_Uniform> |
902 |
27 Sep 07 |
markus |
70 |
class KNN : public SupervisedClassifier |
902 |
27 Sep 07 |
markus |
71 |
{ |
3552 |
03 Jan 17 |
peter |
72 |
|
902 |
27 Sep 07 |
markus |
73 |
public: |
1188 |
29 Feb 08 |
markus |
74 |
/** |
1189 |
29 Feb 08 |
markus |
\brief Default constructor. |
3552 |
03 Jan 17 |
peter |
76 |
|
1188 |
29 Feb 08 |
markus |
The number of nearest neighbors (k) is set to 3. Distance and |
1188 |
29 Feb 08 |
markus |
NeighborWeighting are initialized using their default |
1188 |
29 Feb 08 |
markus |
constructuors. |
1188 |
29 Feb 08 |
markus |
80 |
*/ |
1157 |
26 Feb 08 |
markus |
81 |
KNN(void); |
948 |
08 Oct 07 |
markus |
82 |
|
948 |
08 Oct 07 |
markus |
83 |
|
1188 |
29 Feb 08 |
markus |
84 |
/** |
1189 |
29 Feb 08 |
markus |
\brief Constructor using an intialized distance measure. |
3552 |
03 Jan 17 |
peter |
86 |
|
1189 |
29 Feb 08 |
markus |
The number of nearest neighbors (k) is set to |
1189 |
29 Feb 08 |
markus |
3. NeighborWeighting is initialized using its default |
1189 |
29 Feb 08 |
markus |
constructor. This constructor should be used if Distance has |
1189 |
29 Feb 08 |
markus |
parameters and the user wants to specify the parameters by |
1189 |
29 Feb 08 |
markus |
initializing Distance prior to constructing the KNN. |
3552 |
03 Jan 17 |
peter |
92 |
*/ |
1158 |
26 Feb 08 |
markus |
93 |
KNN(const Distance&); |
1158 |
26 Feb 08 |
markus |
94 |
|
1158 |
26 Feb 08 |
markus |
95 |
|
1188 |
29 Feb 08 |
markus |
96 |
/** |
1188 |
29 Feb 08 |
markus |
Destructor |
1188 |
29 Feb 08 |
markus |
98 |
*/ |
902 |
27 Sep 07 |
markus |
99 |
virtual ~KNN(); |
3552 |
03 Jan 17 |
peter |
100 |
|
3552 |
03 Jan 17 |
peter |
101 |
|
1188 |
29 Feb 08 |
markus |
102 |
/** |
1188 |
29 Feb 08 |
markus |
\brief Get the number of nearest neighbors. |
1188 |
29 Feb 08 |
markus |
\return The number of neighbors. |
1188 |
29 Feb 08 |
markus |
105 |
*/ |
1271 |
09 Apr 08 |
peter |
106 |
unsigned int k() const; |
902 |
27 Sep 07 |
markus |
107 |
|
1188 |
29 Feb 08 |
markus |
108 |
/** |
1188 |
29 Feb 08 |
markus |
\brief Set the number of nearest neighbors. |
3552 |
03 Jan 17 |
peter |
110 |
|
3552 |
03 Jan 17 |
peter |
Sets the number of neighbors to \a k_in. |
1188 |
29 Feb 08 |
markus |
112 |
*/ |
1271 |
09 Apr 08 |
peter |
113 |
void k(unsigned int k_in); |
902 |
27 Sep 07 |
markus |
114 |
|
902 |
27 Sep 07 |
markus |
115 |
|
1157 |
26 Feb 08 |
markus |
116 |
KNN<Distance,NeighborWeighting>* make_classifier(void) const; |
3552 |
03 Jan 17 |
peter |
117 |
|
1188 |
29 Feb 08 |
markus |
118 |
/** |
1189 |
29 Feb 08 |
markus |
\brief Make predictions for unweighted test data. |
3552 |
03 Jan 17 |
peter |
120 |
|
1188 |
29 Feb 08 |
markus |
Predictions are calculated and returned in \a results. For |
1188 |
29 Feb 08 |
markus |
each sample in \a data, \a results contains the weighted number |
1188 |
29 Feb 08 |
markus |
of nearest neighbors which belong to each class. Numbers of |
1188 |
29 Feb 08 |
markus |
nearest neighbors are weighted according to |
1188 |
29 Feb 08 |
markus |
NeighborWeighting. If a class has no training samples NaN's are |
1188 |
29 Feb 08 |
markus |
returned for this class in \a results. |
1188 |
29 Feb 08 |
markus |
127 |
*/ |
1188 |
29 Feb 08 |
markus |
128 |
void predict(const MatrixLookup& data , utility::Matrix& results) const; |
902 |
27 Sep 07 |
markus |
129 |
|
3552 |
03 Jan 17 |
peter |
130 |
/** |
1189 |
29 Feb 08 |
markus |
\brief Make predictions for weighted test data. |
3552 |
03 Jan 17 |
peter |
132 |
|
1188 |
29 Feb 08 |
markus |
Predictions are calculated and returned in \a results. For |
1188 |
29 Feb 08 |
markus |
each sample in \a data, \a results contains the weighted |
1188 |
29 Feb 08 |
markus |
number of nearest neighbors which belong to each class as in |
1188 |
29 Feb 08 |
markus |
predict(const MatrixLookup& data, utility::Matrix& results). |
1188 |
29 Feb 08 |
markus |
If a test and training sample pair has no variables with |
1188 |
29 Feb 08 |
markus |
non-zero weights in common, there are no variables which can |
1188 |
29 Feb 08 |
markus |
be used to calculate the distance between the two samples. In |
1188 |
29 Feb 08 |
markus |
this case the distance between the two is set to infinity. |
1188 |
29 Feb 08 |
markus |
141 |
*/ |
3552 |
03 Jan 17 |
peter |
142 |
void predict(const MatrixLookupWeighted& data, |
2336 |
15 Oct 10 |
peter |
143 |
utility::Matrix& results) const; |
1157 |
26 Feb 08 |
markus |
144 |
|
1160 |
26 Feb 08 |
markus |
145 |
|
1188 |
29 Feb 08 |
markus |
146 |
/** |
1189 |
29 Feb 08 |
markus |
\brief Train the KNN using unweighted training data with known |
3552 |
03 Jan 17 |
peter |
targets. |
3552 |
03 Jan 17 |
peter |
149 |
|
1188 |
29 Feb 08 |
markus |
For KNN there is no actual training; the entire training data |
1188 |
29 Feb 08 |
markus |
set is stored with targets. KNN only stores references to \a data |
1188 |
29 Feb 08 |
markus |
and \a targets as copying these would make the %classifier |
1188 |
29 Feb 08 |
markus |
slow. If the number of training samples set is smaller than k, |
1188 |
29 Feb 08 |
markus |
k is set to the number of training samples. |
3552 |
03 Jan 17 |
peter |
155 |
|
1188 |
29 Feb 08 |
markus |
\note If \a data or \a targets go out of scope ore are |
1188 |
29 Feb 08 |
markus |
deleted, the KNN becomes invalid and further use is undefined |
1188 |
29 Feb 08 |
markus |
unless it is trained again. |
1188 |
29 Feb 08 |
markus |
159 |
*/ |
1188 |
29 Feb 08 |
markus |
160 |
void train(const MatrixLookup& data, const Target& targets); |
3552 |
03 Jan 17 |
peter |
161 |
|
3552 |
03 Jan 17 |
peter |
162 |
/** |
3552 |
03 Jan 17 |
peter |
\brief Train the KNN using weighted training data with known targets. |
3552 |
03 Jan 17 |
peter |
164 |
|
1188 |
29 Feb 08 |
markus |
See train(const MatrixLookup& data, const Target& targets) for |
1188 |
29 Feb 08 |
markus |
additional information. |
1188 |
29 Feb 08 |
markus |
167 |
*/ |
1188 |
29 Feb 08 |
markus |
168 |
void train(const MatrixLookupWeighted& data, const Target& targets); |
3552 |
03 Jan 17 |
peter |
169 |
|
902 |
27 Sep 07 |
markus |
170 |
private: |
3552 |
03 Jan 17 |
peter |
171 |
|
1160 |
26 Feb 08 |
markus |
172 |
const MatrixLookup* data_ml_; |
1160 |
26 Feb 08 |
markus |
173 |
const MatrixLookupWeighted* data_mlw_; |
1157 |
26 Feb 08 |
markus |
174 |
const Target* target_; |
902 |
27 Sep 07 |
markus |
175 |
|
1112 |
21 Feb 08 |
markus |
// The number of neighbors |
1271 |
09 Apr 08 |
peter |
177 |
unsigned int k_; |
902 |
27 Sep 07 |
markus |
178 |
|
1050 |
07 Feb 08 |
peter |
179 |
Distance distance_; |
1112 |
21 Feb 08 |
markus |
180 |
NeighborWeighting weighting_; |
1112 |
21 Feb 08 |
markus |
181 |
|
1107 |
19 Feb 08 |
markus |
182 |
void calculate_unweighted(const MatrixLookup&, |
1107 |
19 Feb 08 |
markus |
183 |
const MatrixLookup&, |
1121 |
22 Feb 08 |
peter |
184 |
utility::Matrix*) const; |
1107 |
19 Feb 08 |
markus |
185 |
void calculate_weighted(const MatrixLookupWeighted&, |
1107 |
19 Feb 08 |
markus |
186 |
const MatrixLookupWeighted&, |
1121 |
22 Feb 08 |
peter |
187 |
utility::Matrix*) const; |
1160 |
26 Feb 08 |
markus |
188 |
|
3552 |
03 Jan 17 |
peter |
189 |
void predict_common(const utility::Matrix& distances, |
1160 |
26 Feb 08 |
markus |
190 |
utility::Matrix& prediction) const; |
1160 |
26 Feb 08 |
markus |
191 |
|
902 |
27 Sep 07 |
markus |
192 |
}; |
3552 |
03 Jan 17 |
peter |
193 |
|
3552 |
03 Jan 17 |
peter |
194 |
|
2340 |
16 Oct 10 |
peter |
195 |
/** |
2340 |
16 Oct 10 |
peter |
\brief Concept check for a \ref concept_neighbor_weighting |
2340 |
16 Oct 10 |
peter |
197 |
|
3552 |
03 Jan 17 |
peter |
This class is intended to be used in a <a |
2340 |
16 Oct 10 |
peter |
href="\boost_url/concept_check/using_concept_check.htm"> |
2340 |
16 Oct 10 |
peter |
BOOST_CONCEPT_ASSERT </a> |
2340 |
16 Oct 10 |
peter |
201 |
|
2340 |
16 Oct 10 |
peter |
\code |
2340 |
16 Oct 10 |
peter |
template<class Distance> |
2340 |
16 Oct 10 |
peter |
void some_function(double x) |
2340 |
16 Oct 10 |
peter |
205 |
{ |
2340 |
16 Oct 10 |
peter |
BOOST_CONCEPT_ASSERT((DistanceConcept<Distance>)); |
2340 |
16 Oct 10 |
peter |
207 |
... |
2340 |
16 Oct 10 |
peter |
208 |
} |
2340 |
16 Oct 10 |
peter |
\endcode |
2340 |
16 Oct 10 |
peter |
210 |
|
2340 |
16 Oct 10 |
peter |
\since New in yat 0.7 |
2340 |
16 Oct 10 |
peter |
212 |
*/ |
2340 |
16 Oct 10 |
peter |
213 |
template <class T> |
3552 |
03 Jan 17 |
peter |
214 |
class NeighborWeightingConcept |
2340 |
16 Oct 10 |
peter |
215 |
: public boost::DefaultConstructible<T>, public boost::Assignable<T> |
2340 |
16 Oct 10 |
peter |
216 |
{ |
2340 |
16 Oct 10 |
peter |
217 |
public: |
2340 |
16 Oct 10 |
peter |
218 |
/** |
2340 |
16 Oct 10 |
peter |
\brief function doing the concept test |
2340 |
16 Oct 10 |
peter |
220 |
*/ |
2340 |
16 Oct 10 |
peter |
221 |
BOOST_CONCEPT_USAGE(NeighborWeightingConcept) |
2340 |
16 Oct 10 |
peter |
222 |
{ |
2340 |
16 Oct 10 |
peter |
223 |
T neighbor_weighting; |
2340 |
16 Oct 10 |
peter |
224 |
utility::Vector vec; |
2340 |
16 Oct 10 |
peter |
225 |
const utility::VectorBase& distance(vec); |
2340 |
16 Oct 10 |
peter |
226 |
utility::VectorMutable& prediction(vec); |
2340 |
16 Oct 10 |
peter |
227 |
std::vector<size_t> k_sorted; |
2340 |
16 Oct 10 |
peter |
228 |
Target target; |
2340 |
16 Oct 10 |
peter |
229 |
neighbor_weighting(distance, k_sorted, target, prediction); |
2340 |
16 Oct 10 |
peter |
230 |
} |
2340 |
16 Oct 10 |
peter |
231 |
private: |
2340 |
16 Oct 10 |
peter |
232 |
}; |
2340 |
16 Oct 10 |
peter |
233 |
|
2340 |
16 Oct 10 |
peter |
// template implementation |
3552 |
03 Jan 17 |
peter |
235 |
|
1112 |
21 Feb 08 |
markus |
236 |
template <typename Distance, typename NeighborWeighting> |
3552 |
03 Jan 17 |
peter |
237 |
KNN<Distance, NeighborWeighting>::KNN() |
1160 |
26 Feb 08 |
markus |
238 |
: SupervisedClassifier(),data_ml_(0),data_mlw_(0),target_(0),k_(3) |
948 |
08 Oct 07 |
markus |
239 |
{ |
2334 |
15 Oct 10 |
peter |
240 |
BOOST_CONCEPT_ASSERT((utility::DistanceConcept<Distance>)); |
2340 |
16 Oct 10 |
peter |
241 |
BOOST_CONCEPT_ASSERT((NeighborWeightingConcept<NeighborWeighting>)); |
948 |
08 Oct 07 |
markus |
242 |
} |
948 |
08 Oct 07 |
markus |
243 |
|
1158 |
26 Feb 08 |
markus |
244 |
template <typename Distance, typename NeighborWeighting> |
3552 |
03 Jan 17 |
peter |
245 |
KNN<Distance, NeighborWeighting>::KNN(const Distance& dist) |
3552 |
03 Jan 17 |
peter |
246 |
: SupervisedClassifier(), data_ml_(0), data_mlw_(0), target_(0), k_(3), |
2335 |
15 Oct 10 |
peter |
247 |
distance_(dist) |
1158 |
26 Feb 08 |
markus |
248 |
{ |
2334 |
15 Oct 10 |
peter |
249 |
BOOST_CONCEPT_ASSERT((utility::DistanceConcept<Distance>)); |
3553 |
03 Jan 17 |
peter |
250 |
BOOST_CONCEPT_ASSERT((NeighborWeightingConcept<NeighborWeighting>)); |
1158 |
26 Feb 08 |
markus |
251 |
} |
1158 |
26 Feb 08 |
markus |
252 |
|
3552 |
03 Jan 17 |
peter |
253 |
|
1112 |
21 Feb 08 |
markus |
254 |
template <typename Distance, typename NeighborWeighting> |
3552 |
03 Jan 17 |
peter |
255 |
KNN<Distance, NeighborWeighting>::~KNN() |
902 |
27 Sep 07 |
markus |
256 |
{ |
902 |
27 Sep 07 |
markus |
257 |
} |
1107 |
19 Feb 08 |
markus |
258 |
|
3552 |
03 Jan 17 |
peter |
259 |
|
1112 |
21 Feb 08 |
markus |
260 |
template <typename Distance, typename NeighborWeighting> |
1112 |
21 Feb 08 |
markus |
261 |
void KNN<Distance, NeighborWeighting>::calculate_unweighted |
1112 |
21 Feb 08 |
markus |
262 |
(const MatrixLookup& training, const MatrixLookup& test, |
1121 |
22 Feb 08 |
peter |
263 |
utility::Matrix* distances) const |
1107 |
19 Feb 08 |
markus |
264 |
{ |
1107 |
19 Feb 08 |
markus |
265 |
for(size_t i=0; i<training.columns(); i++) { |
1107 |
19 Feb 08 |
markus |
266 |
for(size_t j=0; j<test.columns(); j++) { |
3552 |
03 Jan 17 |
peter |
267 |
(*distances)(i,j) = distance_(training.begin_column(i), |
3552 |
03 Jan 17 |
peter |
268 |
training.end_column(i), |
1160 |
26 Feb 08 |
markus |
269 |
test.begin_column(j)); |
1875 |
19 Mar 09 |
peter |
270 |
YAT_ASSERT(!std::isnan((*distances)(i,j))); |
948 |
08 Oct 07 |
markus |
271 |
} |
1107 |
19 Feb 08 |
markus |
272 |
} |
1107 |
19 Feb 08 |
markus |
273 |
} |
1160 |
26 Feb 08 |
markus |
274 |
|
3552 |
03 Jan 17 |
peter |
275 |
|
1112 |
21 Feb 08 |
markus |
276 |
template <typename Distance, typename NeighborWeighting> |
3552 |
03 Jan 17 |
peter |
277 |
void |
1112 |
21 Feb 08 |
markus |
278 |
KNN<Distance, NeighborWeighting>::calculate_weighted |
1112 |
21 Feb 08 |
markus |
279 |
(const MatrixLookupWeighted& training, const MatrixLookupWeighted& test, |
1121 |
22 Feb 08 |
peter |
280 |
utility::Matrix* distances) const |
1107 |
19 Feb 08 |
markus |
281 |
{ |
3552 |
03 Jan 17 |
peter |
282 |
for(size_t i=0; i<training.columns(); i++) { |
1107 |
19 Feb 08 |
markus |
283 |
for(size_t j=0; j<test.columns(); j++) { |
3552 |
03 Jan 17 |
peter |
284 |
(*distances)(i,j) = distance_(training.begin_column(i), |
3552 |
03 Jan 17 |
peter |
285 |
training.end_column(i), |
1160 |
26 Feb 08 |
markus |
286 |
test.begin_column(j)); |
1156 |
26 Feb 08 |
markus |
// If the distance is NaN (no common variables with non-zero weights), |
1156 |
26 Feb 08 |
markus |
// the distance is set to infinity to be sorted as a neighbor at the end |
3552 |
03 Jan 17 |
peter |
289 |
if(std::isnan((*distances)(i,j))) |
1156 |
26 Feb 08 |
markus |
290 |
(*distances)(i,j)=std::numeric_limits<double>::infinity(); |
948 |
08 Oct 07 |
markus |
291 |
} |
902 |
27 Sep 07 |
markus |
292 |
} |
902 |
27 Sep 07 |
markus |
293 |
} |
3552 |
03 Jan 17 |
peter |
294 |
|
3552 |
03 Jan 17 |
peter |
295 |
|
1112 |
21 Feb 08 |
markus |
296 |
template <typename Distance, typename NeighborWeighting> |
1271 |
09 Apr 08 |
peter |
297 |
unsigned int KNN<Distance, NeighborWeighting>::k() const |
902 |
27 Sep 07 |
markus |
298 |
{ |
902 |
27 Sep 07 |
markus |
299 |
return k_; |
902 |
27 Sep 07 |
markus |
300 |
} |
902 |
27 Sep 07 |
markus |
301 |
|
1112 |
21 Feb 08 |
markus |
302 |
template <typename Distance, typename NeighborWeighting> |
1274 |
10 Apr 08 |
peter |
303 |
void KNN<Distance, NeighborWeighting>::k(unsigned int k) |
902 |
27 Sep 07 |
markus |
304 |
{ |
902 |
27 Sep 07 |
markus |
305 |
k_=k; |
902 |
27 Sep 07 |
markus |
306 |
} |
902 |
27 Sep 07 |
markus |
307 |
|
902 |
27 Sep 07 |
markus |
308 |
|
1112 |
21 Feb 08 |
markus |
309 |
template <typename Distance, typename NeighborWeighting> |
3552 |
03 Jan 17 |
peter |
310 |
KNN<Distance, NeighborWeighting>* |
3552 |
03 Jan 17 |
peter |
311 |
KNN<Distance, NeighborWeighting>::make_classifier() const |
3552 |
03 Jan 17 |
peter |
312 |
{ |
1164 |
26 Feb 08 |
markus |
// All private members should be copied here to generate an |
1164 |
26 Feb 08 |
markus |
// identical but untrained classifier |
1164 |
26 Feb 08 |
markus |
315 |
KNN* knn=new KNN<Distance, NeighborWeighting>(distance_); |
1164 |
26 Feb 08 |
markus |
316 |
knn->weighting_=this->weighting_; |
1157 |
26 Feb 08 |
markus |
317 |
knn->k(this->k()); |
902 |
27 Sep 07 |
markus |
318 |
return knn; |
902 |
27 Sep 07 |
markus |
319 |
} |
3552 |
03 Jan 17 |
peter |
320 |
|
3552 |
03 Jan 17 |
peter |
321 |
|
1112 |
21 Feb 08 |
markus |
322 |
template <typename Distance, typename NeighborWeighting> |
3552 |
03 Jan 17 |
peter |
323 |
void KNN<Distance, NeighborWeighting>::train(const MatrixLookup& data, |
1157 |
26 Feb 08 |
markus |
324 |
const Target& target) |
3552 |
03 Jan 17 |
peter |
325 |
{ |
2210 |
05 Mar 10 |
peter |
326 |
utility::yat_assert<utility::runtime_error> |
1157 |
26 Feb 08 |
markus |
327 |
(data.columns()==target.size(), |
1157 |
26 Feb 08 |
markus |
328 |
"KNN::train called with different sizes of target and data"); |
1157 |
26 Feb 08 |
markus |
// k has to be at most the number of training samples. |
3552 |
03 Jan 17 |
peter |
330 |
if(data.columns()<k_) |
1157 |
26 Feb 08 |
markus |
331 |
k_=data.columns(); |
1160 |
26 Feb 08 |
markus |
332 |
data_ml_=&data; |
1160 |
26 Feb 08 |
markus |
333 |
data_mlw_=0; |
1157 |
26 Feb 08 |
markus |
334 |
target_=⌖ |
902 |
27 Sep 07 |
markus |
335 |
} |
902 |
27 Sep 07 |
markus |
336 |
|
1157 |
26 Feb 08 |
markus |
337 |
template <typename Distance, typename NeighborWeighting> |
3552 |
03 Jan 17 |
peter |
338 |
void KNN<Distance, NeighborWeighting>::train(const MatrixLookupWeighted& data, |
1157 |
26 Feb 08 |
markus |
339 |
const Target& target) |
3552 |
03 Jan 17 |
peter |
340 |
{ |
2210 |
05 Mar 10 |
peter |
341 |
utility::yat_assert<utility::runtime_error> |
1157 |
26 Feb 08 |
markus |
342 |
(data.columns()==target.size(), |
1157 |
26 Feb 08 |
markus |
343 |
"KNN::train called with different sizes of target and data"); |
1157 |
26 Feb 08 |
markus |
// k has to be at most the number of training samples. |
3552 |
03 Jan 17 |
peter |
345 |
if(data.columns()<k_) |
1157 |
26 Feb 08 |
markus |
346 |
k_=data.columns(); |
1160 |
26 Feb 08 |
markus |
347 |
data_ml_=0; |
1160 |
26 Feb 08 |
markus |
348 |
data_mlw_=&data; |
1157 |
26 Feb 08 |
markus |
349 |
target_=⌖ |
1157 |
26 Feb 08 |
markus |
350 |
} |
902 |
27 Sep 07 |
markus |
351 |
|
1157 |
26 Feb 08 |
markus |
352 |
|
1112 |
21 Feb 08 |
markus |
353 |
template <typename Distance, typename NeighborWeighting> |
3552 |
03 Jan 17 |
peter |
354 |
void |
2335 |
15 Oct 10 |
peter |
355 |
KNN<Distance, NeighborWeighting>::predict(const MatrixLookup& test, |
2335 |
15 Oct 10 |
peter |
356 |
utility::Matrix& prediction) const |
3552 |
03 Jan 17 |
peter |
357 |
{ |
1160 |
26 Feb 08 |
markus |
// matrix with training samples as rows and test samples as columns |
1160 |
26 Feb 08 |
markus |
359 |
utility::Matrix* distances = 0; |
1160 |
26 Feb 08 |
markus |
// unweighted training data |
1160 |
26 Feb 08 |
markus |
361 |
if(data_ml_ && !data_mlw_) { |
2210 |
05 Mar 10 |
peter |
362 |
utility::yat_assert<utility::runtime_error> |
1160 |
26 Feb 08 |
markus |
363 |
(data_ml_->rows()==test.rows(), |
2336 |
15 Oct 10 |
peter |
364 |
"KNN::predict different number of rows in training and test data"); |
1160 |
26 Feb 08 |
markus |
365 |
distances=new utility::Matrix(data_ml_->columns(),test.columns()); |
1160 |
26 Feb 08 |
markus |
366 |
calculate_unweighted(*data_ml_,test,distances); |
1160 |
26 Feb 08 |
markus |
367 |
} |
1160 |
26 Feb 08 |
markus |
368 |
else if (data_mlw_ && !data_ml_) { |
1160 |
26 Feb 08 |
markus |
// weighted training data |
2210 |
05 Mar 10 |
peter |
370 |
utility::yat_assert<utility::runtime_error> |
1160 |
26 Feb 08 |
markus |
371 |
(data_mlw_->rows()==test.rows(), |
2336 |
15 Oct 10 |
peter |
372 |
"KNN::predict different number of rows in training and test data"); |
1160 |
26 Feb 08 |
markus |
373 |
distances=new utility::Matrix(data_mlw_->columns(),test.columns()); |
1160 |
26 Feb 08 |
markus |
374 |
calculate_weighted(*data_mlw_,MatrixLookupWeighted(test), |
3552 |
03 Jan 17 |
peter |
375 |
distances); |
1160 |
26 Feb 08 |
markus |
376 |
} |
1160 |
26 Feb 08 |
markus |
377 |
else { |
2210 |
05 Mar 10 |
peter |
378 |
throw utility::runtime_error("KNN::predict no training data"); |
1160 |
26 Feb 08 |
markus |
379 |
} |
1031 |
04 Feb 08 |
markus |
380 |
|
1160 |
26 Feb 08 |
markus |
381 |
prediction.resize(target_->nof_classes(),test.columns(),0.0); |
1160 |
26 Feb 08 |
markus |
382 |
predict_common(*distances,prediction); |
1160 |
26 Feb 08 |
markus |
383 |
if(distances) |
1160 |
26 Feb 08 |
markus |
384 |
delete distances; |
1160 |
26 Feb 08 |
markus |
385 |
} |
1160 |
26 Feb 08 |
markus |
386 |
|
1160 |
26 Feb 08 |
markus |
387 |
template <typename Distance, typename NeighborWeighting> |
3552 |
03 Jan 17 |
peter |
388 |
void |
2335 |
15 Oct 10 |
peter |
389 |
KNN<Distance, NeighborWeighting>::predict(const MatrixLookupWeighted& test, |
2335 |
15 Oct 10 |
peter |
390 |
utility::Matrix& prediction) const |
3552 |
03 Jan 17 |
peter |
391 |
{ |
1160 |
26 Feb 08 |
markus |
// matrix with training samples as rows and test samples as columns |
3552 |
03 Jan 17 |
peter |
393 |
utility::Matrix* distances=0; |
1160 |
26 Feb 08 |
markus |
// unweighted training data |
3552 |
03 Jan 17 |
peter |
395 |
if(data_ml_ && !data_mlw_) { |
2210 |
05 Mar 10 |
peter |
396 |
utility::yat_assert<utility::runtime_error> |
1160 |
26 Feb 08 |
markus |
397 |
(data_ml_->rows()==test.rows(), |
3552 |
03 Jan 17 |
peter |
398 |
"KNN::predict different number of rows in training and test data"); |
1160 |
26 Feb 08 |
markus |
399 |
distances=new utility::Matrix(data_ml_->columns(),test.columns()); |
3552 |
03 Jan 17 |
peter |
400 |
calculate_weighted(MatrixLookupWeighted(*data_ml_),test,distances); |
1160 |
26 Feb 08 |
markus |
401 |
} |
1160 |
26 Feb 08 |
markus |
// weighted training data |
1160 |
26 Feb 08 |
markus |
403 |
else if (data_mlw_ && !data_ml_) { |
2210 |
05 Mar 10 |
peter |
404 |
utility::yat_assert<utility::runtime_error> |
1160 |
26 Feb 08 |
markus |
405 |
(data_mlw_->rows()==test.rows(), |
3552 |
03 Jan 17 |
peter |
406 |
"KNN::predict different number of rows in training and test data"); |
1160 |
26 Feb 08 |
markus |
407 |
distances=new utility::Matrix(data_mlw_->columns(),test.columns()); |
3552 |
03 Jan 17 |
peter |
408 |
calculate_weighted(*data_mlw_,test,distances); |
1160 |
26 Feb 08 |
markus |
409 |
} |
1160 |
26 Feb 08 |
markus |
410 |
else { |
2210 |
05 Mar 10 |
peter |
411 |
throw utility::runtime_error("KNN::predict no training data"); |
1160 |
26 Feb 08 |
markus |
412 |
} |
1160 |
26 Feb 08 |
markus |
413 |
|
1160 |
26 Feb 08 |
markus |
414 |
prediction.resize(target_->nof_classes(),test.columns(),0.0); |
1160 |
26 Feb 08 |
markus |
415 |
predict_common(*distances,prediction); |
3552 |
03 Jan 17 |
peter |
416 |
|
1160 |
26 Feb 08 |
markus |
417 |
if(distances) |
1160 |
26 Feb 08 |
markus |
418 |
delete distances; |
1160 |
26 Feb 08 |
markus |
419 |
} |
3552 |
03 Jan 17 |
peter |
420 |
|
1160 |
26 Feb 08 |
markus |
421 |
template <typename Distance, typename NeighborWeighting> |
1160 |
26 Feb 08 |
markus |
422 |
void KNN<Distance, NeighborWeighting>::predict_common |
1160 |
26 Feb 08 |
markus |
423 |
(const utility::Matrix& distances, utility::Matrix& prediction) const |
3552 |
03 Jan 17 |
peter |
424 |
{ |
1160 |
26 Feb 08 |
markus |
425 |
for(size_t sample=0;sample<distances.columns();sample++) { |
902 |
27 Sep 07 |
markus |
426 |
std::vector<size_t> k_index; |
1160 |
26 Feb 08 |
markus |
427 |
utility::VectorConstView dist=distances.column_const_view(sample); |
1112 |
21 Feb 08 |
markus |
428 |
utility::sort_smallest_index(k_index,k_,dist); |
1112 |
21 Feb 08 |
markus |
429 |
utility::VectorView pred=prediction.column_view(sample); |
1157 |
26 Feb 08 |
markus |
430 |
weighting_(dist,k_index,*target_,pred); |
902 |
27 Sep 07 |
markus |
431 |
} |
3552 |
03 Jan 17 |
peter |
432 |
|
1142 |
25 Feb 08 |
markus |
// classes for which there are no training samples should be set |
1142 |
25 Feb 08 |
markus |
// to nan in the predictions |
3552 |
03 Jan 17 |
peter |
435 |
for(size_t c=0;c<target_->nof_classes(); c++) |
3552 |
03 Jan 17 |
peter |
436 |
if(!target_->size(c)) |
1142 |
25 Feb 08 |
markus |
437 |
for(size_t j=0;j<prediction.columns();j++) |
1142 |
25 Feb 08 |
markus |
438 |
prediction(c,j)=std::numeric_limits<double>::quiet_NaN(); |
902 |
27 Sep 07 |
markus |
439 |
} |
902 |
27 Sep 07 |
markus |
440 |
}}} // of namespace classifier, yat, and theplu |
902 |
27 Sep 07 |
markus |
441 |
|
902 |
27 Sep 07 |
markus |
442 |
#endif |