yat
0.8.3pre
|
00001 #ifndef _theplu_yat_classifier_ncc_ 00002 #define _theplu_yat_classifier_ncc_ 00003 00004 // $Id: NCC.h 2384 2010-12-22 14:03:36Z peter $ 00005 00006 /* 00007 Copyright (C) 2005 Peter Johansson, Markus Ringnér 00008 Copyright (C) 2006, 2007, 2008 Jari Häkkinen, Peter Johansson, Markus Ringnér 00009 Copyright (C) 2009, 2010 Peter Johansson 00010 00011 This file is part of the yat library, http://dev.thep.lu.se/yat 00012 00013 The yat library is free software; you can redistribute it and/or 00014 modify it under the terms of the GNU General Public License as 00015 published by the Free Software Foundation; either version 3 of the 00016 License, or (at your option) any later version. 00017 00018 The yat library is distributed in the hope that it will be useful, 00019 but WITHOUT ANY WARRANTY; without even the implied warranty of 00020 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU 00021 General Public License for more details. 00022 00023 You should have received a copy of the GNU General Public License 00024 along with yat. If not, see <http://www.gnu.org/licenses/>. 00025 */ 00026 00027 #include "MatrixLookup.h" 00028 #include "MatrixLookupWeighted.h" 00029 #include "SupervisedClassifier.h" 00030 #include "Target.h" 00031 00032 #include "yat/statistics/Averager.h" 00033 #include "yat/statistics/AveragerWeighted.h" 00034 #include "yat/utility/concept_check.h" 00035 #include "yat/utility/Exception.h" 00036 #include "yat/utility/Matrix.h" 00037 #include "yat/utility/MatrixWeighted.h" 00038 #include "yat/utility/Vector.h" 00039 #include "yat/utility/stl_utility.h" 00040 #include "yat/utility/yat_assert.h" 00041 00042 #include <boost/concept_check.hpp> 00043 00044 #include <iterator> 00045 #include <map> 00046 #include <cmath> 00047 00048 namespace theplu { 00049 namespace yat { 00050 namespace classifier { 00051 00052 00064 template <typename Distance> 00065 class NCC : public SupervisedClassifier 00066 { 00067 00068 public: 00074 NCC(void); 00075 00083 NCC(const Distance&); 00084 00085 00089 virtual ~NCC(void); 00090 00096 const utility::Matrix& centroids(void) const; 00097 00098 NCC<Distance>* make_classifier(void) const; 00099 00117 void predict(const MatrixLookup& data, utility::Matrix& results) const; 00118 00140 void predict(const MatrixLookupWeighted& data, utility::Matrix& results) const; 00141 00150 void train(const MatrixLookup& data, const Target& targets); 00151 00152 00165 void train(const MatrixLookupWeighted& data, const Target& targets); 00166 00167 00168 private: 00169 00170 void predict_unweighted(const MatrixLookup&, utility::Matrix&) const; 00171 void predict_weighted(const MatrixLookupWeighted&, utility::Matrix&) const; 00172 00173 utility::Matrix centroids_; 00174 bool centroids_nan_; 00175 Distance distance_; 00176 }; 00177 00178 // templates 00179 00180 template <typename Distance> 00181 NCC<Distance>::NCC() 00182 : SupervisedClassifier(), centroids_nan_(false) 00183 { 00184 BOOST_CONCEPT_ASSERT((utility::DistanceConcept<Distance>)); 00185 } 00186 00187 template <typename Distance> 00188 NCC<Distance>::NCC(const Distance& dist) 00189 : SupervisedClassifier(), centroids_nan_(false), distance_(dist) 00190 { 00191 BOOST_CONCEPT_ASSERT((utility::DistanceConcept<Distance>)); 00192 } 00193 00194 00195 template <typename Distance> 00196 NCC<Distance>::~NCC() 00197 { 00198 } 00199 00200 00201 template <typename Distance> 00202 const utility::Matrix& NCC<Distance>::centroids(void) const 00203 { 00204 return centroids_; 00205 } 00206 00207 00208 template <typename Distance> 00209 NCC<Distance>* 00210 NCC<Distance>::make_classifier() const 00211 { 00212 // All private members should be copied here to generate an 00213 // identical but untrained classifier 00214 return new NCC<Distance>(distance_); 00215 } 00216 00217 template <typename Distance> 00218 void NCC<Distance>::train(const MatrixLookup& data, const Target& target) 00219 { 00220 centroids_.resize(data.rows(), target.nof_classes()); 00221 for(size_t i=0; i<data.rows(); i++) { 00222 std::vector<statistics::Averager> class_averager; 00223 class_averager.resize(target.nof_classes()); 00224 for(size_t j=0; j<data.columns(); j++) { 00225 class_averager[target(j)].add(data(i,j)); 00226 } 00227 for(size_t c=0;c<target.nof_classes();c++) { 00228 centroids_(i,c) = class_averager[c].mean(); 00229 } 00230 } 00231 } 00232 00233 00234 template <typename Distance> 00235 void NCC<Distance>::train(const MatrixLookupWeighted& data, const Target& target) 00236 { 00237 centroids_.resize(data.rows(), target.nof_classes()); 00238 for(size_t i=0; i<data.rows(); i++) { 00239 std::vector<statistics::AveragerWeighted> class_averager; 00240 class_averager.resize(target.nof_classes()); 00241 for(size_t j=0; j<data.columns(); j++) 00242 class_averager[target(j)].add(data.data(i,j),data.weight(i,j)); 00243 for(size_t c=0;c<target.nof_classes();c++) { 00244 if(class_averager[c].sum_w()==0) { 00245 centroids_nan_=true; 00246 } 00247 centroids_(i,c) = class_averager[c].mean(); 00248 } 00249 } 00250 } 00251 00252 00253 template <typename Distance> 00254 void NCC<Distance>::predict(const MatrixLookup& test, 00255 utility::Matrix& prediction) const 00256 { 00257 utility::yat_assert<utility::runtime_error> 00258 (centroids_.rows()==test.rows(), 00259 "NCC::predict test data with incorrect number of rows"); 00260 00261 prediction.resize(centroids_.columns(), test.columns()); 00262 00263 // If weighted training data has resulted in NaN in centroids: weighted calculations 00264 if(centroids_nan_) { 00265 predict_weighted(MatrixLookupWeighted(test),prediction); 00266 } 00267 // If unweighted training data: unweighted calculations 00268 else { 00269 predict_unweighted(test,prediction); 00270 } 00271 } 00272 00273 template <typename Distance> 00274 void NCC<Distance>::predict(const MatrixLookupWeighted& test, 00275 utility::Matrix& prediction) const 00276 { 00277 utility::yat_assert<utility::runtime_error> 00278 (centroids_.rows()==test.rows(), 00279 "NCC::predict test data with incorrect number of rows"); 00280 00281 prediction.resize(centroids_.columns(), test.columns()); 00282 predict_weighted(test,prediction); 00283 } 00284 00285 00286 template <typename Distance> 00287 void NCC<Distance>::predict_unweighted(const MatrixLookup& test, 00288 utility::Matrix& prediction) const 00289 { 00290 for(size_t j=0; j<test.columns();j++) 00291 for(size_t k=0; k<centroids_.columns();k++) 00292 prediction(k,j) = distance_(test.begin_column(j), test.end_column(j), 00293 centroids_.begin_column(k)); 00294 } 00295 00296 template <typename Distance> 00297 void NCC<Distance>::predict_weighted(const MatrixLookupWeighted& test, 00298 utility::Matrix& prediction) const 00299 { 00300 utility::MatrixWeighted weighted_centroids(centroids_); 00301 for(size_t j=0; j<test.columns();j++) 00302 for(size_t k=0; k<centroids_.columns();k++) 00303 prediction(k,j) = distance_(test.begin_column(j), test.end_column(j), 00304 weighted_centroids.begin_column(k)); 00305 } 00306 00307 00308 }}} // of namespace classifier, yat, and theplu 00309 00310 #endif