Hello together,
this is my first post here since I found nothing else regarding this problem. Maybe some of you know how to fix my problem.
I wrote a piecewise constant exponential (piecewise constant hazards) survival model for STAN given in the model code below. Basically I recycled the likelihood function given in Bayesian Survival Analysis (Ibrahim, 2001), which is
L(\beta,\lambda|D)=\prod_{i=1}^n \prod_{j=1}^J \left ( \lambda_j \exp(x_i' \beta) \right )^{\delta_{ij} \nu_i} \exp \left ( -\delta_{ij}[\lambda_j (y_i-s_{j-1})+\sum_{g=1}^{j-1} \lambda_g (s_g-s_{g-1})] \exp(x_i' \beta) \right )
where \delta_{ij}=1 if the i^{th} subject failed or was censored in the j^{th} interval and 0 otherwise. The derivation of the likelihood can be found in Ibrahim (2001). If J=1 the model reduces to a parametric exponential model with \lambda \equiv \lambda_1.
My model basically iterates over the observations i=1,...,n and intervals j=1,...,J and given a status vector \nu=(\nu_1,...,\nu_n), a matrix delta which contains the \delta_{ij} for i=1,...,n, j=1,...,J, the vector s=(s_1,...,s_J) creating the intervals, the design matrix X where x_i=(x_1,...,x_p) is a row belonging to an observation, the vector of baseline hazards \lambda=(\lambda_1,...,\lambda_J) and the vector of regression coefficients \beta=(\beta_1,...,\beta_p) calculates the likelihood. More specifically, my likelihood function pch_log
in the functions
block calculates the log-likelihood.
I put a normal prior \mathcal{N}(0,1) on \beta and a Gamma prior \mathcal{G}(1,1) on the piecewise constant baseline hazards \lambda.
This works quite well if
- The number or predictors (x_1,...,x_p) stays small
- The number of intervals J-1 stays small
I have included a reproducible example using the retinopathy
dataset from the survival
package which includes survival data from a trial of laser coagulation as a treatment to delay diabetic retinopathy. The events are loss of sight here and the appended code 1 uses just the treatment (laser vs no treatment) as a predictor for 2 intervals (0,50] and (50,80] with different constant baseline hazards \lambda_1,\lambda_2.
The appended Code 2 then uses 4 intervals and 4 predictors, and STAN produces posteriors including huge R-hat values and effective sample sizes below 10 for all parameters.
My question now is: Has anyone an idea how to improve the model so that STAN does not encounter such difficulties in exploring the posterior distribution? I already centered the non-categorial predictors and tried initial prior values, but I achieved no improvement.
I have no clue why STAN has such difficulties to explore the posterior distribution. Changing the priors on \lambda to uniform priors has not helped much.
Thanks for any help in advance,
Riko
functions{
real pch_log(vector y, vector nu, matrix delta, matrix X, vector lambda, vector s, vector beta){
vector[num_elements(y)] log_prob;
real logprob;
real hazSum;
for(i in 1:num_elements(y)){
for(j in 1:(num_elements(s)-1)){
if(delta[i,j]==1){ // failure or censored event in interval j for observation i
if(nu[i]==1){ // failure occured in interval j for observation i
hazSum=0;
for(g in 2:j){ // is zero when j=1
hazSum += lambda[g-1]*(s[g]-s[g-1]);
}
log_prob[i]=log(lambda[j])+(X[i,]*beta)+(-(lambda[j]*(y[i]-s[j])+hazSum)*exp(X[i,]*beta));
if(log_prob[i]<-100000){ // if likelihood is very close to zero (log-likelihood would become -infty) we set it to -100000
log_prob[i]=-100000;
}
}
if(nu[i]==0){ // censoring occured in interval j for observation i
hazSum=0;
for(g in 2:j){ // is zero when j=1
hazSum += lambda[g-1]*(s[g]-s[g-1]);
}
log_prob[i]=-(lambda[j]*(y[i]-s[j])+hazSum)*exp(X[i,]*beta);
if(log_prob[i]<-100000){ // if likelihood is very close to zero (log-likelihood would become -infty) we set it to -100000
log_prob[i]=-100000;
}
}
}
}
}
logprob = sum(log_prob); // compute sum of log-likelihoods and return
return logprob;
}
}
data{
int<lower=1> N; // number of patients
vector<lower=0>[N] y; // observed times
vector<lower=0,upper=1>[N] nu; // censoring status 1 or 0
int<lower=0> numCovariates; // number of covariates (predictors)
matrix[N,numCovariates] X; // design matrix for observations
int<lower=1> J; // number of timepoints building the intervals
vector<lower=0>[J] s; // vector with interval endpoints
matrix[N,J-1] delta; // matrix with indicators for event or failure in interval j=1,...,J for patients
}
parameters{
vector[numCovariates] beta_raw; // regression coefficients
vector<lower=0>[J-1] lambda; // piecewise constant baseline hazards
}
transformed parameters{
vector[numCovariates] beta;
beta = 1 * beta_raw;
}
model{
lambda ~ gamma(1,1); // prior for piecewise constant baseline hazards
beta_raw ~ std_normal(); // raw prior for regression coefficients
y ~ pch(nu, delta, X, lambda, s, beta);
}
Code 1: Survival regression for predictor treatment using two intervals
require(survival)
require(tidyverse)
data("retinopathy")
head(retinopathy)
retinopathy$age=scale(retinopathy$age)
# Data preparation
y = retinopathy$futime # censoring or failure time
N=length(y)
nu = retinopathy$status # (0 = censored, 1 = visual loss)
X =as.matrix(retinopathy %>% select(trt)) # predictor matrix
# trt: 0 = no treatment, 1 = laser treatment
s = c(0,50,80) # split points for intervals
J=length(s);
interval1=retinopathy %>% filter(retinopathy$futime <= 80) %>% mutate(interval=1)
interval2=retinopathy %>% filter(retinopathy$futime > 40) %>% mutate(interval=2)
retinopathyNew=rbind(interval1,interval2)
retinopathyNew
j=retinopathyNew$interval
# Create matrix with delta_ij's
delta=matrix(nrow=length(y),ncol=J-1,0)
head(delta)
for(i in 1:length(y)){
for(k in 1:J-1){
if(j[i]==k){
delta[i,k]=1
}
}
}
# Put data in list for STAN
stan_data <- list(N=N,
y=y,
nu=nu,
numCovariates=1,
X=X,
J=J,
s=s,
delta=delta
)
# Run STAN
set.seed(12);
require(rstan)
options(mc.cores = parallel::detectCores()) # use multiple cores if available, speeds up simulations
rstan_options(auto_write = TRUE) # avoids recompiling unchanged models, speed up simulations
pch_surv_model_fit__rethinopathy_Treatment <- stan(model_code = pch_model, data=stan_data, control=list(adapt_delta=0.999, max_treedepth=10), chains=2, iter=5000, init=list(list(beta=c(0,0),lambda=c(1,1)),list(beta=c(0,0),lambda=c(1,1))))
I end up with a posterior mean of -0.8 for \beta_1 and means of 0.02 and 0.98 for \lambda_1,\lambda_2. The resulting R-hat is 1 for all parameters.
Code 2: Survival regression for four predictors treatment, age, type and laser using four intervals
require(survival)
require(tidyverse)
data("retinopathy")
head(retinopathy)
retinopathy$age=scale(retinopathy$age)
# Data preparation
y = retinopathy$futime # censoring or failure time
N=length(y)
nu = retinopathy$status # (0 = censored, 1 = visual loss)
X =as.matrix(retinopathy %>% select(trt,age,laser,type) %>% mutate(laser=as.numeric(retinopathy$laser)-1,type=as.numeric(retinopathy$type)-1)) # predictor matrix
# trt: 0 = no treatment, 1 = laser treatment
# age: age scaled to mean zero and sd one (centered)
max(retinopathy$futime)
hist(retinopathy$futime)
s = c(0,20,40,60,80)
J=length(s);
interval1=retinopathy %>% filter(retinopathy$futime <= 20) %>% mutate(interval=1)
interval2=retinopathy %>% filter(retinopathy$futime > 20) %>% filter(futime <= 40) %>% mutate(interval=2)
interval3=retinopathy %>% filter(retinopathy$futime > 40) %>% filter(futime <= 60) %>% mutate(interval=3)
interval4=retinopathy %>% filter(retinopathy$futime > 60) %>% filter(futime <= 80) %>% mutate(interval=4)
retinopathyNew=rbind(interval1,interval2,interval3,interval4)
retinopathyNew
j=retinopathyNew$interval
# Create matrix with delta_{ij}'s
delta=matrix(nrow=length(y),ncol=J-1,0)
head(delta)
for(i in 1:length(y)){
for(k in 1:J-1){
if(j[i]==k){
delta[i,k]=1
}
}
}
# Put data in list for STAN
stan_data <- list(N=N,
y=y,
nu=nu,
numCovariates=4,
X=X,
J=J,
s=s,
delta=delta
)
# Run STAN
set.seed(12);
require(rstan)
options(mc.cores = parallel::detectCores()) # use multiple cores if available, speeds up simulations
rstan_options(auto_write = TRUE) # avoids recompiling unchanged models, speed up simulations
pch_surv_model_fit_rethinopathy_TreatmentAgeLaserType <- stan(model_code = pch_model, data=stan_data, control=list(adapt_delta=0.8, max_treedepth=15), chains=2, iter=50000, init=list(list(beta=c(0,0,0,0),lambda=c(1,1,1,1)),list(beta=c(0,0,0,0),lambda=c(1,1,1,1)),list(beta=c(0,0,0,0),lambda=c(1,1,1,1)),list(beta=c(0,0,0,0),lambda=c(1,1,1,1))))
In this case, I end up with huge R-hat values, 1-digit effective sample size, and tweaking the hyper parameters of the priors has not helped at all.