yat/classifier/SvmMultiClass.cc

Code
Comments
Other
Rev Date Author Line
1855 07 Mar 09 peter 1 // $Id$
1855 07 Mar 09 peter 2
1855 07 Mar 09 peter 3 /*
4359 23 Aug 23 peter 4   Copyright (C) 2009, 2012 Peter Johansson
1855 07 Mar 09 peter 5
1855 07 Mar 09 peter 6   This file is part of the yat library, http://dev.thep.lu.se/yat
1855 07 Mar 09 peter 7
1855 07 Mar 09 peter 8   The yat library is free software; you can redistribute it and/or
1855 07 Mar 09 peter 9   modify it under the terms of the GNU General Public License as
1855 07 Mar 09 peter 10   published by the Free Software Foundation; either version 3 of the
1855 07 Mar 09 peter 11   License, or (at your option) any later version.
1855 07 Mar 09 peter 12
1855 07 Mar 09 peter 13   The yat library is distributed in the hope that it will be useful,
1855 07 Mar 09 peter 14   but WITHOUT ANY WARRANTY; without even the implied warranty of
1855 07 Mar 09 peter 15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
1855 07 Mar 09 peter 16   General Public License for more details.
1855 07 Mar 09 peter 17
1855 07 Mar 09 peter 18   You should have received a copy of the GNU General Public License
1855 07 Mar 09 peter 19   along with yat. If not, see <http://www.gnu.org/licenses/>.
1855 07 Mar 09 peter 20 */
1855 07 Mar 09 peter 21
2881 18 Nov 12 peter 22 #include <config.h>
2881 18 Nov 12 peter 23
1855 07 Mar 09 peter 24 #include "SvmMultiClass.h"
1855 07 Mar 09 peter 25 #include "KernelLookup.h"
1855 07 Mar 09 peter 26
1855 07 Mar 09 peter 27 #include "yat/utility/Matrix.h"
1855 07 Mar 09 peter 28 #include "yat/utility/VectorConstView.h"
1855 07 Mar 09 peter 29 #include "yat/utility/VectorView.h"
1855 07 Mar 09 peter 30
1855 07 Mar 09 peter 31 #include <cassert>
2055 08 Sep 09 peter 32 #include <limits>
1855 07 Mar 09 peter 33
1855 07 Mar 09 peter 34 namespace theplu {
1855 07 Mar 09 peter 35 namespace yat {
4200 19 Aug 22 peter 36 namespace classifier {
1855 07 Mar 09 peter 37
1855 07 Mar 09 peter 38   SvmMultiClass::SvmMultiClass(void)
1855 07 Mar 09 peter 39   {}
1855 07 Mar 09 peter 40
1855 07 Mar 09 peter 41
1861 12 Mar 09 peter 42   double SvmMultiClass::C(void) const
1861 12 Mar 09 peter 43   {
1861 12 Mar 09 peter 44     return prototype_.C();
1861 12 Mar 09 peter 45   }
1861 12 Mar 09 peter 46
1861 12 Mar 09 peter 47
1855 07 Mar 09 peter 48   SvmMultiClass* SvmMultiClass::make_classifier(void) const
1855 07 Mar 09 peter 49   {
1861 12 Mar 09 peter 50     return new SvmMultiClass(*this);
1855 07 Mar 09 peter 51   }
1855 07 Mar 09 peter 52
1855 07 Mar 09 peter 53
1861 12 Mar 09 peter 54   unsigned long int SvmMultiClass::max_epochs(void) const
1861 12 Mar 09 peter 55   {
1861 12 Mar 09 peter 56     return prototype_.max_epochs();
1861 12 Mar 09 peter 57   }
1861 12 Mar 09 peter 58
4200 19 Aug 22 peter 59
1861 12 Mar 09 peter 60   void SvmMultiClass::max_epochs(unsigned long int n)
1861 12 Mar 09 peter 61   {
1861 12 Mar 09 peter 62     prototype_.max_epochs(n);
1861 12 Mar 09 peter 63   }
1861 12 Mar 09 peter 64
1861 12 Mar 09 peter 65
4200 19 Aug 22 peter 66   void SvmMultiClass::predict(const KernelLookup& input,
1855 07 Mar 09 peter 67                               utility::Matrix& prediction) const
1855 07 Mar 09 peter 68   {
1855 07 Mar 09 peter 69     assert(svm_.size());
1855 07 Mar 09 peter 70     prediction.resize(svm_.size(),input.columns(),
1855 07 Mar 09 peter 71                       std::numeric_limits<double>::quiet_NaN());
1855 07 Mar 09 peter 72     for (size_t i=0; i<svm_.size(); ++i) {
1865 16 Mar 09 peter 73       if (svm_[i].trained()) {
1865 16 Mar 09 peter 74         yat::utility::Matrix tmp;
1865 16 Mar 09 peter 75         svm_[i].predict(input, tmp);
1865 16 Mar 09 peter 76         prediction.row_view(i) = tmp.row_const_view(0);
1865 16 Mar 09 peter 77       }
1855 07 Mar 09 peter 78     }
1855 07 Mar 09 peter 79
1855 07 Mar 09 peter 80   }
1855 07 Mar 09 peter 81
1855 07 Mar 09 peter 82
1861 12 Mar 09 peter 83   void SvmMultiClass::set_C(double c)
1861 12 Mar 09 peter 84   {
1861 12 Mar 09 peter 85     prototype_.set_C(c);
1861 12 Mar 09 peter 86   }
1861 12 Mar 09 peter 87
1861 12 Mar 09 peter 88
4200 19 Aug 22 peter 89   void SvmMultiClass::train(const KernelLookup& kernel, const Target& targ)
1855 07 Mar 09 peter 90   {
1861 12 Mar 09 peter 91     svm_.clear();
1861 12 Mar 09 peter 92     svm_.resize(targ.nof_classes(), prototype_);
1855 07 Mar 09 peter 93     Target target(targ);
1855 07 Mar 09 peter 94     for (size_t i=0; i<target.nof_classes(); ++i)
1855 07 Mar 09 peter 95       target.set_binary(i, false);
1855 07 Mar 09 peter 96
1855 07 Mar 09 peter 97     for (size_t i=0; i<target.nof_classes(); ++i) {
1860 08 Mar 09 peter 98       if (target.size(i)) {
1860 08 Mar 09 peter 99         target.set_binary(i, true);
1860 08 Mar 09 peter 100         svm_[i].train(kernel, target);
1860 08 Mar 09 peter 101         target.set_binary(i, false);
1860 08 Mar 09 peter 102       }
1860 08 Mar 09 peter 103       else {
1860 08 Mar 09 peter 104         svm_[i].reset(); // don't train empty class
1860 08 Mar 09 peter 105       }
4200 19 Aug 22 peter 106     }
1855 07 Mar 09 peter 107   }
1855 07 Mar 09 peter 108
1855 07 Mar 09 peter 109 }}} // of namespace classifier, yat, and theplu