yat/statistics/PoissonMixture.cc

Code
Comments
Other
Rev Date Author Line
4180 09 Jun 22 peter 1 // $Id$
4180 09 Jun 22 peter 2
4180 09 Jun 22 peter 3 /*
4180 09 Jun 22 peter 4   Copyright (C) 2022 Peter Johansson
4180 09 Jun 22 peter 5
4180 09 Jun 22 peter 6   This file is part of the yat library, https://dev.thep.lu.se/yat
4180 09 Jun 22 peter 7
4180 09 Jun 22 peter 8   The yat library is free software; you can redistribute it and/or
4180 09 Jun 22 peter 9   modify it under the terms of the GNU General Public License as
4180 09 Jun 22 peter 10   published by the Free Software Foundation; either version 3 of the
4180 09 Jun 22 peter 11   License, or (at your option) any later version.
4180 09 Jun 22 peter 12
4180 09 Jun 22 peter 13   The yat library is distributed in the hope that it will be useful,
4180 09 Jun 22 peter 14   but WITHOUT ANY WARRANTY; without even the implied warranty of
4180 09 Jun 22 peter 15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
4180 09 Jun 22 peter 16   General Public License for more details.
4180 09 Jun 22 peter 17
4180 09 Jun 22 peter 18   You should have received a copy of the GNU General Public License
4180 09 Jun 22 peter 19   along with yat. If not, see <https://www.gnu.org/licenses/>.
4180 09 Jun 22 peter 20 */
4180 09 Jun 22 peter 21
4180 09 Jun 22 peter 22 #include <config.h>
4180 09 Jun 22 peter 23
4180 09 Jun 22 peter 24 #include "PoissonMixture.h"
4180 09 Jun 22 peter 25
4180 09 Jun 22 peter 26 #include "GaussianMixture.h"
4180 09 Jun 22 peter 27
4180 09 Jun 22 peter 28 #include <yat/utility/BFGS2.h>
4180 09 Jun 22 peter 29 #include <yat/utility/Matrix.h>
4180 09 Jun 22 peter 30 #include <yat/utility/version.h>
4180 09 Jun 22 peter 31
4180 09 Jun 22 peter 32 #include <gsl/gsl_randist.h>
4180 09 Jun 22 peter 33 #include <gsl/gsl_sf_gamma.h>
4180 09 Jun 22 peter 34
4180 09 Jun 22 peter 35 #include <cassert>
4180 09 Jun 22 peter 36 #include <cmath>
4180 09 Jun 22 peter 37
4180 09 Jun 22 peter 38 namespace theplu {
4180 09 Jun 22 peter 39 namespace yat {
4180 09 Jun 22 peter 40 namespace statistics {
4180 09 Jun 22 peter 41
4180 09 Jun 22 peter 42   void PoissonMixture::add(unsigned long int k, unsigned long int n)
4180 09 Jun 22 peter 43   {
4180 09 Jun 22 peter 44     auto it = count_.lower_bound(k);
4180 09 Jun 22 peter 45     if (it == count_.end() || k < it->first)
4180 09 Jun 22 peter 46       count_.insert(it, std::make_pair(k, n));
4180 09 Jun 22 peter 47     else
4180 09 Jun 22 peter 48       it->second += n;
4180 09 Jun 22 peter 49   }
4180 09 Jun 22 peter 50
4180 09 Jun 22 peter 51
4180 09 Jun 22 peter 52   void PoissonMixture::clear(void)
4180 09 Jun 22 peter 53   {
4180 09 Jun 22 peter 54     count_.clear();
4180 09 Jun 22 peter 55   }
4180 09 Jun 22 peter 56
4180 09 Jun 22 peter 57
4180 09 Jun 22 peter 58   bool PoissonMixture::expectation_step(yat::utility::Matrix& H) const
4180 09 Jun 22 peter 59   {
4180 09 Jun 22 peter 60     // loop over values
4180 09 Jun 22 peter 61     auto it=count_.begin();
4180 09 Jun 22 peter 62     assert(count_.size() == H.columns());
4180 09 Jun 22 peter 63     for (size_t j=0; j<H.columns(); ++j) {
4180 09 Jun 22 peter 64       yat::utility::VectorView column = H.column_view(j);
4180 09 Jun 22 peter 65       unsigned long int k = it->first;
4180 09 Jun 22 peter 66       // loop over models
4180 09 Jun 22 peter 67       for (size_t i=0; i<H.rows(); ++i) {
4180 09 Jun 22 peter 68         if (m_(i)) {
4180 09 Jun 22 peter 69           assert(m_(i) > 0.0);
4180 09 Jun 22 peter 70           column(i) = tau_(i) * gsl_ran_poisson_pdf(k, m_(i));
4180 09 Jun 22 peter 71         }
4180 09 Jun 22 peter 72         else
4180 09 Jun 22 peter 73           column(i) = tau_(i) * (k ? 0.0 : 1.0);
4180 09 Jun 22 peter 74       }
4180 09 Jun 22 peter 75       assert(sum(column));
4180 09 Jun 22 peter 76       column *= 1.0 / sum(column);
4180 09 Jun 22 peter 77       ++it;
4180 09 Jun 22 peter 78     }
4180 09 Jun 22 peter 79     return false;
4180 09 Jun 22 peter 80   }
4180 09 Jun 22 peter 81
4180 09 Jun 22 peter 82
4180 09 Jun 22 peter 83   void PoissonMixture::fit(const yat::utility::VectorBase& m,
4180 09 Jun 22 peter 84                            const yat::utility::VectorBase& tau)
4180 09 Jun 22 peter 85   {
4180 09 Jun 22 peter 86     assert(m.size() == tau.size());
4180 09 Jun 22 peter 87     m_ = m;
4180 09 Jun 22 peter 88     tau_ = tau;
4180 09 Jun 22 peter 89     optimize();
4180 09 Jun 22 peter 90   }
4180 09 Jun 22 peter 91
4180 09 Jun 22 peter 92
4180 09 Jun 22 peter 93   void PoissonMixture::fit(size_t n)
4180 09 Jun 22 peter 94   {
4180 09 Jun 22 peter 95     init_fit(n);
4180 09 Jun 22 peter 96     optimize();
4180 09 Jun 22 peter 97   }
4180 09 Jun 22 peter 98
4180 09 Jun 22 peter 99
4180 09 Jun 22 peter 100   void PoissonMixture::init_fit(size_t n)
4180 09 Jun 22 peter 101   {
4180 09 Jun 22 peter 102     assert(n>0);
4180 09 Jun 22 peter 103     tau_.resize(n, 1.0/n);
4180 09 Jun 22 peter 104
4180 09 Jun 22 peter 105     statistics::GaussianMixture gm;
4180 09 Jun 22 peter 106     for (const auto& x : count_)
4180 09 Jun 22 peter 107       gm.add(x.first, x.second);
4180 09 Jun 22 peter 108     gm.fit(n);
4180 09 Jun 22 peter 109
4180 09 Jun 22 peter 110     m_.resize(n);
4180 09 Jun 22 peter 111     for (size_t i=0; i<n; ++i)
4180 09 Jun 22 peter 112       m_(i) = gm.mean(i);
4180 09 Jun 22 peter 113   }
4180 09 Jun 22 peter 114
4180 09 Jun 22 peter 115
4180 09 Jun 22 peter 116   double PoissonMixture::logL(void)
4180 09 Jun 22 peter 117   {
4180 09 Jun 22 peter 118     double log_L = 0;
4180 09 Jun 22 peter 119     for (const auto& x : count_) {
4180 09 Jun 22 peter 120       double L = 0;
4180 09 Jun 22 peter 121       for (size_t i=0; i<tau_.size(); ++i) {
4180 09 Jun 22 peter 122         if (mean(i) > 0)
4180 09 Jun 22 peter 123           L += tau_(i) * gsl_ran_poisson_pdf(x.first, mean(i));
4180 09 Jun 22 peter 124         else if (x.first==0)
4180 09 Jun 22 peter 125           L += tau_(i);
4180 09 Jun 22 peter 126       }
4180 09 Jun 22 peter 127       log_L += x.second * std::log(L);
4180 09 Jun 22 peter 128     }
4180 09 Jun 22 peter 129     return log_L;
4180 09 Jun 22 peter 130   }
4180 09 Jun 22 peter 131
4180 09 Jun 22 peter 132
4180 09 Jun 22 peter 133   double PoissonMixture::mean(size_t i) const
4180 09 Jun 22 peter 134   {
4180 09 Jun 22 peter 135     return m_(i);
4180 09 Jun 22 peter 136   }
4180 09 Jun 22 peter 137
4180 09 Jun 22 peter 138
4180 09 Jun 22 peter 139   void PoissonMixture::optimize(void)
4180 09 Jun 22 peter 140   {
4180 09 Jun 22 peter 141     if (tau_.size()==0 || count_.empty())
4180 09 Jun 22 peter 142       return;
4180 09 Jun 22 peter 143     utility::Matrix H(tau_.size(), count_.size());
4180 09 Jun 22 peter 144
4180 09 Jun 22 peter 145     bool changed = true;
4180 09 Jun 22 peter 146     for (size_t i=0; i<100 && changed; ++i) {
4180 09 Jun 22 peter 147       changed = expectation_step(H);
4180 09 Jun 22 peter 148       if (optimize_tau(H))
4180 09 Jun 22 peter 149         changed = true;
4180 09 Jun 22 peter 150       if (optimize_m(H))
4180 09 Jun 22 peter 151         changed=true;
4180 09 Jun 22 peter 152     }
4180 09 Jun 22 peter 153
4180 09 Jun 22 peter 154   }
4180 09 Jun 22 peter 155
4180 09 Jun 22 peter 156
4180 09 Jun 22 peter 157   bool PoissonMixture::optimize_tau(const yat::utility::Matrix& H)
4180 09 Jun 22 peter 158   {
4180 09 Jun 22 peter 159     yat::utility::Vector prev_tau(tau_);
4180 09 Jun 22 peter 160
4180 09 Jun 22 peter 161     tau_.all(0.0);
4180 09 Jun 22 peter 162     auto it = count_.begin();
4180 09 Jun 22 peter 163     // tau is updated as the average
4180 09 Jun 22 peter 164     for (size_t i=0; i<H.columns(); ++i) {
4180 09 Jun 22 peter 165       tau_ += it->second * H.column_const_view(i);
4180 09 Jun 22 peter 166       ++it;
4180 09 Jun 22 peter 167     }
4180 09 Jun 22 peter 168     assert(!std::isnan(sum(tau_)));
4180 09 Jun 22 peter 169     assert(sum(tau_));
4180 09 Jun 22 peter 170     // normalize tau to unity sum
4180 09 Jun 22 peter 171     tau_ *= 1.0 / sum(tau_);
4180 09 Jun 22 peter 172     return norm2_squared(yat::utility::Vector(tau_-prev_tau)) > 1e-9;
4180 09 Jun 22 peter 173   }
4180 09 Jun 22 peter 174
4180 09 Jun 22 peter 175
4180 09 Jun 22 peter 176   /*
4180 09 Jun 22 peter 177     Update m by minimizing
4180 09 Jun 22 peter 178     Q = sum H_ij * (log tau_i + logL(m,k)) =
4180 09 Jun 22 peter 179       = sum H_ij * (log tau_i + k_j*log(m_i) - m_i - ln(Gamma(k_j+1)))
4180 09 Jun 22 peter 180    */
4180 09 Jun 22 peter 181   bool PoissonMixture::optimize_m(const yat::utility::Matrix& H)
4180 09 Jun 22 peter 182   {
4180 09 Jun 22 peter 183     yat::utility::Vector prev_m(m_);
4180 09 Jun 22 peter 184     // We can minimize each m_i independently by minimizing
4180 09 Jun 22 peter 185     // sum_j H_ij * (k_j*log(m_i) - m_i) by solving
4180 09 Jun 22 peter 186     // 0 = sum_j H_ij * (k_j/m_i - 1) i.e.
4180 09 Jun 22 peter 187     // m_i = sum_j H_ij k_j / sum_j H_ij
4180 09 Jun 22 peter 188     for (size_t i=0; i<m_.size(); ++i) {
4180 09 Jun 22 peter 189       double sum_Hk = 0;
4180 09 Jun 22 peter 190       double sum_H = 0;
4180 09 Jun 22 peter 191       size_t j=0;
4180 09 Jun 22 peter 192       assert(H.columns() == count_.size());
4180 09 Jun 22 peter 193       for (const auto& x : count_) {
4180 09 Jun 22 peter 194         sum_H += x.second * H(i,j);
4180 09 Jun 22 peter 195         sum_Hk += x.second * H(i,j) * x.first;
4180 09 Jun 22 peter 196         ++j;
4180 09 Jun 22 peter 197       }
4180 09 Jun 22 peter 198       m_(i) = sum_Hk / sum_H;
4180 09 Jun 22 peter 199       assert(m_(i) >= 0);
4180 09 Jun 22 peter 200     }
4180 09 Jun 22 peter 201
4180 09 Jun 22 peter 202     assert(sum(m_));
4180 09 Jun 22 peter 203     return norm2_squared(yat::utility::Vector(m_ - prev_m)) > 1e-9;
4180 09 Jun 22 peter 204   }
4180 09 Jun 22 peter 205
4180 09 Jun 22 peter 206
4180 09 Jun 22 peter 207   double PoissonMixture::tau(size_t i) const
4180 09 Jun 22 peter 208   {
4180 09 Jun 22 peter 209     return tau_(i);
4180 09 Jun 22 peter 210   }
4180 09 Jun 22 peter 211 }}}