00001 #ifndef _theplu_yat_classifier_ncc_
00002 #define _theplu_yat_classifier_ncc_
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027 #include "MatrixLookup.h"
00028 #include "MatrixLookupWeighted.h"
00029 #include "SupervisedClassifier.h"
00030 #include "Target.h"
00031
00032 #include "yat/statistics/Averager.h"
00033 #include "yat/statistics/AveragerWeighted.h"
00034 #include "yat/utility/Matrix.h"
00035 #include "yat/utility/MatrixWeighted.h"
00036 #include "yat/utility/Vector.h"
00037 #include "yat/utility/stl_utility.h"
00038 #include "yat/utility/yat_assert.h"
00039
00040 #include <iterator>
00041 #include <map>
00042 #include <cmath>
00043 #include <stdexcept>
00044
00045 namespace theplu {
00046 namespace yat {
00047 namespace classifier {
00048
00049
00061 template <typename Distance>
00062 class NCC : public SupervisedClassifier
00063 {
00064
00065 public:
00071 NCC(void);
00072
00080 NCC(const Distance&);
00081
00082
00086 virtual ~NCC(void);
00087
00093 const utility::Matrix& centroids(void) const;
00094
00095 NCC<Distance>* make_classifier(void) const;
00096
00114 void predict(const MatrixLookup& data, utility::Matrix& results) const;
00115
00137 void predict(const MatrixLookupWeighted& data, utility::Matrix& results) const;
00138
00147 void train(const MatrixLookup& data, const Target& targets);
00148
00149
00162 void train(const MatrixLookupWeighted& data, const Target& targets);
00163
00164
00165 private:
00166
00167 void predict_unweighted(const MatrixLookup&, utility::Matrix&) const;
00168 void predict_weighted(const MatrixLookupWeighted&, utility::Matrix&) const;
00169
00170 utility::Matrix centroids_;
00171 bool centroids_nan_;
00172 Distance distance_;
00173 };
00174
00175
00176
00177 template <typename Distance>
00178 NCC<Distance>::NCC()
00179 : SupervisedClassifier(), centroids_nan_(false)
00180 {
00181 }
00182
00183 template <typename Distance>
00184 NCC<Distance>::NCC(const Distance& dist)
00185 : SupervisedClassifier(), centroids_nan_(false), distance_(dist)
00186 {
00187 }
00188
00189
00190 template <typename Distance>
00191 NCC<Distance>::~NCC()
00192 {
00193 }
00194
00195
00196 template <typename Distance>
00197 const utility::Matrix& NCC<Distance>::centroids(void) const
00198 {
00199 return centroids_;
00200 }
00201
00202
00203 template <typename Distance>
00204 NCC<Distance>*
00205 NCC<Distance>::make_classifier() const
00206 {
00207
00208
00209 return new NCC<Distance>(distance_);
00210 }
00211
00212 template <typename Distance>
00213 void NCC<Distance>::train(const MatrixLookup& data, const Target& target)
00214 {
00215 centroids_.resize(data.rows(), target.nof_classes());
00216 for(size_t i=0; i<data.rows(); i++) {
00217 std::vector<statistics::Averager> class_averager;
00218 class_averager.resize(target.nof_classes());
00219 for(size_t j=0; j<data.columns(); j++) {
00220 class_averager[target(j)].add(data(i,j));
00221 }
00222 for(size_t c=0;c<target.nof_classes();c++) {
00223 centroids_(i,c) = class_averager[c].mean();
00224 }
00225 }
00226 }
00227
00228
00229 template <typename Distance>
00230 void NCC<Distance>::train(const MatrixLookupWeighted& data, const Target& target)
00231 {
00232 centroids_.resize(data.rows(), target.nof_classes());
00233 for(size_t i=0; i<data.rows(); i++) {
00234 std::vector<statistics::AveragerWeighted> class_averager;
00235 class_averager.resize(target.nof_classes());
00236 for(size_t j=0; j<data.columns(); j++)
00237 class_averager[target(j)].add(data.data(i,j),data.weight(i,j));
00238 for(size_t c=0;c<target.nof_classes();c++) {
00239 if(class_averager[c].sum_w()==0) {
00240 centroids_nan_=true;
00241 }
00242 centroids_(i,c) = class_averager[c].mean();
00243 }
00244 }
00245 }
00246
00247
00248 template <typename Distance>
00249 void NCC<Distance>::predict(const MatrixLookup& test,
00250 utility::Matrix& prediction) const
00251 {
00252 utility::yat_assert<std::runtime_error>
00253 (centroids_.rows()==test.rows(),
00254 "NCC::predict test data with incorrect number of rows");
00255
00256 prediction.resize(centroids_.columns(), test.columns());
00257
00258
00259 if(centroids_nan_) {
00260 predict_weighted(MatrixLookupWeighted(test),prediction);
00261 }
00262
00263 else {
00264 predict_unweighted(test,prediction);
00265 }
00266 }
00267
00268 template <typename Distance>
00269 void NCC<Distance>::predict(const MatrixLookupWeighted& test,
00270 utility::Matrix& prediction) const
00271 {
00272 utility::yat_assert<std::runtime_error>
00273 (centroids_.rows()==test.rows(),
00274 "NCC::predict test data with incorrect number of rows");
00275
00276 prediction.resize(centroids_.columns(), test.columns());
00277 predict_weighted(test,prediction);
00278 }
00279
00280
00281 template <typename Distance>
00282 void NCC<Distance>::predict_unweighted(const MatrixLookup& test,
00283 utility::Matrix& prediction) const
00284 {
00285 for(size_t j=0; j<test.columns();j++)
00286 for(size_t k=0; k<centroids_.columns();k++)
00287 prediction(k,j) = distance_(test.begin_column(j), test.end_column(j),
00288 centroids_.begin_column(k));
00289 }
00290
00291 template <typename Distance>
00292 void NCC<Distance>::predict_weighted(const MatrixLookupWeighted& test,
00293 utility::Matrix& prediction) const
00294 {
00295 utility::MatrixWeighted weighted_centroids(centroids_);
00296 for(size_t j=0; j<test.columns();j++)
00297 for(size_t k=0; k<centroids_.columns();k++)
00298 prediction(k,j) = distance_(test.begin_column(j), test.end_column(j),
00299 weighted_centroids.begin_column(k));
00300 }
00301
00302
00303 }}}
00304
00305 #endif