test/svm.cc

Code
Comments
Other
Rev Date Author Line
31 16 Jan 04 peter 1 // $Id$
31 16 Jan 04 peter 2
675 10 Oct 06 jari 3 /*
2119 12 Dec 09 peter 4   Copyright (C) 2004, 2005 Jari Häkkinen, Peter Johansson
2119 12 Dec 09 peter 5   Copyright (C) 2006 Jari Häkkinen, Peter Johansson, Markus Ringnér
4359 23 Aug 23 peter 6   Copyright (C) 2007 Peter Johansson
4359 23 Aug 23 peter 7   Copyright (C) 2008 Jari Häkkinen, Peter Johansson
4359 23 Aug 23 peter 8   Copyright (C) 2012 Peter Johansson
31 16 Jan 04 peter 9
1437 25 Aug 08 peter 10   This file is part of the yat library, http://dev.thep.lu.se/yat
675 10 Oct 06 jari 11
675 10 Oct 06 jari 12   The yat library is free software; you can redistribute it and/or
675 10 Oct 06 jari 13   modify it under the terms of the GNU General Public License as
1486 09 Sep 08 jari 14   published by the Free Software Foundation; either version 3 of the
675 10 Oct 06 jari 15   License, or (at your option) any later version.
675 10 Oct 06 jari 16
675 10 Oct 06 jari 17   The yat library is distributed in the hope that it will be useful,
675 10 Oct 06 jari 18   but WITHOUT ANY WARRANTY; without even the implied warranty of
675 10 Oct 06 jari 19   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
675 10 Oct 06 jari 20   General Public License for more details.
675 10 Oct 06 jari 21
675 10 Oct 06 jari 22   You should have received a copy of the GNU General Public License
1487 10 Sep 08 jari 23   along with yat. If not, see <http://www.gnu.org/licenses/>.
675 10 Oct 06 jari 24 */
675 10 Oct 06 jari 25
2881 18 Nov 12 peter 26 #include <config.h>
2881 18 Nov 12 peter 27
1248 19 Mar 08 peter 28 #include "Suite.h"
1248 19 Mar 08 peter 29
675 10 Oct 06 jari 30 #include "yat/classifier/SVM.h"
675 10 Oct 06 jari 31 #include "yat/classifier/Kernel.h"
675 10 Oct 06 jari 32 #include "yat/classifier/KernelLookup.h"
675 10 Oct 06 jari 33 #include "yat/classifier/Kernel_SEV.h"
675 10 Oct 06 jari 34 #include "yat/classifier/Kernel_MEV.h"
1133 23 Feb 08 peter 35 #include "yat/classifier/MatrixLookup.h"
675 10 Oct 06 jari 36 #include "yat/classifier/PolynomialKernelFunction.h"
747 11 Feb 07 peter 37 #include "yat/classifier/Target.h"
1121 22 Feb 08 peter 38 #include "yat/utility/Matrix.h"
1120 21 Feb 08 peter 39 #include "yat/utility/Vector.h"
675 10 Oct 06 jari 40
442 15 Dec 05 jari 41 #include <cassert>
42 26 Feb 04 jari 42 #include <fstream>
31 16 Jan 04 peter 43 #include <iostream>
31 16 Jan 04 peter 44 #include <cstdlib>
442 15 Dec 05 jari 45 #include <limits>
31 16 Jan 04 peter 46
680 11 Oct 06 jari 47 using namespace theplu::yat;
31 16 Jan 04 peter 48
1248 19 Mar 08 peter 49 int main( int argc, char* argv[])
4200 19 Aug 22 peter 50 {
1248 19 Mar 08 peter 51   test::Suite suite(argc, argv);
1248 19 Mar 08 peter 52   suite.err() << "testing svm" << std::endl;
209 03 Nov 04 peter 53
1121 22 Feb 08 peter 54   utility::Matrix data2_core(2,3);
527 01 Mar 06 peter 55   data2_core(0,0)=0;
527 01 Mar 06 peter 56   data2_core(1,0)=0;
527 01 Mar 06 peter 57   data2_core(0,1)=0;
527 01 Mar 06 peter 58   data2_core(1,1)=1;
527 01 Mar 06 peter 59   data2_core(0,2)=1;
527 01 Mar 06 peter 60   data2_core(1,2)=0;
527 01 Mar 06 peter 61   classifier::MatrixLookup data2(data2_core);
509 18 Feb 06 peter 62   std::vector<std::string> label;
509 18 Feb 06 peter 63   label.reserve(3);
509 18 Feb 06 peter 64   label.push_back("-1");
509 18 Feb 06 peter 65   label.push_back("1");
509 18 Feb 06 peter 66   label.push_back("1");
509 18 Feb 06 peter 67   classifier::Target target2(label);
4200 19 Aug 22 peter 68   classifier::KernelFunction* kf2 = new classifier::PolynomialKernelFunction();
453 15 Dec 05 peter 69   classifier::Kernel_MEV kernel2(data2,*kf2);
323 26 May 05 peter 70   assert(kernel2.size()==3);
463 16 Dec 05 peter 71   assert(target2.size()==3);
527 01 Mar 06 peter 72   for (size_t i=0; i<3; i++){
527 01 Mar 06 peter 73     for (size_t j=0; j<3; j++)
1248 19 Mar 08 peter 74       suite.err() << kernel2(i,j) << " ";
1248 19 Mar 08 peter 75     suite.err() << std::endl;
527 01 Mar 06 peter 76   }
475 22 Dec 05 peter 77   classifier::KernelLookup kv2(kernel2);
1248 19 Mar 08 peter 78   suite.err() << "testing with linear kernel" << std::endl;
463 16 Dec 05 peter 79   assert(kv2.rows()==target2.size());
1100 18 Feb 08 peter 80   classifier::SVM classifier2;
1248 19 Mar 08 peter 81   suite.err() << "training...";
1100 18 Feb 08 peter 82   classifier2.train(kv2, target2);
1248 19 Mar 08 peter 83   suite.err() << " done!" << std::endl;
31 16 Jan 04 peter 84
475 22 Dec 05 peter 85   double tmp=0;
4200 19 Aug 22 peter 86   for (size_t i=0; i<target2.size(); i++)
514 20 Feb 06 peter 87     if (target2.binary(i))
509 18 Feb 06 peter 88       tmp += classifier2.alpha()(i);
509 18 Feb 06 peter 89     else
509 18 Feb 06 peter 90       tmp -= classifier2.alpha()(i);
509 18 Feb 06 peter 91
475 22 Dec 05 peter 92   if (tmp){
1248 19 Mar 08 peter 93     suite.err() << "ERROR: found " << tmp << " expected zero" << std::endl;
323 26 May 05 peter 94     return -1;
323 26 May 05 peter 95   }
116 19 Jul 04 peter 96
1672 22 Dec 08 peter 97   // tol defined on learning precision
509 18 Feb 06 peter 98   double tol=1e-6;
1672 22 Dec 08 peter 99   if (!suite.equal_fix(classifier2.alpha()(1), 2.0, tol) ||
1672 22 Dec 08 peter 100       !suite.equal_fix(classifier2.alpha()(2), 2.0, tol) ) {
1248 19 Mar 08 peter 101     suite.err() << "wrong alpha" << std::endl;
1248 19 Mar 08 peter 102     suite.err() << "alpha: " << classifier2.alpha() <<  std::endl;
1248 19 Mar 08 peter 103     suite.err() << "expected: 4 2 2" <<  std::endl;
323 26 May 05 peter 104
323 26 May 05 peter 105     return -1;
323 26 May 05 peter 106   }
323 26 May 05 peter 107
323 26 May 05 peter 108
4200 19 Aug 22 peter 109
1251 03 Apr 08 peter 110   std::ifstream is(test::filename("data/nm_data_centralized.txt").c_str());
1121 22 Feb 08 peter 111   utility::Matrix data_core(is);
323 26 May 05 peter 112   is.close();
323 26 May 05 peter 113
527 01 Mar 06 peter 114   classifier::MatrixLookup data(data_core);
527 01 Mar 06 peter 115
4200 19 Aug 22 peter 116   classifier::KernelFunction* kf = new classifier::PolynomialKernelFunction();
453 15 Dec 05 peter 117   classifier::Kernel_SEV kernel(data,*kf);
323 26 May 05 peter 118
323 26 May 05 peter 119
1251 03 Apr 08 peter 120   is.open(test::filename("data/nm_target_bin.txt").c_str());
475 22 Dec 05 peter 121    classifier::Target target(is);
323 26 May 05 peter 122    is.close();
323 26 May 05 peter 123
1251 03 Apr 08 peter 124    is.open(test::filename("data/nm_alpha_linear_matlab.txt").c_str());
1120 21 Feb 08 peter 125    theplu::yat::utility::Vector alpha_matlab(is);
323 26 May 05 peter 126    is.close();
323 26 May 05 peter 127
475 22 Dec 05 peter 128   classifier::KernelLookup kv(kernel);
1100 18 Feb 08 peter 129    theplu::yat::classifier::SVM svm;
1100 18 Feb 08 peter 130   svm.train(kv, target);
323 26 May 05 peter 131
1120 21 Feb 08 peter 132    theplu::yat::utility::Vector alpha = svm.alpha();
4200 19 Aug 22 peter 133
323 26 May 05 peter 134   // Comparing alpha to alpha_matlab
1672 22 Dec 08 peter 135   if (!suite.equal_range_fix(alpha.begin(), alpha.end(),
1672 22 Dec 08 peter 136                              alpha_matlab.begin(), 1e-6) ) {
1248 19 Mar 08 peter 137     suite.err() << "Difference to matlab alphas too large\n";
1248 19 Mar 08 peter 138      suite.add(false);
323 26 May 05 peter 139   }
323 26 May 05 peter 140
1672 22 Dec 08 peter 141
323 26 May 05 peter 142    // Comparing output to target
1120 21 Feb 08 peter 143   theplu::yat::utility::Vector output(svm.output());
323 26 May 05 peter 144    double slack = 0;
323 26 May 05 peter 145    for (unsigned int i=0; i<target.size(); i++){
475 22 Dec 05 peter 146      if (output(i)*target(i) < 1){
514 20 Feb 06 peter 147       if (target.binary(i))
1672 22 Dec 08 peter 148         slack = 1 - output(i);
509 18 Feb 06 peter 149       else
1672 22 Dec 08 peter 150         slack = 1 + output(i);
1672 22 Dec 08 peter 151       double slack_bound=2e-7;
1672 22 Dec 08 peter 152       if (slack > slack_bound || std::isnan(slack)){
1672 22 Dec 08 peter 153         suite.err() << "Slack too large. Is the bias correct?\n";
1672 22 Dec 08 peter 154         suite.err() << "slack: " << slack << std::endl;
1672 22 Dec 08 peter 155         suite.err() << "expected less than " << slack_bound << std::endl;
1672 22 Dec 08 peter 156         suite.add(false);
1672 22 Dec 08 peter 157       }
323 26 May 05 peter 158      }
323 26 May 05 peter 159    }
4200 19 Aug 22 peter 160
337 03 Jun 05 peter 161   delete kf;
337 03 Jun 05 peter 162   delete kf2;
337 03 Jun 05 peter 163
1248 19 Mar 08 peter 164   return suite.return_value();
31 16 Jan 04 peter 165 }