yat/statistics/GaussianMixture.cc

Code
Comments
Other
Rev Date Author Line
3644 11 May 17 peter 1 // $Id$
3644 11 May 17 peter 2
3644 11 May 17 peter 3 /*
4207 26 Aug 22 peter 4   Copyright (C) 2017, 2022 Peter Johansson
3644 11 May 17 peter 5
3644 11 May 17 peter 6   This file is part of the yat library, http://dev.thep.lu.se/yat
3644 11 May 17 peter 7
3644 11 May 17 peter 8   The yat library is free software; you can redistribute it and/or
3644 11 May 17 peter 9   modify it under the terms of the GNU General Public License as
3644 11 May 17 peter 10   published by the Free Software Foundation; either version 3 of the
3644 11 May 17 peter 11   License, or (at your option) any later version.
3644 11 May 17 peter 12
3644 11 May 17 peter 13   The yat library is distributed in the hope that it will be useful,
3644 11 May 17 peter 14   but WITHOUT ANY WARRANTY; without even the implied warranty of
3644 11 May 17 peter 15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
3644 11 May 17 peter 16   General Public License for more details.
3644 11 May 17 peter 17
3644 11 May 17 peter 18   You should have received a copy of the GNU General Public License
3644 11 May 17 peter 19   along with yat. If not, see <http://www.gnu.org/licenses/>.
3644 11 May 17 peter 20 */
3644 11 May 17 peter 21
3644 11 May 17 peter 22 #include <config.h>
3644 11 May 17 peter 23
3644 11 May 17 peter 24 #include "GaussianMixture.h"
3644 11 May 17 peter 25
3644 11 May 17 peter 26 #include "Averager.h"
3644 11 May 17 peter 27 #include <yat/utility/Matrix.h>
3644 11 May 17 peter 28
3644 11 May 17 peter 29 #include <gsl/gsl_randist.h>
3644 11 May 17 peter 30 #include <gsl/gsl_cdf.h>
3644 11 May 17 peter 31
3644 11 May 17 peter 32 #include <cassert>
3644 11 May 17 peter 33 #include <cmath>
3644 11 May 17 peter 34
3644 11 May 17 peter 35 namespace theplu {
3644 11 May 17 peter 36 namespace yat {
3644 11 May 17 peter 37 namespace statistics {
3644 11 May 17 peter 38
3644 11 May 17 peter 39   void GaussianMixture::add(double x, unsigned long int n)
3644 11 May 17 peter 40   {
3644 11 May 17 peter 41     data_.push_back(data_type(x, n));
3644 11 May 17 peter 42   }
3644 11 May 17 peter 43
3644 11 May 17 peter 44
3644 11 May 17 peter 45   void GaussianMixture::clear(void)
3644 11 May 17 peter 46   {
3644 11 May 17 peter 47     data_.clear();
3644 11 May 17 peter 48   }
3644 11 May 17 peter 49
3644 11 May 17 peter 50
4150 07 Mar 22 peter 51   void GaussianMixture::fit(const utility::Vector& alpha,
4150 07 Mar 22 peter 52                             const utility::Vector& mean,
4150 07 Mar 22 peter 53                             const utility::Vector& sigma)
4150 07 Mar 22 peter 54   {
4150 07 Mar 22 peter 55     alpha_ = alpha;
4150 07 Mar 22 peter 56     alpha_ *= 1.0/sum(alpha);
4150 07 Mar 22 peter 57     mean_ = mean;
4150 07 Mar 22 peter 58     sigma_ = sigma;
4150 07 Mar 22 peter 59     fit();
4150 07 Mar 22 peter 60   }
4150 07 Mar 22 peter 61
4150 07 Mar 22 peter 62
3644 11 May 17 peter 63   void GaussianMixture::fit(size_t n)
3644 11 May 17 peter 64   {
3644 11 May 17 peter 65     // init
3644 11 May 17 peter 66     Averager averager;
3644 11 May 17 peter 67     for (size_t i=0; i<data_.size(); ++i)
3644 11 May 17 peter 68       averager.add(data_[i].x, data_[i].n);
3644 11 May 17 peter 69
3644 11 May 17 peter 70     alpha_.resize(n, 1.0/n);
3644 11 May 17 peter 71     mean_.resize(n);
3644 11 May 17 peter 72     // put centroid evenly between m-2s and m+2s
3644 11 May 17 peter 73     for (size_t i=0; i<n; ++i) {
3644 11 May 17 peter 74       double fraction = static_cast<double>(i+1) / (n+1);
3644 11 May 17 peter 75       mean_(i) = averager.mean() + (4*fraction-2.0) * averager.std();
3644 11 May 17 peter 76     }
3644 11 May 17 peter 77
3644 11 May 17 peter 78     sigma_.resize(n, averager.std());
4150 07 Mar 22 peter 79     fit();
4150 07 Mar 22 peter 80   }
3644 11 May 17 peter 81
4150 07 Mar 22 peter 82
4150 07 Mar 22 peter 83   void GaussianMixture::fit(void)
4150 07 Mar 22 peter 84   {
4150 07 Mar 22 peter 85     assert(alpha_.size() == mean_.size());
4150 07 Mar 22 peter 86     assert(alpha_.size() == sigma_.size());
4150 07 Mar 22 peter 87     yat::utility::Matrix h(alpha_.size(), data_.size());
3644 11 May 17 peter 88     for (size_t i=0; i<100; ++i) {
3644 11 May 17 peter 89       calculate_posterior(h);
3644 11 May 17 peter 90       update_model(h);
3644 11 May 17 peter 91     }
3644 11 May 17 peter 92   }
3644 11 May 17 peter 93
3644 11 May 17 peter 94
3644 11 May 17 peter 95   double GaussianMixture::alpha(size_t i) const
3644 11 May 17 peter 96   {
3644 11 May 17 peter 97     assert(i < alpha_.size());
3644 11 May 17 peter 98     return alpha_(i);
3644 11 May 17 peter 99   }
3644 11 May 17 peter 100
3644 11 May 17 peter 101
3644 11 May 17 peter 102   double GaussianMixture::cdf_P(double x) const
3644 11 May 17 peter 103   {
3644 11 May 17 peter 104     double val = 0;
3644 11 May 17 peter 105     for (size_t i=0; i<alpha_.size(); ++i)
3644 11 May 17 peter 106       val += alpha_(i) * gsl_cdf_gaussian_P(x-mean_(i), sigma_(i));
3644 11 May 17 peter 107     return val;
3644 11 May 17 peter 108   }
3644 11 May 17 peter 109
3644 11 May 17 peter 110
3644 11 May 17 peter 111   double GaussianMixture::mean(void) const
3644 11 May 17 peter 112   {
3644 11 May 17 peter 113     double val = 0;
3644 11 May 17 peter 114     for (size_t i=0; i<alpha_.size(); ++i)
3644 11 May 17 peter 115       val += alpha_(i) * mean(i);
3644 11 May 17 peter 116     return val;
3644 11 May 17 peter 117   }
3644 11 May 17 peter 118
3644 11 May 17 peter 119
3644 11 May 17 peter 120   double GaussianMixture::mean(size_t i) const
3644 11 May 17 peter 121   {
3644 11 May 17 peter 122     assert(i < mean_.size());
3644 11 May 17 peter 123     return mean_(i);
3644 11 May 17 peter 124   }
3644 11 May 17 peter 125
3644 11 May 17 peter 126
3644 11 May 17 peter 127   double GaussianMixture::pdf(double x) const
3644 11 May 17 peter 128   {
3644 11 May 17 peter 129     double val = 0;
3644 11 May 17 peter 130     for (size_t i=0; i<alpha_.size(); ++i)
3644 11 May 17 peter 131       val += alpha_(i) * gsl_ran_gaussian_pdf(x - mean_(i), sigma_(i));
3644 11 May 17 peter 132     return val;
3644 11 May 17 peter 133   }
3644 11 May 17 peter 134
3644 11 May 17 peter 135
3644 11 May 17 peter 136   double GaussianMixture::std(size_t i) const
3644 11 May 17 peter 137   {
3644 11 May 17 peter 138     assert(i < sigma_.size());
3644 11 May 17 peter 139     return sigma_(i);
3644 11 May 17 peter 140   }
3644 11 May 17 peter 141
3644 11 May 17 peter 142
3644 11 May 17 peter 143   void GaussianMixture::calculate_posterior(utility::Matrix& h) const
3644 11 May 17 peter 144   {
3644 11 May 17 peter 145     assert(data_.size() == h.columns());
3644 11 May 17 peter 146
3644 11 May 17 peter 147     for (size_t model=0; model<h.rows(); ++model)
3644 11 May 17 peter 148       for (size_t sample=0; sample<h.columns(); ++sample)
3644 11 May 17 peter 149         h(model, sample) =
3644 11 May 17 peter 150           alpha_(model) * gsl_ran_gaussian_pdf(data_[sample].x - mean_(model),
3644 11 May 17 peter 151                                                sigma_(model));
3644 11 May 17 peter 152
3644 11 May 17 peter 153     // normalize to unity for each sample (column)
3644 11 May 17 peter 154     for (size_t sample=0; sample<data_.size(); ++sample) {
3644 11 May 17 peter 155       double s = sum(h.column_const_view(sample));
3644 11 May 17 peter 156       h.column_view(sample) *= 1.0/s;
3644 11 May 17 peter 157     }
3644 11 May 17 peter 158   }
3644 11 May 17 peter 159
3644 11 May 17 peter 160
3644 11 May 17 peter 161   void GaussianMixture::update_model(const utility::Matrix& h)
3644 11 May 17 peter 162   {
3644 11 May 17 peter 163     // calculate number of samples
3644 11 May 17 peter 164     unsigned long int N = 0;
3644 11 May 17 peter 165     for (size_t i=0; i<data_.size(); ++i)
3644 11 May 17 peter 166       N += data_[i].n;
3644 11 May 17 peter 167
3644 11 May 17 peter 168     utility::Vector sum_h(h.rows());
3644 11 May 17 peter 169     for (size_t i=0; i<h.rows(); ++i)
3644 11 May 17 peter 170       for (size_t j=0; j<h.columns(); ++j)
3644 11 May 17 peter 171         sum_h(i) += h(i,j) * data_[j].n;
3644 11 May 17 peter 172
3644 11 May 17 peter 173     // update std
3644 11 May 17 peter 174     // std(i) = sum h(i, j) * (x(j) - mean(i))^2 / sum h(i, j)
3644 11 May 17 peter 175     for (size_t i=0; i<sigma_.size(); ++i) {
3644 11 May 17 peter 176       sigma_(i) = 0;
3644 11 May 17 peter 177       for (size_t j=0; j<h.columns(); ++j) {
3644 11 May 17 peter 178         sigma_(i) += data_[j].n * h(i, j) * std::pow(data_[j].x - mean_(i), 2);
3644 11 May 17 peter 179       }
3644 11 May 17 peter 180       sigma_(i) = std::sqrt(sigma_(i) / sum_h(i));
3644 11 May 17 peter 181     }
3644 11 May 17 peter 182
3644 11 May 17 peter 183     // update mean
3644 11 May 17 peter 184     // mean(i) = sum h(i, j) * x(j) / sum h(i, j)
3644 11 May 17 peter 185     for (size_t i=0; i<mean_.size(); ++i) {
3644 11 May 17 peter 186       mean_(i) = 0;
3644 11 May 17 peter 187       for (size_t j=0; j<h.columns(); ++j)
3644 11 May 17 peter 188         mean_(i) += data_[j].n * h(i, j) * data_[j].x;
3644 11 May 17 peter 189       mean_(i) /= sum_h(i);
3644 11 May 17 peter 190     }
3644 11 May 17 peter 191
3644 11 May 17 peter 192     // update alpha
3644 11 May 17 peter 193     // alpha(i) = \sum h(i, j) n(j) / sum n(j)
3644 11 May 17 peter 194     for (size_t model=0; model<h.rows(); ++model) {
3644 11 May 17 peter 195       double sum = 0;
3644 11 May 17 peter 196       for (size_t sample=0; sample<h.columns(); ++sample)
3644 11 May 17 peter 197         sum += h(model, sample) * data_[sample].n;
3644 11 May 17 peter 198       alpha_(model) = sum / N;
3644 11 May 17 peter 199     }
3644 11 May 17 peter 200   }
3644 11 May 17 peter 201
3644 11 May 17 peter 202
3644 11 May 17 peter 203   GaussianMixture::data_type::data_type(double value, unsigned long int number)
3644 11 May 17 peter 204     : x(value), n(number)
3644 11 May 17 peter 205   {}
3644 11 May 17 peter 206
3644 11 May 17 peter 207
3644 11 May 17 peter 208 }}} // of namespace statistics, yat, and theplu