Hi everyone, I am trying to implement survival model in rstan. The code below is bit modified version of the leuk example. The problem is when I run the program, it takes more than 8 hours with 4 chains each 1000 iterations. The data is relatively big, 4000 observations and 7 predictors. Is there any possible way to optimize the code and make it faster?!
data {
int<lower=0> N;//number of observations
int<lower=0> NT;//number of death times
int<lower=0> obs_t[N]; // survival months of each patient
int<lower=0> S[NT+1]; //unique death times
int<lower=0> fail[N]; // dead or alive
int K;
matrix[N,K] z;
real mu[N];// Lnr
}
transformed data {
int Y[N, NT];
int dN[N, NT];
real c;
real r;
for(i in 1:N) {
for(j in 1:NT) {
Y[i, j] = int_step(obs_t[i] - S[j] + .000000001);
dN[i, j] = Y[i, j] * fail[i] * int_step(S[j + 1] - obs_t[i] - .000000001);
}
}
c = 1.5;
r = 0.1;
}
parameters {
real betaLnr;
vector [K] beta;
real<lower=0> dL0[NT];
}
model {
betaLnr ~ normal(0, 1000);
beta ~ normal(0, 1000);
for(j in 1:NT) {
dL0[j] ~ gamma(r * (S[j+1] - S[j]) * c, c);
for(i in 1:N) {
if (Y[i, j] != 0)
dN[i,j]~poisson(Y[i, j] * exp(mu[i]*betaLnr+dot_product(z[i],beta))));
}
}
}
I reduced the time by adding “cores” in stan() function. Before I did not know about it,
unless I have seen in your code. cores=6 helped me to run faster, it took around 1,5 hours.
Don’t worry about efficiency in transformed data. But the rest of the model is tricky to make more efficient.
You can precompute a vector of S_diff_times_c[j] = (S[j + 1] - S[j]) * c to use in the dL0 distribution in the model block. Then vectorize to
dL0 ~ gamma(r * S_diff_times_c[j], c);
Then you need to vectorize the dN, but there’s the nasty Y[i, j] > 0 condition,w hich seems to imply the data with Y[i, j] == 0 is not being modeled. Is that the intention?
To vectorize, you need replace Y with a precomputed array of values where Y[i,j] > 0, but that will turn out to be ragged, which is a headache with our current rectangular data structures. Then vectorize by row. And use the poisson_log distribution, which takes a parameter on the log scale, e.g.,