yat/regression/Cox.cc

Code
Comments
Other
Rev Date Author Line
4198 19 Aug 22 peter 1 // $Id$
4198 19 Aug 22 peter 2
4198 19 Aug 22 peter 3 /*
4198 19 Aug 22 peter 4   Copyright (C) 2022 Peter Johansson
4198 19 Aug 22 peter 5
4198 19 Aug 22 peter 6   This file is part of the yat library, https://dev.thep.lu.se/yat
4198 19 Aug 22 peter 7
4198 19 Aug 22 peter 8   The yat library is free software; you can redistribute it and/or
4198 19 Aug 22 peter 9   modify it under the terms of the GNU General Public License as
4198 19 Aug 22 peter 10   published by the Free Software Foundation; either version 3 of the
4198 19 Aug 22 peter 11   License, or (at your option) any later version.
4198 19 Aug 22 peter 12
4198 19 Aug 22 peter 13   The yat library is distributed in the hope that it will be useful,
4198 19 Aug 22 peter 14   but WITHOUT ANY WARRANTY; without even the implied warranty of
4198 19 Aug 22 peter 15   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
4198 19 Aug 22 peter 16   General Public License for more details.
4198 19 Aug 22 peter 17
4198 19 Aug 22 peter 18   You should have received a copy of the GNU General Public License
4198 19 Aug 22 peter 19   along with yat. If not, see <https://www.gnu.org/licenses/>.
4198 19 Aug 22 peter 20 */
4198 19 Aug 22 peter 21
4198 19 Aug 22 peter 22 #include <config.h>
4198 19 Aug 22 peter 23
4198 19 Aug 22 peter 24 #include "Cox.h"
4198 19 Aug 22 peter 25
4198 19 Aug 22 peter 26 #include "detail/Cox.h"
4198 19 Aug 22 peter 27
4252 18 Nov 22 peter 28 #include <yat/utility/Steffenson.h>
4198 19 Aug 22 peter 29 #include <yat/utility/VectorBase.h>
4198 19 Aug 22 peter 30
4198 19 Aug 22 peter 31 #include <gsl/gsl_cdf.h>
4198 19 Aug 22 peter 32
4198 19 Aug 22 peter 33 #include <algorithm>
4198 19 Aug 22 peter 34 #include <memory>
4198 19 Aug 22 peter 35
4198 19 Aug 22 peter 36 namespace theplu {
4198 19 Aug 22 peter 37 namespace yat {
4198 19 Aug 22 peter 38 namespace regression {
4198 19 Aug 22 peter 39
4198 19 Aug 22 peter 40   class Cox::Impl : public cox::Implementation<double>
4198 19 Aug 22 peter 41   {
4198 19 Aug 22 peter 42   public:
4198 19 Aug 22 peter 43     using cox::Implementation<double>::add;
4198 19 Aug 22 peter 44     void add(const yat::utility::VectorBase& x,
4198 19 Aug 22 peter 45              const yat::utility::VectorBase& time,
4198 19 Aug 22 peter 46              const std::vector<char>& event)
4198 19 Aug 22 peter 47     {
4198 19 Aug 22 peter 48       assert(x.size() == time.size());
4198 19 Aug 22 peter 49       assert(x.size() == event.size());
4198 19 Aug 22 peter 50       for (size_t i=0; i<x.size(); ++i)
4198 19 Aug 22 peter 51         add(x(i), time(i), event[i]);
4198 19 Aug 22 peter 52     }
4198 19 Aug 22 peter 53
4198 19 Aug 22 peter 54     double b(void) const { return beta_ ; }
4198 19 Aug 22 peter 55     double hazard_ratio(void) const
4198 19 Aug 22 peter 56     { return exp(beta_); }
4198 19 Aug 22 peter 57
4198 19 Aug 22 peter 58     double hazard_ratio_lower_CI(double alpha) const
4198 19 Aug 22 peter 59     { return exp(beta_ - hazard_ratio_CI(alpha)); }
4198 19 Aug 22 peter 60
4198 19 Aug 22 peter 61     double hazard_ratio_upper_CI(double alpha) const
4198 19 Aug 22 peter 62     { return exp(beta_ + hazard_ratio_CI(alpha)); }
4198 19 Aug 22 peter 63
4198 19 Aug 22 peter 64     double p(void) const
4198 19 Aug 22 peter 65     { return 2 * gsl_cdf_ugaussian_Q(std::abs(z())); }
4198 19 Aug 22 peter 66
4198 19 Aug 22 peter 67     void train(void);
4198 19 Aug 22 peter 68
4198 19 Aug 22 peter 69     double z(void) const { return beta_ / beta_std_error_; }
4198 19 Aug 22 peter 70
4198 19 Aug 22 peter 71   private:
4198 19 Aug 22 peter 72     double hazard_ratio_CI(double alpha) const
4198 19 Aug 22 peter 73     {
4198 19 Aug 22 peter 74       double z = gsl_cdf_ugaussian_Qinv(0.5 * (1.0 - alpha));
4198 19 Aug 22 peter 75       return z * beta_std_error_;
4198 19 Aug 22 peter 76     }
4198 19 Aug 22 peter 77     double beta_;
4198 19 Aug 22 peter 78     double beta_std_error_;
4198 19 Aug 22 peter 79
4252 18 Nov 22 peter 80     class Score
4198 19 Aug 22 peter 81     {
4198 19 Aug 22 peter 82     public:
4252 18 Nov 22 peter 83       Score(const std::vector<TimePoint>& times);
4252 18 Nov 22 peter 84       double operator()(double beta) const;
4252 18 Nov 22 peter 85       double derivative(double beta) const;
4198 19 Aug 22 peter 86     private:
4198 19 Aug 22 peter 87       const std::vector<TimePoint>& times_;
4198 19 Aug 22 peter 88     };
4198 19 Aug 22 peter 89   };
4198 19 Aug 22 peter 90
4198 19 Aug 22 peter 91
4198 19 Aug 22 peter 92   void Cox::Impl::train(void)
4198 19 Aug 22 peter 93   {
4198 19 Aug 22 peter 94     if (data_.empty())
4198 19 Aug 22 peter 95       return;
4198 19 Aug 22 peter 96
4198 19 Aug 22 peter 97     prepare_times();
4252 18 Nov 22 peter 98     // score is derivative of logL
4252 18 Nov 22 peter 99     Score score(times_);
4252 18 Nov 22 peter 100     for (double b=0; b<1.1; b+=0.1)
4252 18 Nov 22 peter 101       score(b);
4252 18 Nov 22 peter 102     utility::Steffenson solver;
4252 18 Nov 22 peter 103     beta_ = solver(score, 0.0,
4252 18 Nov 22 peter 104                    utility::RootFinderDerivative::Delta(0.0, 1e-5));
4198 19 Aug 22 peter 105     if (std::isnan(beta_))
4198 19 Aug 22 peter 106       throw std::runtime_error("beta is NaN");
4252 18 Nov 22 peter 107     // 2nd derivative of logL is 1st derivative of score
4252 18 Nov 22 peter 108     double hessian = score.derivative(beta_);
4252 18 Nov 22 peter 109     beta_std_error_ = 1.0 / std::sqrt(-hessian);
4198 19 Aug 22 peter 110   }
4198 19 Aug 22 peter 111
4198 19 Aug 22 peter 112
4252 18 Nov 22 peter 113   /*
4252 18 Nov 22 peter 114     Without ties the log-likelihood is
4252 18 Nov 22 peter 115     logL = sum_i (x_i * beta - log sum_j x_j * beta)
4252 18 Nov 22 peter 116     where i runs over all events and j runs over all data points j
4252 18 Nov 22 peter 117     such that t_j >= t_i
4198 19 Aug 22 peter 118
4252 18 Nov 22 peter 119     Setting theta = x * beta we have
4252 18 Nov 22 peter 120     logL = sum_i (theta_i - log sum_j theta_j) =
4252 18 Nov 22 peter 121          = sum_i (theta_i - log theta_Q_i) =
4198 19 Aug 22 peter 122
4252 18 Nov 22 peter 123     theta = beta * x -> dtheta/dbeta = x
4198 19 Aug 22 peter 124
4252 18 Nov 22 peter 125     We handle ties using Efron's method. Let denote m_i number of data
4252 18 Nov 22 peter 126     points at t_i, H_i the indices of events at time t_i.
4198 19 Aug 22 peter 127
4252 18 Nov 22 peter 128     logL = sum_i (theta_H_i - sum_k^m_i-1 log(theta_Q_i - k/m_i theta_H_i))
4198 19 Aug 22 peter 129
4252 18 Nov 22 peter 130     where theta_H_i is a sum of theta running over H_i.
4198 19 Aug 22 peter 131
4252 18 Nov 22 peter 132     The derivative (wrt beta)
4252 18 Nov 22 peter 133     l' = sum_i(x_H_i - sum_k^m_i-1 (theta*x_Q_i - k/m_j theta+x_H_i) / (theta_Q_i - k/m_j theta_H_i))
4198 19 Aug 22 peter 134
4252 18 Nov 22 peter 135   */
4252 18 Nov 22 peter 136   Cox::Impl::Score::Score(const std::vector<TimePoint>& times)
4252 18 Nov 22 peter 137     : times_(times)
4252 18 Nov 22 peter 138   {
4252 18 Nov 22 peter 139   }
4198 19 Aug 22 peter 140
4252 18 Nov 22 peter 141
4252 18 Nov 22 peter 142   double Cox::Impl::Score::operator()(double beta) const
4252 18 Nov 22 peter 143   {
4252 18 Nov 22 peter 144     double score = 0;
4252 18 Nov 22 peter 145     // variables with suffix _Q denote sums running over data points
4252 18 Nov 22 peter 146     // (including events and censored data points) at current time and
4252 18 Nov 22 peter 147     // future
4198 19 Aug 22 peter 148     double theta_Q = 0;
4198 19 Aug 22 peter 149     double thetaX_Q = 0;
4252 18 Nov 22 peter 150
4198 19 Aug 22 peter 151     for (auto time = times_.rbegin(); time!=times_.rend(); ++time) {
4252 18 Nov 22 peter 152       // variables with suffix _H denote sums running over events at
4252 18 Nov 22 peter 153       // the current time.
4252 18 Nov 22 peter 154       double theta_H = 0;
4252 18 Nov 22 peter 155       double thetaX_H = 0;
4198 19 Aug 22 peter 156       for (auto it = time->events_begin(); it!=time->events_end(); ++it) {
4198 19 Aug 22 peter 157         double theta = it->theta(beta);
4252 18 Nov 22 peter 158         theta_H += theta;
4252 18 Nov 22 peter 159         thetaX_H += theta * it->x;
4198 19 Aug 22 peter 160       }
4252 18 Nov 22 peter 161       theta_Q += theta_H;
4252 18 Nov 22 peter 162       thetaX_Q += thetaX_H;
4198 19 Aug 22 peter 163
4198 19 Aug 22 peter 164       for (auto it = time->censored_begin(); it!=time->censored_end(); ++it) {
4198 19 Aug 22 peter 165         double theta = it->theta(beta);
4198 19 Aug 22 peter 166         theta_Q += theta;
4198 19 Aug 22 peter 167         thetaX_Q += theta * it->x;
4198 19 Aug 22 peter 168       }
4198 19 Aug 22 peter 169
4198 19 Aug 22 peter 170       // loop over events at time point t
4198 19 Aug 22 peter 171       for (auto it = time->events_begin(); it!=time->events_end(); ++it) {
4252 18 Nov 22 peter 172         score += it->x;
4252 18 Nov 22 peter 173
4198 19 Aug 22 peter 174         const size_t k = it - time->events_begin();
4198 19 Aug 22 peter 175         double r = static_cast<double>(k) / time->size();
4198 19 Aug 22 peter 176
4252 18 Nov 22 peter 177         assert(theta_Q > r * theta_H);
4198 19 Aug 22 peter 178
4252 18 Nov 22 peter 179         score -= (thetaX_Q - r * thetaX_H) / (theta_Q - r * theta_H);
4198 19 Aug 22 peter 180       }
4198 19 Aug 22 peter 181     }
4198 19 Aug 22 peter 182
4252 18 Nov 22 peter 183     assert(!std::isnan(score));
4252 18 Nov 22 peter 184     return score;
4198 19 Aug 22 peter 185   }
4198 19 Aug 22 peter 186
4198 19 Aug 22 peter 187
4252 18 Nov 22 peter 188   double Cox::Impl::Score::derivative(double beta) const
4198 19 Aug 22 peter 189   {
4252 18 Nov 22 peter 190     double deriv = 0;
4252 18 Nov 22 peter 191     // variables with suffix _Q denote sums running over data points
4252 18 Nov 22 peter 192     // (including events and censored data points) at current time and
4252 18 Nov 22 peter 193     // future
4252 18 Nov 22 peter 194     double theta_Q = 0;
4252 18 Nov 22 peter 195     double thetaX_Q = 0;
4252 18 Nov 22 peter 196     double thetaXX_Q = 0;
4252 18 Nov 22 peter 197     double XX_Q = 0;
4198 19 Aug 22 peter 198
4252 18 Nov 22 peter 199     for (auto time = times_.rbegin(); time!=times_.rend(); ++time) {
4252 18 Nov 22 peter 200       // variables with suffix _H denote sums running over events at
4252 18 Nov 22 peter 201       // the current time.
4252 18 Nov 22 peter 202       double theta_H = 0;
4252 18 Nov 22 peter 203       double thetaX_H = 0;
4252 18 Nov 22 peter 204       double thetaXX_H = 0;
4252 18 Nov 22 peter 205       double XX_H = 0;
4252 18 Nov 22 peter 206       for (auto it = time->events_begin(); it!=time->events_end(); ++it) {
4198 19 Aug 22 peter 207         double theta = it->theta(beta);
4252 18 Nov 22 peter 208         theta_H += theta;
4252 18 Nov 22 peter 209         thetaX_H += theta * it->x;
4252 18 Nov 22 peter 210         thetaXX_H += theta * it->x * it->x;
4252 18 Nov 22 peter 211         XX_H += it->x * it->x;
4198 19 Aug 22 peter 212       }
4252 18 Nov 22 peter 213       theta_Q += theta_H;
4252 18 Nov 22 peter 214       thetaX_Q += thetaX_H;
4252 18 Nov 22 peter 215       thetaXX_Q += thetaXX_H;
4252 18 Nov 22 peter 216       XX_Q += XX_H;
4198 19 Aug 22 peter 217
4252 18 Nov 22 peter 218       for (auto it = time->censored_begin(); it!=time->censored_end(); ++it) {
4198 19 Aug 22 peter 219         double theta = it->theta(beta);
4252 18 Nov 22 peter 220         theta_Q += theta;
4252 18 Nov 22 peter 221         thetaX_Q += theta * it->x;
4252 18 Nov 22 peter 222         thetaXX_Q += theta * it->x * it->x;
4252 18 Nov 22 peter 223         XX_Q += it->x * it->x;
4198 19 Aug 22 peter 224       }
4198 19 Aug 22 peter 225
4252 18 Nov 22 peter 226       // f = g/h
4252 18 Nov 22 peter 227       // g = - (thetaX_Q - r*thetaX_H)
4252 18 Nov 22 peter 228       // g'= - (thetaXX_Q - r*thetaXX_H)
4252 18 Nov 22 peter 229       //
4252 18 Nov 22 peter 230       // h = theta_Q - r*theta_H
4252 18 Nov 22 peter 231       // h'= thetaX_Q - r*thetaX_H =
4252 18 Nov 22 peter 232
4252 18 Nov 22 peter 233       // f' = (g/h)' = (g'h - gh')/h^2 =
4252 18 Nov 22 peter 234
4198 19 Aug 22 peter 235       // loop over events at time point t
4252 18 Nov 22 peter 236       for (auto it = time->events_begin(); it!=time->events_end(); ++it) {
4252 18 Nov 22 peter 237         const size_t k = it - time->events_begin();
4252 18 Nov 22 peter 238         double r = static_cast<double>(k) / time->size();
4252 18 Nov 22 peter 239         double g = - (thetaX_Q - r*thetaX_H);
4252 18 Nov 22 peter 240         double dg = - (thetaXX_Q - r * thetaXX_H);
4198 19 Aug 22 peter 241
4252 18 Nov 22 peter 242         double h = (theta_Q - r*theta_H);
4252 18 Nov 22 peter 243         double dh= (thetaX_Q - r*thetaX_H);
4252 18 Nov 22 peter 244         deriv += (dg*h - g*dh) / (h*h);
4198 19 Aug 22 peter 245       }
4252 18 Nov 22 peter 246     }
4198 19 Aug 22 peter 247
4252 18 Nov 22 peter 248     assert(!std::isnan(deriv));
4252 18 Nov 22 peter 249     return deriv;
4198 19 Aug 22 peter 250   }
4198 19 Aug 22 peter 251
4198 19 Aug 22 peter 252   // class Cox
4198 19 Aug 22 peter 253
4198 19 Aug 22 peter 254   Cox::Cox(void)
4198 19 Aug 22 peter 255     : pimpl_(new Impl)
4198 19 Aug 22 peter 256   {
4198 19 Aug 22 peter 257   }
4198 19 Aug 22 peter 258
4198 19 Aug 22 peter 259
4198 19 Aug 22 peter 260   Cox::Cox(const Cox& other)
4198 19 Aug 22 peter 261     : pimpl_(new Impl(*other.pimpl_))
4198 19 Aug 22 peter 262   {
4198 19 Aug 22 peter 263   }
4198 19 Aug 22 peter 264
4198 19 Aug 22 peter 265
4198 19 Aug 22 peter 266   Cox::Cox(Cox&& other)
4198 19 Aug 22 peter 267   {
4198 19 Aug 22 peter 268     std::swap(pimpl_, other.pimpl_);
4198 19 Aug 22 peter 269   }
4198 19 Aug 22 peter 270
4198 19 Aug 22 peter 271
4198 19 Aug 22 peter 272   Cox::~Cox(void)
4198 19 Aug 22 peter 273   {
4198 19 Aug 22 peter 274   }
4198 19 Aug 22 peter 275
4198 19 Aug 22 peter 276
4198 19 Aug 22 peter 277   Cox& Cox::operator=(const Cox& other)
4198 19 Aug 22 peter 278   {
4198 19 Aug 22 peter 279     assert(other.pimpl_);
4198 19 Aug 22 peter 280     pimpl_.reset(new Impl(*other.pimpl_));
4198 19 Aug 22 peter 281     return *this;
4198 19 Aug 22 peter 282   }
4198 19 Aug 22 peter 283
4198 19 Aug 22 peter 284
4198 19 Aug 22 peter 285   Cox& Cox::operator=(Cox&& other)
4198 19 Aug 22 peter 286   {
4198 19 Aug 22 peter 287     std::swap(pimpl_, other.pimpl_);
4198 19 Aug 22 peter 288     return *this;
4198 19 Aug 22 peter 289   }
4198 19 Aug 22 peter 290
4198 19 Aug 22 peter 291
4198 19 Aug 22 peter 292   void Cox::Cox::add(double x, double time, bool event)
4198 19 Aug 22 peter 293   {
4198 19 Aug 22 peter 294     pimpl_->add(x, time, event);
4198 19 Aug 22 peter 295   }
4198 19 Aug 22 peter 296
4198 19 Aug 22 peter 297
4198 19 Aug 22 peter 298   void Cox::add(const yat::utility::VectorBase& x,
4198 19 Aug 22 peter 299                 const yat::utility::VectorBase& time,
4198 19 Aug 22 peter 300                 const std::vector<char>& event)
4198 19 Aug 22 peter 301   {
4198 19 Aug 22 peter 302     pimpl_->add(x, time, event);
4198 19 Aug 22 peter 303   }
4198 19 Aug 22 peter 304
4198 19 Aug 22 peter 305
4198 19 Aug 22 peter 306   double Cox::b(void) const
4198 19 Aug 22 peter 307   {
4198 19 Aug 22 peter 308     return pimpl_->b();
4198 19 Aug 22 peter 309   }
4198 19 Aug 22 peter 310
4198 19 Aug 22 peter 311
4198 19 Aug 22 peter 312   void Cox::clear(void)
4198 19 Aug 22 peter 313   {
4198 19 Aug 22 peter 314     pimpl_->clear();
4198 19 Aug 22 peter 315   }
4198 19 Aug 22 peter 316
4198 19 Aug 22 peter 317
4198 19 Aug 22 peter 318   double Cox::hazard_ratio(void) const
4198 19 Aug 22 peter 319   {
4198 19 Aug 22 peter 320     return pimpl_->hazard_ratio();
4198 19 Aug 22 peter 321   }
4198 19 Aug 22 peter 322
4198 19 Aug 22 peter 323
4198 19 Aug 22 peter 324   double Cox::hazard_ratio_lower_CI(double alpha) const
4198 19 Aug 22 peter 325   {
4198 19 Aug 22 peter 326     return pimpl_->hazard_ratio_lower_CI(alpha);
4198 19 Aug 22 peter 327   }
4198 19 Aug 22 peter 328
4198 19 Aug 22 peter 329
4198 19 Aug 22 peter 330   double Cox::hazard_ratio_upper_CI(double alpha) const
4198 19 Aug 22 peter 331   {
4198 19 Aug 22 peter 332     return pimpl_->hazard_ratio_upper_CI(alpha);
4198 19 Aug 22 peter 333   }
4198 19 Aug 22 peter 334
4198 19 Aug 22 peter 335
4198 19 Aug 22 peter 336   double Cox::p(void) const
4198 19 Aug 22 peter 337   {
4198 19 Aug 22 peter 338     return pimpl_->p();
4198 19 Aug 22 peter 339   }
4198 19 Aug 22 peter 340
4198 19 Aug 22 peter 341
4198 19 Aug 22 peter 342   void Cox::train(void)
4198 19 Aug 22 peter 343   {
4198 19 Aug 22 peter 344     pimpl_->train();
4198 19 Aug 22 peter 345   }
4198 19 Aug 22 peter 346
4198 19 Aug 22 peter 347
4198 19 Aug 22 peter 348   double Cox::z(void) const
4198 19 Aug 22 peter 349   {
4198 19 Aug 22 peter 350     return pimpl_->z();
4198 19 Aug 22 peter 351   }
4198 19 Aug 22 peter 352
4198 19 Aug 22 peter 353 }}}