test/svm_multi_class.cc

Code
Comments
Other
Rev Date Author Line
1861 12 Mar 09 peter 1 // $Id$
1861 12 Mar 09 peter 2
1861 12 Mar 09 peter 3 /*
4359 23 Aug 23 peter 4   Copyright (C) 2009, 2012 Peter Johansson
1861 12 Mar 09 peter 5
1861 12 Mar 09 peter 6   This file is part of the yat library, http://dev.thep.lu.se/yat
1861 12 Mar 09 peter 7
1861 12 Mar 09 peter 8   The yat library is free software; you can redistribute it and/or
1861 12 Mar 09 peter 9   modify it under the terms of the GNU General Public License as
1861 12 Mar 09 peter 10   published by the Free Software Foundation; either version 3 of the
1861 12 Mar 09 peter 11   License, or (at your option) any later version.
1861 12 Mar 09 peter 12
1861 12 Mar 09 peter 13   The yat library is distributed in the hope that it will be useful,
1861 12 Mar 09 peter 14   but WITHOUT ANY WARRANTY; without even the implied warranty of
1861 12 Mar 09 peter 15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
1861 12 Mar 09 peter 16   General Public License for more details.
1861 12 Mar 09 peter 17
1861 12 Mar 09 peter 18   You should have received a copy of the GNU General Public License
1861 12 Mar 09 peter 19   along with yat. If not, see <http://www.gnu.org/licenses/>.
1861 12 Mar 09 peter 20 */
1861 12 Mar 09 peter 21
2881 18 Nov 12 peter 22 #include <config.h>
2881 18 Nov 12 peter 23
1861 12 Mar 09 peter 24 #include "Suite.h"
1861 12 Mar 09 peter 25
1864 15 Mar 09 peter 26 #include "yat/classifier/Kernel_SEV.h"
1864 15 Mar 09 peter 27 #include "yat/classifier/KernelLookup.h"
1864 15 Mar 09 peter 28 #include "yat/classifier/MatrixLookupWeighted.h"
1864 15 Mar 09 peter 29 #include "yat/classifier/PolynomialKernelFunction.h"
1861 12 Mar 09 peter 30 #include "yat/classifier/SvmMultiClass.h"
1861 12 Mar 09 peter 31
1864 15 Mar 09 peter 32 #include <fstream>
1864 15 Mar 09 peter 33 #include <string>
1864 15 Mar 09 peter 34
1861 12 Mar 09 peter 35 using namespace theplu::yat;
1861 12 Mar 09 peter 36
1861 12 Mar 09 peter 37 void test_construction(test::Suite&);
1864 15 Mar 09 peter 38 void test_predict(test::Suite&);
1865 16 Mar 09 peter 39 void test_predict2(test::Suite&);
1861 12 Mar 09 peter 40
1861 12 Mar 09 peter 41 int main( int argc, char* argv[])
4200 19 Aug 22 peter 42 {
1861 12 Mar 09 peter 43   test::Suite suite(argc, argv);
1861 12 Mar 09 peter 44   suite.err() << "testing SvmMultiClass" << std::endl;
1861 12 Mar 09 peter 45   test_construction(suite);
1864 15 Mar 09 peter 46   test_predict(suite);
1865 16 Mar 09 peter 47   test_predict2(suite);
1861 12 Mar 09 peter 48
1861 12 Mar 09 peter 49   return suite.return_value();
1861 12 Mar 09 peter 50 }
1861 12 Mar 09 peter 51
1861 12 Mar 09 peter 52 void test_construction(test::Suite& suite)
1861 12 Mar 09 peter 53 {
1861 12 Mar 09 peter 54   using classifier::SvmMultiClass;
1861 12 Mar 09 peter 55   SvmMultiClass svm;
1861 12 Mar 09 peter 56   svm.set_C(100);
1861 12 Mar 09 peter 57   suite.add(suite.equal(svm.C(), 100));
1861 12 Mar 09 peter 58   svm.max_epochs(1000);
1861 12 Mar 09 peter 59   suite.add(svm.max_epochs()==1000);
1861 12 Mar 09 peter 60   SvmMultiClass svm2(svm);
1862 12 Mar 09 peter 61   suite.add(suite.equal(svm.C(),svm2.C()));
1861 12 Mar 09 peter 62   suite.add(svm.max_epochs()==svm2.max_epochs());
1861 12 Mar 09 peter 63   SvmMultiClass* svm3 = svm2.make_classifier();
1862 12 Mar 09 peter 64   suite.add(suite.equal(svm3->C(),svm2.C()));
1861 12 Mar 09 peter 65   suite.add(svm3->max_epochs()==svm2.max_epochs());
1861 12 Mar 09 peter 66   delete svm3;
1861 12 Mar 09 peter 67 }
1864 15 Mar 09 peter 68
1864 15 Mar 09 peter 69
1864 15 Mar 09 peter 70 void test_predict(test::Suite& suite)
1864 15 Mar 09 peter 71 {
1864 15 Mar 09 peter 72   using namespace classifier;
1864 15 Mar 09 peter 73   std::string file = test::filename("data/sorlie_centroid_data.txt");
1864 15 Mar 09 peter 74   std::ifstream is(file.c_str());
1864 15 Mar 09 peter 75   suite.err() << "load data `" << file << "'" << std::endl;
1864 15 Mar 09 peter 76   MatrixLookupWeighted data(is, '\t');
1864 15 Mar 09 peter 77   is.close();
1864 15 Mar 09 peter 78   PolynomialKernelFunction linear;
1864 15 Mar 09 peter 79   suite.err() << "calculating kernel" << std::endl;
1864 15 Mar 09 peter 80   Kernel_SEV kernel_raw(data, linear);
1864 15 Mar 09 peter 81   KernelLookup kernel(kernel_raw);
1864 15 Mar 09 peter 82   file = test::filename("data/sorlie_centroid_classes.txt");
1864 15 Mar 09 peter 83   suite.err() << "load classes `" << file << "'" << std::endl;
1864 15 Mar 09 peter 84   is.open(file.c_str());
1864 15 Mar 09 peter 85   Target target(is);
1864 15 Mar 09 peter 86   is.close();
1864 15 Mar 09 peter 87
1864 15 Mar 09 peter 88   SvmMultiClass svm;
1864 15 Mar 09 peter 89   suite.err() << "training svm" << std::endl;
1864 15 Mar 09 peter 90   svm.train(kernel, target);
1864 15 Mar 09 peter 91
1864 15 Mar 09 peter 92   utility::Matrix result;
1864 15 Mar 09 peter 93   svm.predict(kernel, result);
1864 15 Mar 09 peter 94
1864 15 Mar 09 peter 95   if (!suite.add(result.rows()==5 && result.columns()==79)) {
1864 15 Mar 09 peter 96     suite.err() << "ERROR: incorrect dimension in result Matrix\n"
1864 15 Mar 09 peter 97                 << "found " << result.rows() << "x" << result.columns() << "\n"
1864 15 Mar 09 peter 98                 << "expected 5x79\n";
1864 15 Mar 09 peter 99   }
1864 15 Mar 09 peter 100
1864 15 Mar 09 peter 101   // we expect perfect predictions on training data
1864 15 Mar 09 peter 102   for (size_t i=0; i<79; ++i) {
1864 15 Mar 09 peter 103     for (size_t j=0; j<5; ++j) {
1864 15 Mar 09 peter 104       if (target(i)==j && result(j,i)<0) {
1864 15 Mar 09 peter 105         suite.err() << "result(" << j << "," << i << ") is "
1864 15 Mar 09 peter 106                     << result(j,i) << " expected greater than 0" << std::endl;
1864 15 Mar 09 peter 107         suite.add(false);
1864 15 Mar 09 peter 108       }
1864 15 Mar 09 peter 109       else if (target(i)!=j && result(j,i)>0) {
1864 15 Mar 09 peter 110         suite.err() << "result(" << j << "," << i << ") is "
1864 15 Mar 09 peter 111                     << result(j,i) << " expected smaller than 0" << std::endl;
1864 15 Mar 09 peter 112         suite.add(false);
1864 15 Mar 09 peter 113       }
1864 15 Mar 09 peter 114     }
1864 15 Mar 09 peter 115   }
1864 15 Mar 09 peter 116 }
1865 16 Mar 09 peter 117
1865 16 Mar 09 peter 118 void test_predict2(test::Suite& suite)
1865 16 Mar 09 peter 119 {
1865 16 Mar 09 peter 120   using namespace classifier;
1865 16 Mar 09 peter 121   std::string file = test::filename("data/sorlie_centroid_data.txt");
1865 16 Mar 09 peter 122   std::ifstream is(file.c_str());
1865 16 Mar 09 peter 123   suite.err() << "load data `" << file << "'" << std::endl;
1865 16 Mar 09 peter 124   MatrixLookupWeighted data(is, '\t');
1865 16 Mar 09 peter 125   is.close();
1865 16 Mar 09 peter 126   PolynomialKernelFunction linear;
1865 16 Mar 09 peter 127   suite.err() << "calculating kernel" << std::endl;
1865 16 Mar 09 peter 128   Kernel_SEV kernel_raw(data, linear);
1865 16 Mar 09 peter 129   file = test::filename("data/sorlie_centroid_classes.txt");
1865 16 Mar 09 peter 130   suite.err() << "load classes `" << file << "'" << std::endl;
1865 16 Mar 09 peter 131   is.open(file.c_str());
1865 16 Mar 09 peter 132   Target target(is);
1865 16 Mar 09 peter 133   is.close();
1865 16 Mar 09 peter 134
1865 16 Mar 09 peter 135   std::vector<size_t> index;
1865 16 Mar 09 peter 136   for (size_t i=0; i<50; ++i)
1865 16 Mar 09 peter 137     index.push_back(i);
1865 16 Mar 09 peter 138   for (size_t i=70; i<79; ++i)
1865 16 Mar 09 peter 139     index.push_back(i);
1865 16 Mar 09 peter 140   utility::Index train_index(index);
4200 19 Aug 22 peter 141
1865 16 Mar 09 peter 142   Target target_train(target, train_index);
1865 16 Mar 09 peter 143   KernelLookup kernel_train(kernel_raw, train_index, train_index);
1865 16 Mar 09 peter 144   SvmMultiClass svm;
1865 16 Mar 09 peter 145   suite.err() << "training svm" << std::endl;
1865 16 Mar 09 peter 146   svm.train(kernel_train, target_train);
1865 16 Mar 09 peter 147
1865 16 Mar 09 peter 148   index.clear();
1865 16 Mar 09 peter 149   for (size_t i=50; i<70; ++i)
1865 16 Mar 09 peter 150     index.push_back(i);
1865 16 Mar 09 peter 151   utility::Index test_index(index);
1865 16 Mar 09 peter 152
1865 16 Mar 09 peter 153   KernelLookup kernel_test(kernel_raw, train_index, test_index);
1865 16 Mar 09 peter 154   utility::Matrix result;
1865 16 Mar 09 peter 155   suite.err() << "Predicting on test data" << std::endl;
1865 16 Mar 09 peter 156   svm.predict(kernel_test, result);
1865 16 Mar 09 peter 157
1865 16 Mar 09 peter 158   if (!suite.add(result.rows()==5 && result.columns()==20)) {
1865 16 Mar 09 peter 159     suite.err() << "ERROR: incorrect dimension in result Matrix\n"
1865 16 Mar 09 peter 160                 << "found " << result.rows() << "x" << result.columns() << "\n"
1865 16 Mar 09 peter 161                 << "expected 5x79\n";
1865 16 Mar 09 peter 162   }
1865 16 Mar 09 peter 163   if (!suite.add(std::isnan(result(3, 0))) ) {
1865 16 Mar 09 peter 164     suite.err() << "ERROR: expected result(4,0) to be nan\n"
1865 16 Mar 09 peter 165                 << "  found " << result(4,0) << std::endl;
1865 16 Mar 09 peter 166   }
1865 16 Mar 09 peter 167
1865 16 Mar 09 peter 168 }