yat/classifier/CrossValidationSampler.cc

Code
Comments
Other
Rev Date Author Line
612 30 Aug 06 peter 1 // $Id$
612 30 Aug 06 peter 2
612 30 Aug 06 peter 3 /*
4359 23 Aug 23 peter 4   Copyright (C) 2006 Jari Häkkinen, Peter Johansson
4359 23 Aug 23 peter 5   Copyright (C) 2007 Peter Johansson
4359 23 Aug 23 peter 6   Copyright (C) 2008 Jari Häkkinen, Peter Johansson
4359 23 Aug 23 peter 7   Copyright (C) 2012 Peter Johansson
612 30 Aug 06 peter 8
1437 25 Aug 08 peter 9   This file is part of the yat library, http://dev.thep.lu.se/yat
612 30 Aug 06 peter 10
675 10 Oct 06 jari 11   The yat library is free software; you can redistribute it and/or
675 10 Oct 06 jari 12   modify it under the terms of the GNU General Public License as
1486 09 Sep 08 jari 13   published by the Free Software Foundation; either version 3 of the
675 10 Oct 06 jari 14   License, or (at your option) any later version.
612 30 Aug 06 peter 15
675 10 Oct 06 jari 16   The yat library is distributed in the hope that it will be useful,
675 10 Oct 06 jari 17   but WITHOUT ANY WARRANTY; without even the implied warranty of
675 10 Oct 06 jari 18   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
612 30 Aug 06 peter 19   General Public License for more details.
612 30 Aug 06 peter 20
612 30 Aug 06 peter 21   You should have received a copy of the GNU General Public License
1487 10 Sep 08 jari 22   along with yat. If not, see <http://www.gnu.org/licenses/>.
612 30 Aug 06 peter 23 */
612 30 Aug 06 peter 24
2881 18 Nov 12 peter 25 #include <config.h>
2881 18 Nov 12 peter 26
680 11 Oct 06 jari 27 #include "CrossValidationSampler.h"
680 11 Oct 06 jari 28 #include "Target.h"
675 10 Oct 06 jari 29 #include "yat/random/random.h"
612 30 Aug 06 peter 30
612 30 Aug 06 peter 31 #include <algorithm>
612 30 Aug 06 peter 32 #include <cassert>
612 30 Aug 06 peter 33 #include <utility>
612 30 Aug 06 peter 34 #include <vector>
612 30 Aug 06 peter 35
612 30 Aug 06 peter 36 namespace theplu {
680 11 Oct 06 jari 37 namespace yat {
4200 19 Aug 22 peter 38 namespace classifier {
612 30 Aug 06 peter 39
4200 19 Aug 22 peter 40   CrossValidationSampler::CrossValidationSampler(const Target& target,
4200 19 Aug 22 peter 41                                                  const size_t N,
612 30 Aug 06 peter 42                                                  const size_t k)
823 19 Mar 07 peter 43     : Sampler(target, N), k_(k)
4200 19 Aug 22 peter 44   {
612 30 Aug 06 peter 45     assert(target.size()>1);
612 30 Aug 06 peter 46     build(target, N, k);
612 30 Aug 06 peter 47   }
612 30 Aug 06 peter 48
612 30 Aug 06 peter 49   CrossValidationSampler::~CrossValidationSampler()
612 30 Aug 06 peter 50   {
612 30 Aug 06 peter 51   }
612 30 Aug 06 peter 52
612 30 Aug 06 peter 53   void CrossValidationSampler::build(const Target& target, size_t N, size_t k)
612 30 Aug 06 peter 54   {
612 30 Aug 06 peter 55     std::vector<std::pair<size_t,size_t> > v;
612 30 Aug 06 peter 56     for (size_t i=0; i<target.size(); i++)
612 30 Aug 06 peter 57       v.push_back(std::make_pair(target(i),i));
612 30 Aug 06 peter 58     // sorting with respect to class
612 30 Aug 06 peter 59     std::sort(v.begin(),v.end());
4200 19 Aug 22 peter 60
612 30 Aug 06 peter 61     // my_begin[i] is index of first sample of class i
612 30 Aug 06 peter 62     std::vector<size_t> my_begin;
612 30 Aug 06 peter 63     my_begin.reserve(target.nof_classes());
612 30 Aug 06 peter 64     my_begin.push_back(0);
612 30 Aug 06 peter 65     for (size_t i=1; i<target.size(); i++)
612 30 Aug 06 peter 66       while (v[i].first > my_begin.size()-1)
612 30 Aug 06 peter 67         my_begin.push_back(i);
612 30 Aug 06 peter 68     my_begin.push_back(target.size());
612 30 Aug 06 peter 69
612 30 Aug 06 peter 70     for (size_t i=0; i<N; ) {
612 30 Aug 06 peter 71       // shuffle indices within class each class
1002 14 Jan 08 peter 72       for (size_t j=0; j+1<my_begin.size(); ++j)
1004 23 Jan 08 peter 73         random::random_shuffle(v.begin()+my_begin[j],v.begin()+my_begin[j+1]);
4200 19 Aug 22 peter 74
612 30 Aug 06 peter 75       for (size_t part=0; part<k && i<N; i++, part++) {
612 30 Aug 06 peter 76         std::vector<size_t> training_index;
612 30 Aug 06 peter 77         std::vector<size_t> validation_index;
612 30 Aug 06 peter 78         for (size_t j=0; j<v.size(); j++) {
612 30 Aug 06 peter 79           if (j%k==part)
612 30 Aug 06 peter 80             validation_index.push_back(v[j].second);
612 30 Aug 06 peter 81           else
612 30 Aug 06 peter 82             training_index.push_back(v[j].second);
612 30 Aug 06 peter 83         }
612 30 Aug 06 peter 84
1134 23 Feb 08 peter 85         training_index_.push_back(utility::Index(training_index));
1134 23 Feb 08 peter 86         validation_index_.push_back(utility::Index(validation_index));
612 30 Aug 06 peter 87       }
612 30 Aug 06 peter 88     }
612 30 Aug 06 peter 89     assert(training_index_.size()==N);
612 30 Aug 06 peter 90     assert(validation_index_.size()==N);
4200 19 Aug 22 peter 91
615 31 Aug 06 peter 92     for (size_t i=0; i<N; ++i){
615 31 Aug 06 peter 93       training_target_.push_back(Target(target,training_index_[i]));
615 31 Aug 06 peter 94       validation_target_.push_back(Target(target,validation_index_[i]));
615 31 Aug 06 peter 95     }
615 31 Aug 06 peter 96     assert(training_target_.size()==N);
615 31 Aug 06 peter 97     assert(validation_target_.size()==N);
612 30 Aug 06 peter 98   }
612 30 Aug 06 peter 99
680 11 Oct 06 jari 100 }}} // of namespace classifier, yat, and theplu