test/ensemble.cc

Code
Comments
Other
Rev Date Author Line
518 21 Feb 06 peter 1 // $Id$
518 21 Feb 06 peter 2
675 10 Oct 06 jari 3 /*
2119 12 Dec 09 peter 4   Copyright (C) 2006 Jari Häkkinen, Peter Johansson, Markus Ringnér
4359 23 Aug 23 peter 5   Copyright (C) 2007 Peter Johansson
4359 23 Aug 23 peter 6   Copyright (C) 2008 Jari Häkkinen, Peter Johansson
4359 23 Aug 23 peter 7   Copyright (C) 2009, 2012 Peter Johansson
518 21 Feb 06 peter 8
1437 25 Aug 08 peter 9   This file is part of the yat library, http://dev.thep.lu.se/yat
675 10 Oct 06 jari 10
675 10 Oct 06 jari 11   The yat library is free software; you can redistribute it and/or
675 10 Oct 06 jari 12   modify it under the terms of the GNU General Public License as
1486 09 Sep 08 jari 13   published by the Free Software Foundation; either version 3 of the
675 10 Oct 06 jari 14   License, or (at your option) any later version.
675 10 Oct 06 jari 15
675 10 Oct 06 jari 16   The yat library is distributed in the hope that it will be useful,
675 10 Oct 06 jari 17   but WITHOUT ANY WARRANTY; without even the implied warranty of
675 10 Oct 06 jari 18   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
675 10 Oct 06 jari 19   General Public License for more details.
675 10 Oct 06 jari 20
675 10 Oct 06 jari 21   You should have received a copy of the GNU General Public License
1487 10 Sep 08 jari 22   along with yat. If not, see <http://www.gnu.org/licenses/>.
675 10 Oct 06 jari 23 */
675 10 Oct 06 jari 24
2881 18 Nov 12 peter 25 #include <config.h>
2881 18 Nov 12 peter 26
1236 15 Mar 08 peter 27 #include "Suite.h"
1236 15 Mar 08 peter 28
1121 22 Feb 08 peter 29 #include "yat/utility/Matrix.h"
675 10 Oct 06 jari 30 #include "yat/classifier/SubsetGenerator.h"
675 10 Oct 06 jari 31 #include "yat/classifier/CrossValidationSampler.h"
675 10 Oct 06 jari 32 #include "yat/classifier/EnsembleBuilder.h"
675 10 Oct 06 jari 33 #include "yat/classifier/Kernel.h"
675 10 Oct 06 jari 34 #include "yat/classifier/KernelLookup.h"
675 10 Oct 06 jari 35 #include "yat/classifier/Kernel_SEV.h"
675 10 Oct 06 jari 36 #include "yat/classifier/Kernel_MEV.h"
675 10 Oct 06 jari 37 #include "yat/classifier/MatrixLookup.h"
1206 05 Mar 08 peter 38 #include "yat/classifier/MatrixLookupWeighted.h"
675 10 Oct 06 jari 39 #include "yat/classifier/NCC.h"
675 10 Oct 06 jari 40 #include "yat/classifier/PolynomialKernelFunction.h"
675 10 Oct 06 jari 41 #include "yat/classifier/SVM.h"
824 19 Mar 07 peter 42 #include "yat/statistics/AUC.h"
1100 18 Feb 08 peter 43 #include "yat/statistics/EuclideanDistance.h"
675 10 Oct 06 jari 44
518 21 Feb 06 peter 45 #include <cassert>
518 21 Feb 06 peter 46 #include <fstream>
518 21 Feb 06 peter 47 #include <iostream>
518 21 Feb 06 peter 48 #include <cstdlib>
518 21 Feb 06 peter 49 #include <limits>
518 21 Feb 06 peter 50
518 21 Feb 06 peter 51
1236 15 Mar 08 peter 52 int main(int argc, char* argv[])
4200 19 Aug 22 peter 53 {
680 11 Oct 06 jari 54   using namespace theplu::yat;
1236 15 Mar 08 peter 55   test::Suite suite(argc, argv);
4200 19 Aug 22 peter 56
1236 15 Mar 08 peter 57   suite.err() << "testing ensemble" << std::endl;
518 21 Feb 06 peter 58
1236 15 Mar 08 peter 59   suite.err() << "loading data" << std::endl;
1251 03 Apr 08 peter 60   std::ifstream is(test::filename("data/nm_data_centralized.txt").c_str());
1121 22 Feb 08 peter 61   utility::Matrix data_core(is);
518 21 Feb 06 peter 62   is.close();
518 21 Feb 06 peter 63
1236 15 Mar 08 peter 64   suite.err() << "create MatrixLookup" << std::endl;
527 01 Mar 06 peter 65   classifier::MatrixLookup data(data_core);
4200 19 Aug 22 peter 66   classifier::KernelFunction* kf = new classifier::PolynomialKernelFunction();
1236 15 Mar 08 peter 67   suite.err() << "Building kernel" << std::endl;
518 21 Feb 06 peter 68   classifier::Kernel_SEV kernel(data,*kf);
518 21 Feb 06 peter 69
518 21 Feb 06 peter 70
1236 15 Mar 08 peter 71   suite.err() << "load target" << std::endl;
1251 03 Apr 08 peter 72   is.open(test::filename("data/nm_target_bin.txt").c_str());
518 21 Feb 06 peter 73    classifier::Target target(is);
518 21 Feb 06 peter 74    is.close();
619 04 Sep 06 peter 75   assert(data.columns()==target.size());
518 21 Feb 06 peter 76
1100 18 Feb 08 peter 77   {
1236 15 Mar 08 peter 78     suite.err() << "create ensemble of ncc" << std::endl;
1157 26 Feb 08 markus 79     classifier::NCC<statistics::EuclideanDistance> ncc;
1100 18 Feb 08 peter 80     classifier::CrossValidationSampler sampler(target,3,3);
1100 18 Feb 08 peter 81     classifier::SubsetGenerator<classifier::MatrixLookup> subdata(sampler,data);
1100 18 Feb 08 peter 82     classifier::EnsembleBuilder<classifier::SupervisedClassifier,
1100 18 Feb 08 peter 83       classifier::MatrixLookup> ensemble(ncc, data, sampler);
1236 15 Mar 08 peter 84     suite.err() << "build ensemble" << std::endl;
1100 18 Feb 08 peter 85     ensemble.build();
1206 05 Mar 08 peter 86     std::vector<std::vector<statistics::Averager> > result;
1206 05 Mar 08 peter 87     ensemble.predict(data, result);
1100 18 Feb 08 peter 88   }
1100 18 Feb 08 peter 89
1206 05 Mar 08 peter 90   {
1236 15 Mar 08 peter 91     suite.err() << "create ensemble of ncc" << std::endl;
1206 05 Mar 08 peter 92     classifier::MatrixLookupWeighted data_weighted(data);
1206 05 Mar 08 peter 93     classifier::NCC<statistics::EuclideanDistance> ncc;
1206 05 Mar 08 peter 94     classifier::CrossValidationSampler sampler(target,3,3);
4200 19 Aug 22 peter 95     classifier::SubsetGenerator<classifier::MatrixLookupWeighted>
1206 05 Mar 08 peter 96       subdata(sampler,data_weighted);
1206 05 Mar 08 peter 97     classifier::EnsembleBuilder<classifier::SupervisedClassifier,
1206 05 Mar 08 peter 98       classifier::MatrixLookupWeighted> ensemble(ncc, data_weighted, sampler);
1236 15 Mar 08 peter 99     suite.err() << "build ensemble" << std::endl;
1206 05 Mar 08 peter 100     ensemble.build();
1206 05 Mar 08 peter 101     std::vector<std::vector<statistics::Averager> > result;
1206 05 Mar 08 peter 102     ensemble.predict(data_weighted, result);
1206 05 Mar 08 peter 103   }
1206 05 Mar 08 peter 104
1236 15 Mar 08 peter 105   suite.err() << "create KernelLookup" << std::endl;
518 21 Feb 06 peter 106   classifier::KernelLookup kernel_lookup(kernel);
1236 15 Mar 08 peter 107   suite.err() << "create svm" << std::endl;
1161 26 Feb 08 peter 108    classifier::SVM svm;
1236 15 Mar 08 peter 109   suite.err() << "create Subsets" << std::endl;
615 31 Aug 06 peter 110   classifier::CrossValidationSampler sampler(target,3,3);
1087 14 Feb 08 peter 111   classifier::SubsetGenerator<classifier::KernelLookup> cv(sampler,
1072 12 Feb 08 peter 112                                                            kernel_lookup);
1161 26 Feb 08 peter 113
1236 15 Mar 08 peter 114   suite.err() << "create ensemble" << std::endl;
4200 19 Aug 22 peter 115   classifier::EnsembleBuilder<classifier::SVM, classifier::KernelLookup>
1087 14 Feb 08 peter 116     ensemble(svm, kernel_lookup, sampler);
1236 15 Mar 08 peter 117   suite.err() << "build ensemble" << std::endl;
619 04 Sep 06 peter 118   ensemble.build();
1161 26 Feb 08 peter 119   utility::Vector out(target.size(),0);
2138 24 Dec 09 peter 120   for (size_t i = 0; i<out.size(); ++i) {
4200 19 Aug 22 peter 121     out(i)=ensemble.validate()[0][i].mean();
2138 24 Dec 09 peter 122   }
824 19 Mar 07 peter 123   statistics::AUC roc;
1236 15 Mar 08 peter 124   suite.err() << roc.score(target,out) << std::endl;
1161 26 Feb 08 peter 125
2138 24 Dec 09 peter 126   std::vector<std::vector<statistics::Averager> > result;
2138 24 Dec 09 peter 127   ensemble.predict(kernel_lookup, result);
2138 24 Dec 09 peter 128   for (size_t i = 0; i<result.size(); ++i) {
2138 24 Dec 09 peter 129     for (size_t j=0; j<result[0].size(); ++j) {
2138 24 Dec 09 peter 130       if (!suite.add(result[i][j].variance() > 0)) {
2138 24 Dec 09 peter 131         suite.err() << "error: element " << i << " " << j << "\n";
2138 24 Dec 09 peter 132         suite.err() << "expected finite prediction varince\n";
2138 24 Dec 09 peter 133         suite.err() << "found: " << result[i][j].variance() << "\n";
2138 24 Dec 09 peter 134       }
4200 19 Aug 22 peter 135     }
2138 24 Dec 09 peter 136   }
2138 24 Dec 09 peter 137
1227 13 Mar 08 peter 138   {
2138 24 Dec 09 peter 139     suite.err() << "test ensemble of SVMs with weighted kernel" << std::endl;
2138 24 Dec 09 peter 140     classifier::MatrixLookupWeighted wdata(data_core);
2138 24 Dec 09 peter 141     classifier::Kernel_SEV kernel(wdata, *kf);
2138 24 Dec 09 peter 142     classifier::KernelLookup wkl(kernel);
4200 19 Aug 22 peter 143     classifier::EnsembleBuilder<classifier::SVM, classifier::KernelLookup>
2138 24 Dec 09 peter 144       ensemble(svm, wkl, sampler);
2138 24 Dec 09 peter 145     suite.err() << "build ensemble" << std::endl;
2138 24 Dec 09 peter 146     ensemble.build();
4200 19 Aug 22 peter 147     ensemble.validate();
2138 24 Dec 09 peter 148     std::vector<std::vector<statistics::Averager> > result;
2138 24 Dec 09 peter 149     ensemble.predict(wkl, result);
2138 24 Dec 09 peter 150   }
2138 24 Dec 09 peter 151
2138 24 Dec 09 peter 152   {
1236 15 Mar 08 peter 153     suite.err() << "create ensemble" << std::endl;
4200 19 Aug 22 peter 154     classifier::EnsembleBuilder<classifier::SVM, classifier::KernelLookup>
1227 13 Mar 08 peter 155       ensemble(svm, kernel_lookup, sampler);
1236 15 Mar 08 peter 156     suite.err() << "test validate() before build()\n";
1227 13 Mar 08 peter 157     ensemble.validate();
1227 13 Mar 08 peter 158     std::vector<std::vector<statistics::Averager> > result;
1236 15 Mar 08 peter 159     suite.err() << "test predict() before build()\n";
1227 13 Mar 08 peter 160     ensemble.predict(kernel_lookup, result);
1227 13 Mar 08 peter 161   }
518 21 Feb 06 peter 162   delete kf;
518 21 Feb 06 peter 163
1236 15 Mar 08 peter 164   return suite.return_value();
518 21 Feb 06 peter 165 }