yat/regression/MultiCox.cc

Code
Comments
Other
Rev Date Author Line
4198 19 Aug 22 peter 1 // $Id$
4198 19 Aug 22 peter 2
4198 19 Aug 22 peter 3 /*
4198 19 Aug 22 peter 4   Copyright (C) 2022 Peter Johansson
4198 19 Aug 22 peter 5
4198 19 Aug 22 peter 6   This file is part of the yat library, https://dev.thep.lu.se/yat
4198 19 Aug 22 peter 7
4198 19 Aug 22 peter 8   The yat library is free software; you can redistribute it and/or
4198 19 Aug 22 peter 9   modify it under the terms of the GNU General Public License as
4198 19 Aug 22 peter 10   published by the Free Software Foundation; either version 3 of the
4198 19 Aug 22 peter 11   License, or (at your option) any later version.
4198 19 Aug 22 peter 12
4198 19 Aug 22 peter 13   The yat library is distributed in the hope that it will be useful,
4198 19 Aug 22 peter 14   but WITHOUT ANY WARRANTY; without even the implied warranty of
4198 19 Aug 22 peter 15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
4198 19 Aug 22 peter 16   General Public License for more details.
4198 19 Aug 22 peter 17
4198 19 Aug 22 peter 18   You should have received a copy of the GNU General Public License
4198 19 Aug 22 peter 19   along with yat. If not, see <https://www.gnu.org/licenses/>.
4198 19 Aug 22 peter 20 */
4198 19 Aug 22 peter 21
4198 19 Aug 22 peter 22 #include <config.h>
4198 19 Aug 22 peter 23
4198 19 Aug 22 peter 24 #include "MultiCox.h"
4198 19 Aug 22 peter 25
4198 19 Aug 22 peter 26 #include "detail/Cox.h"
4198 19 Aug 22 peter 27
4198 19 Aug 22 peter 28 #include <yat/utility/BFGS2.h>
4198 19 Aug 22 peter 29 #include <yat/utility/Matrix.h>
4198 19 Aug 22 peter 30 #include <yat/utility/MatrixBase.h>
4198 19 Aug 22 peter 31 #include <yat/utility/Vector.h>
4198 19 Aug 22 peter 32 #include <yat/utility/VectorBase.h>
4198 19 Aug 22 peter 33
4198 19 Aug 22 peter 34 #include <gsl/gsl_cdf.h>
4198 19 Aug 22 peter 35
4198 19 Aug 22 peter 36 #include <cassert>
4198 19 Aug 22 peter 37 #include <cmath>
4198 19 Aug 22 peter 38
4198 19 Aug 22 peter 39 namespace theplu {
4198 19 Aug 22 peter 40 namespace yat {
4198 19 Aug 22 peter 41 namespace regression {
4198 19 Aug 22 peter 42
4198 19 Aug 22 peter 43   class MultiCox::Impl : public cox::Implementation<utility::Vector>
4198 19 Aug 22 peter 44   {
4198 19 Aug 22 peter 45   public:
4198 19 Aug 22 peter 46     using cox::Implementation<utility::Vector>::add;
4198 19 Aug 22 peter 47     void add(const utility::MatrixBase& X,
4198 19 Aug 22 peter 48              const utility::VectorBase& times,
4198 19 Aug 22 peter 49              const std::vector<char>& events)
4198 19 Aug 22 peter 50     {
4198 19 Aug 22 peter 51       assert(times.size() == events.size());
4198 19 Aug 22 peter 52       assert(X.columns() == events.size());
4198 19 Aug 22 peter 53       for (size_t i=0; i<times.size(); ++i)
4198 19 Aug 22 peter 54         add(X.column_const_view(i), times(i), events[i]);
4198 19 Aug 22 peter 55     }
4198 19 Aug 22 peter 56
4198 19 Aug 22 peter 57
4198 19 Aug 22 peter 58     double b(size_t i) const
4198 19 Aug 22 peter 59     {
4198 19 Aug 22 peter 60       assert(i < beta_.size());
4198 19 Aug 22 peter 61       return beta_(i);
4198 19 Aug 22 peter 62     }
4198 19 Aug 22 peter 63
4198 19 Aug 22 peter 64
4198 19 Aug 22 peter 65     const utility::Matrix& covariance(void) const
4198 19 Aug 22 peter 66     {
4198 19 Aug 22 peter 67       return beta_cov_;
4198 19 Aug 22 peter 68     }
4198 19 Aug 22 peter 69
4198 19 Aug 22 peter 70
4198 19 Aug 22 peter 71     double hazard_ratio(size_t i) const
4198 19 Aug 22 peter 72     {
4198 19 Aug 22 peter 73       return exp(b(i));
4198 19 Aug 22 peter 74     }
4198 19 Aug 22 peter 75
4198 19 Aug 22 peter 76
4198 19 Aug 22 peter 77     double hazard_ratio_lower_CI(size_t i, double alpha) const
4198 19 Aug 22 peter 78     {
4198 19 Aug 22 peter 79       return exp(b(i) - hazard_ratio_CI(i, alpha));
4198 19 Aug 22 peter 80     }
4198 19 Aug 22 peter 81
4198 19 Aug 22 peter 82
4198 19 Aug 22 peter 83     double hazard_ratio_upper_CI(size_t i, double alpha) const
4198 19 Aug 22 peter 84     {
4198 19 Aug 22 peter 85       return exp(b(i) + hazard_ratio_CI(i, alpha));
4198 19 Aug 22 peter 86     }
4198 19 Aug 22 peter 87
4198 19 Aug 22 peter 88
4198 19 Aug 22 peter 89     double hazard_ratio_CI(size_t i, double alpha) const
4198 19 Aug 22 peter 90     {
4198 19 Aug 22 peter 91       assert(i < beta_cov_.rows());
4198 19 Aug 22 peter 92       assert(i < beta_cov_.columns());
4198 19 Aug 22 peter 93       double z = gsl_cdf_ugaussian_Qinv(0.5 * (1.0 - alpha));
4198 19 Aug 22 peter 94       return z * std::sqrt(beta_cov_(i,i));
4198 19 Aug 22 peter 95     }
4198 19 Aug 22 peter 96
4198 19 Aug 22 peter 97
4198 19 Aug 22 peter 98     double p(size_t i) const
4198 19 Aug 22 peter 99     {
4198 19 Aug 22 peter 100       if (b(i) > 0)
4198 19 Aug 22 peter 101         return 2 * gsl_cdf_ugaussian_Q(z(i));
4198 19 Aug 22 peter 102       return 2 * gsl_cdf_ugaussian_P(z(i));
4198 19 Aug 22 peter 103     }
4198 19 Aug 22 peter 104
4198 19 Aug 22 peter 105
4198 19 Aug 22 peter 106     void train(void)
4198 19 Aug 22 peter 107     {
4198 19 Aug 22 peter 108       if (data_.empty())
4198 19 Aug 22 peter 109         return;
4198 19 Aug 22 peter 110
4198 19 Aug 22 peter 111       size_t n = data_.front().x.size();
4198 19 Aug 22 peter 112       prepare_times();
4198 19 Aug 22 peter 113       logL func(times_);
4198 19 Aug 22 peter 114       beta_.resize(n, 0.0);
4198 19 Aug 22 peter 115       utility::BFGS2 solver(n);
4252 18 Nov 22 peter 116       solver(beta_, func, utility::MultiMinimizerDerivative::Gradient(1e-3));
4198 19 Aug 22 peter 117
4198 19 Aug 22 peter 118       // Calculate 2nd deriviate at beta_;
4198 19 Aug 22 peter 119       utility::Matrix H(n, n);
4198 19 Aug 22 peter 120       func.hessian(beta_, H);
4198 19 Aug 22 peter 121       inverse_svd(H, beta_cov_);
4198 19 Aug 22 peter 122     }
4198 19 Aug 22 peter 123
4198 19 Aug 22 peter 124
4198 19 Aug 22 peter 125     double z(size_t i) const
4198 19 Aug 22 peter 126     {
4198 19 Aug 22 peter 127       assert(i < beta_cov_.rows());
4198 19 Aug 22 peter 128       assert(i < beta_cov_.columns());
4198 19 Aug 22 peter 129       return b(i) / std::sqrt(beta_cov_(i,i));
4198 19 Aug 22 peter 130     }
4198 19 Aug 22 peter 131
4198 19 Aug 22 peter 132
4198 19 Aug 22 peter 133   private:
4198 19 Aug 22 peter 134     utility::Vector beta_;
4198 19 Aug 22 peter 135     utility::Matrix beta_cov_;
4198 19 Aug 22 peter 136
4198 19 Aug 22 peter 137     class logL
4198 19 Aug 22 peter 138     {
4198 19 Aug 22 peter 139       const std::vector<TimePoint>& times_;
4198 19 Aug 22 peter 140     public:
4198 19 Aug 22 peter 141       logL(const std::vector<TimePoint>& times)
4198 19 Aug 22 peter 142         : times_(times)
4198 19 Aug 22 peter 143       {}
4198 19 Aug 22 peter 144
4198 19 Aug 22 peter 145
4198 19 Aug 22 peter 146       double operator()(const utility::VectorBase& beta)
4198 19 Aug 22 peter 147       {
4198 19 Aug 22 peter 148         double f = 0;
4198 19 Aug 22 peter 149         double theta_Q = 0;
4198 19 Aug 22 peter 150         for (auto time = times_.rbegin(); time!=times_.rend(); ++time) {
4198 19 Aug 22 peter 151           double sum_event_theta = 0;
4198 19 Aug 22 peter 152           for (auto it = time->events_begin(); it!=time->events_end(); ++it)
4198 19 Aug 22 peter 153             sum_event_theta += it->theta(beta);
4198 19 Aug 22 peter 154           theta_Q += sum_event_theta;
4198 19 Aug 22 peter 155           for (auto it=time->censored_begin(); it!=time->censored_end(); ++it)
4198 19 Aug 22 peter 156             theta_Q += it->theta(beta);
4198 19 Aug 22 peter 157
4198 19 Aug 22 peter 158           for (auto it = time->events_begin(); it!=time->events_end(); ++it) {
4198 19 Aug 22 peter 159             const size_t k = it - time->events_begin();
4198 19 Aug 22 peter 160             double r = static_cast<double>(k) / time->size();
4198 19 Aug 22 peter 161             f -= it->x * beta;
4198 19 Aug 22 peter 162             f += log(theta_Q - r * sum_event_theta);
4198 19 Aug 22 peter 163           }
4198 19 Aug 22 peter 164         }
4198 19 Aug 22 peter 165         return f;
4198 19 Aug 22 peter 166       }
4198 19 Aug 22 peter 167
4198 19 Aug 22 peter 168
4198 19 Aug 22 peter 169       void operator()(const utility::VectorBase& beta,
4198 19 Aug 22 peter 170                       utility::VectorMutable& gradient)
4198 19 Aug 22 peter 171       {
4198 19 Aug 22 peter 172         assert(beta.size() == gradient.size());
4198 19 Aug 22 peter 173         gradient.all(0.0);
4198 19 Aug 22 peter 174         size_t n = beta.size();
4198 19 Aug 22 peter 175
4198 19 Aug 22 peter 176         double theta_Q = 0;
4198 19 Aug 22 peter 177         utility::Vector thetaX_Q(n, 0.0);
4198 19 Aug 22 peter 178         for (auto time = times_.rbegin(); time!=times_.rend(); ++time) {
4198 19 Aug 22 peter 179           double sum_event_theta = 0;
4198 19 Aug 22 peter 180           utility::Vector sum_event_thetaX(n, 0);
4198 19 Aug 22 peter 181           for (auto it = time->events_begin(); it!=time->events_end(); ++it) {
4198 19 Aug 22 peter 182             double theta = it->theta(beta);
4198 19 Aug 22 peter 183             sum_event_theta += theta;
4198 19 Aug 22 peter 184             sum_event_thetaX += theta * it->x;
4198 19 Aug 22 peter 185           }
4198 19 Aug 22 peter 186           theta_Q += sum_event_theta;
4198 19 Aug 22 peter 187           thetaX_Q += sum_event_thetaX;
4198 19 Aug 22 peter 188
4198 19 Aug 22 peter 189           for (auto it=time->censored_begin(); it!=time->censored_end(); ++it){
4198 19 Aug 22 peter 190             double theta = it->theta(beta);
4198 19 Aug 22 peter 191             theta_Q += theta;
4198 19 Aug 22 peter 192             thetaX_Q += theta * it->x;
4198 19 Aug 22 peter 193           }
4198 19 Aug 22 peter 194
4198 19 Aug 22 peter 195           // loop over events at time point t
4198 19 Aug 22 peter 196           for (auto it = time->events_begin(); it!=time->events_end(); ++it) {
4198 19 Aug 22 peter 197             const size_t k = it - time->events_begin();
4198 19 Aug 22 peter 198             double r = static_cast<double>(k) / time->size();
4198 19 Aug 22 peter 199
4198 19 Aug 22 peter 200             gradient -= it->x;
4198 19 Aug 22 peter 201             gradient += (1.0 / (theta_Q - r * sum_event_theta)) *
4198 19 Aug 22 peter 202               (thetaX_Q - r * sum_event_thetaX);
4198 19 Aug 22 peter 203           }
4198 19 Aug 22 peter 204         }
4198 19 Aug 22 peter 205       }
4198 19 Aug 22 peter 206
4198 19 Aug 22 peter 207
4198 19 Aug 22 peter 208       void hessian(const utility::VectorBase& beta, utility::Matrix& H) const
4198 19 Aug 22 peter 209       {
4198 19 Aug 22 peter 210         size_t n = beta.size();
4198 19 Aug 22 peter 211         H.all(0.0);
4198 19 Aug 22 peter 212
4198 19 Aug 22 peter 213         double sum_theta = 0;
4198 19 Aug 22 peter 214         utility::Vector sum_thetaX(n,0);
4198 19 Aug 22 peter 215         utility::Matrix sum_thetaXX(n, n, 0);
4198 19 Aug 22 peter 216         for (const auto& t : times_) {
4198 19 Aug 22 peter 217           for (auto it = t.begin; it!=t.end; ++it) {
4198 19 Aug 22 peter 218             double theta = it->theta(beta);
4198 19 Aug 22 peter 219             sum_theta += theta;
4198 19 Aug 22 peter 220             sum_thetaX += theta * it->x;
4198 19 Aug 22 peter 221             for (size_t i=0; i<n; ++i)
4198 19 Aug 22 peter 222               sum_thetaXX.row_view(i) += theta * it->x(i) * it->x;
4198 19 Aug 22 peter 223           }
4198 19 Aug 22 peter 224         }
4198 19 Aug 22 peter 225
4198 19 Aug 22 peter 226         // loop over unique time points
4198 19 Aug 22 peter 227         for (const auto& t : times_) {
4198 19 Aug 22 peter 228
4198 19 Aug 22 peter 229           if (t.events_begin() != t.events_end()) {
4198 19 Aug 22 peter 230             // sum over all events in H_j
4198 19 Aug 22 peter 231             double part_sum_theta = 0;
4198 19 Aug 22 peter 232             utility::Vector part_sum_thetaX(n, 0);
4198 19 Aug 22 peter 233             utility::Matrix part_sum_thetaXX(n, n,0);
4198 19 Aug 22 peter 234             for (auto it = t.events_begin(); it!=t.events_end(); ++it) {
4198 19 Aug 22 peter 235               double theta = it->theta(beta);
4198 19 Aug 22 peter 236               part_sum_theta += theta;
4198 19 Aug 22 peter 237               part_sum_thetaX += theta * it->x;
4198 19 Aug 22 peter 238               for (size_t i=0; i<n; ++i)
4198 19 Aug 22 peter 239                 part_sum_thetaXX.row_view(i) += theta * it->x(i) * it->x;
4198 19 Aug 22 peter 240             }
4198 19 Aug 22 peter 241
4198 19 Aug 22 peter 242             // loop over events at time point t
4198 19 Aug 22 peter 243             for (auto it = t.events_begin(); it!=t.events_end(); ++it) {
4198 19 Aug 22 peter 244               const size_t k = it - t.events_begin();
4198 19 Aug 22 peter 245               double r = static_cast<double>(k) / t.size();
4198 19 Aug 22 peter 246               double S_theta   = sum_theta   - r * part_sum_theta;
4198 19 Aug 22 peter 247               utility::Vector S_thetaX  = sum_thetaX  - r * part_sum_thetaX;
4198 19 Aug 22 peter 248               utility::Matrix S_thetaXX = sum_thetaXX - r * part_sum_thetaXX;
4198 19 Aug 22 peter 249
4198 19 Aug 22 peter 250               H += (1.0 / S_theta) * S_thetaXX;
4198 19 Aug 22 peter 251               utility::Vector ratio = (1.0 / S_theta) * S_thetaX;
4198 19 Aug 22 peter 252               for (size_t i=0; i<n; ++i)
4198 19 Aug 22 peter 253                 H.row_view(i) -= ratio(i) * ratio;
4198 19 Aug 22 peter 254             }
4198 19 Aug 22 peter 255           }
4198 19 Aug 22 peter 256
4198 19 Aug 22 peter 257           // update the cumulative sums
4198 19 Aug 22 peter 258           for (auto it = t.begin; it!=t.end; ++it) {
4198 19 Aug 22 peter 259             double theta = it->theta(beta);
4198 19 Aug 22 peter 260             sum_theta -= theta;
4198 19 Aug 22 peter 261             sum_thetaX -= theta * it->x;
4198 19 Aug 22 peter 262             for (size_t i=0; i<n; ++i)
4198 19 Aug 22 peter 263               sum_thetaXX.row_view(i) -= theta * it->x(i) * it->x;
4198 19 Aug 22 peter 264           }
4198 19 Aug 22 peter 265         }
4198 19 Aug 22 peter 266       }
4198 19 Aug 22 peter 267     };
4198 19 Aug 22 peter 268   };
4198 19 Aug 22 peter 269
4198 19 Aug 22 peter 270
4198 19 Aug 22 peter 271   MultiCox::MultiCox(void)
4198 19 Aug 22 peter 272     : pimpl_(new Impl)
4198 19 Aug 22 peter 273   {
4198 19 Aug 22 peter 274   }
4198 19 Aug 22 peter 275
4198 19 Aug 22 peter 276
4198 19 Aug 22 peter 277   MultiCox::MultiCox(const MultiCox& other)
4198 19 Aug 22 peter 278     : pimpl_(new Impl(*other.pimpl_))
4198 19 Aug 22 peter 279   {
4198 19 Aug 22 peter 280   }
4198 19 Aug 22 peter 281
4198 19 Aug 22 peter 282
4198 19 Aug 22 peter 283   MultiCox::MultiCox(MultiCox&& other)
4198 19 Aug 22 peter 284   {
4198 19 Aug 22 peter 285     std::swap(pimpl_, other.pimpl_);
4198 19 Aug 22 peter 286   }
4198 19 Aug 22 peter 287
4198 19 Aug 22 peter 288
4198 19 Aug 22 peter 289   MultiCox::~MultiCox(void)
4198 19 Aug 22 peter 290   {
4198 19 Aug 22 peter 291   }
4198 19 Aug 22 peter 292
4198 19 Aug 22 peter 293
4198 19 Aug 22 peter 294   MultiCox& MultiCox::operator=(const MultiCox& other)
4198 19 Aug 22 peter 295   {
4198 19 Aug 22 peter 296     assert(other.pimpl_);
4198 19 Aug 22 peter 297     pimpl_.reset(new Impl(*other.pimpl_));
4198 19 Aug 22 peter 298     return *this;
4198 19 Aug 22 peter 299   }
4198 19 Aug 22 peter 300
4198 19 Aug 22 peter 301
4198 19 Aug 22 peter 302   MultiCox& MultiCox::operator=(MultiCox&& other)
4198 19 Aug 22 peter 303   {
4198 19 Aug 22 peter 304     std::swap(pimpl_, other.pimpl_);
4198 19 Aug 22 peter 305     return *this;
4198 19 Aug 22 peter 306   }
4198 19 Aug 22 peter 307
4198 19 Aug 22 peter 308
4198 19 Aug 22 peter 309   void MultiCox::add(const utility::VectorBase& x, double time,
4198 19 Aug 22 peter 310                      bool event)
4198 19 Aug 22 peter 311   {
4198 19 Aug 22 peter 312     pimpl_->add(x, time, event);
4198 19 Aug 22 peter 313   }
4198 19 Aug 22 peter 314
4198 19 Aug 22 peter 315
4198 19 Aug 22 peter 316   void MultiCox::add(const utility::MatrixBase& X,
4198 19 Aug 22 peter 317                      const utility::VectorBase& times,
4198 19 Aug 22 peter 318                      const std::vector<char>& event)
4198 19 Aug 22 peter 319   {
4198 19 Aug 22 peter 320     pimpl_->add(X, times, event);
4198 19 Aug 22 peter 321   }
4198 19 Aug 22 peter 322
4198 19 Aug 22 peter 323
4198 19 Aug 22 peter 324   double MultiCox::b(size_t i) const
4198 19 Aug 22 peter 325   {
4198 19 Aug 22 peter 326     return pimpl_->b(i);
4198 19 Aug 22 peter 327   }
4198 19 Aug 22 peter 328
4198 19 Aug 22 peter 329
4198 19 Aug 22 peter 330   void MultiCox::clear(void)
4198 19 Aug 22 peter 331   {
4198 19 Aug 22 peter 332     pimpl_->clear();
4198 19 Aug 22 peter 333   }
4198 19 Aug 22 peter 334
4198 19 Aug 22 peter 335
4198 19 Aug 22 peter 336   const utility::MatrixBase& MultiCox::covariance(void) const
4198 19 Aug 22 peter 337   {
4198 19 Aug 22 peter 338     return pimpl_->covariance();
4198 19 Aug 22 peter 339   }
4198 19 Aug 22 peter 340
4198 19 Aug 22 peter 341
4198 19 Aug 22 peter 342   double MultiCox::hazard_ratio(size_t i) const
4198 19 Aug 22 peter 343   {
4198 19 Aug 22 peter 344     return pimpl_->hazard_ratio(i);
4198 19 Aug 22 peter 345   }
4198 19 Aug 22 peter 346
4198 19 Aug 22 peter 347
4198 19 Aug 22 peter 348   double MultiCox::hazard_ratio_lower_CI(size_t i, double alpha) const
4198 19 Aug 22 peter 349   {
4198 19 Aug 22 peter 350     return pimpl_->hazard_ratio_lower_CI(i, alpha);
4198 19 Aug 22 peter 351   }
4198 19 Aug 22 peter 352
4198 19 Aug 22 peter 353
4198 19 Aug 22 peter 354   double MultiCox::hazard_ratio_upper_CI(size_t i, double alpha) const
4198 19 Aug 22 peter 355   {
4198 19 Aug 22 peter 356     return pimpl_->hazard_ratio_upper_CI(i, alpha);
4198 19 Aug 22 peter 357   }
4198 19 Aug 22 peter 358
4198 19 Aug 22 peter 359
4198 19 Aug 22 peter 360   double MultiCox::p(size_t i) const
4198 19 Aug 22 peter 361   {
4198 19 Aug 22 peter 362     return pimpl_->p(i);
4198 19 Aug 22 peter 363   }
4198 19 Aug 22 peter 364
4198 19 Aug 22 peter 365
4198 19 Aug 22 peter 366   void MultiCox::train(void)
4198 19 Aug 22 peter 367   {
4198 19 Aug 22 peter 368     pimpl_->train();
4198 19 Aug 22 peter 369   }
4198 19 Aug 22 peter 370
4198 19 Aug 22 peter 371
4198 19 Aug 22 peter 372   double MultiCox::z(size_t i) const
4198 19 Aug 22 peter 373   {
4198 19 Aug 22 peter 374     return pimpl_->z(i);
4198 19 Aug 22 peter 375   }
4198 19 Aug 22 peter 376
4198 19 Aug 22 peter 377
4198 19 Aug 22 peter 378 }}}