Hi stan community,
I am new to stan, and now need to fit a hierarchical model to a small simulated data via rstan, but I found sampling is very time-consuming (takes more than 12 hours even for 2000 iteration/500 warmup). I am not good at computing, so would be greatly appreciated if there are any tips!
The model has some features: (1) The responses are 2-dimensional; (2) it follows a specific special distribution (so I have to define the likelihood function by myself); (3) it has many parameters (we do not need care about the overfitting problem).
In order to speed up the sampling, I have tried my best to improve the stan model, including:
- Try to use vectorize operation, and avoid to use loops.
- Remove any not-needed parts to save time, for example, the generated quantities block
- Sampling the chains parallelly, so I simply use cores=4 at the stan(…) code in R studio
Besides these, I am also considering other possible strategies:
4. Adjust the adapt_delta and max_treedepth values, but I think it will influence the quality of posterior samples. Currently I use adapt_delta=0.95 and max_treedepth=15.
5. Any other parts or effective and well-designed codes in stan, for example, transformed parameters block. However, I do not know how to do in that way to speed things up.
6. Adjust priors, but I have no any good ideas.
7. Adjust initial starting points. To my knowledge, they may affect posterior samples’ quality, but not too much to the time cost of sampling.
My stan code is as below
functions{
real two_d_st_lpdf(matrix XY, vector xi1, vector xi2, vector sigma1, vector sigma2, vector rho12, vector alpha1, vector alpha2){
vector [cols(XY)] A;
vector [cols(XY)] B;
vector [cols(XY)] C1;
vector [cols(XY)] C2;
vector [cols(XY)] singlelogprob;
real lprob;
A=sigma1 .* sigma1 .* sigma2 .* sigma2 - rho12 .* rho12 .* sigma1 .* sigma1 .* sigma2 .* sigma2;
B=(to_vector(XY[1,])- xi1 ).*(to_vector(XY[1,])- xi1) .* sigma2 .* sigma2-
(to_vector(XY[1,])- xi1 ).*(to_vector(XY[2,])- xi2) .* rho12 .* sigma1 .* sigma2+
(to_vector(XY[2,])- xi2 ).*(to_vector(XY[2,])- xi2) .* sigma1 .* sigma1-
(to_vector(XY[1,])- xi1 ).*(to_vector(XY[2,])- xi2) .* rho12 .* sigma1 .* sigma2;
C1=alpha1 .* ( to_vector(XY[1,])-xi1 ) ./ sigma1+alpha2 .* ( to_vector(XY[2,])-xi2 ) ./ sigma2;
C2=(to_vector(XY[1,])-xi1).*(to_vector(XY[1,])-xi1) ./ sigma1+(to_vector(XY[2,])-xi2).*(to_vector(XY[2,])-xi2) ./ sigma2;
for(i in 1:cols(XY)){
if(sigma1[i]<0){singlelogprob[i]=-1000000; }
else{
if(sigma2[i]<0){singlelogprob[i]=-1000000;}
else{
singlelogprob[i]=log(23)-0.5*log(pi()*23)-0.5*log(A[i])-12.5*log( 1+(B[i]/A[i])/23 )+student_t_lcdf(C1[i]*sqrt(25/(C2[i]+2))|25, 0, 1);
}
}
}
lprob = sum(singlelogprob);
return lprob;
}
}
data{
int<lower=1> N;
int<lower=1> N_Rj;
int<lower=1> N_j;
matrix[2, N] BH; //responses
int Rj[N]; //index for groups
vector[N] Is;
vector[N] Ic;
vector[N] mid_year_corrected;
}
parameters{
real ag;
real<lower=0,upper=100> sd_ar;
real As;
real Ac;
real bg;
real<lower=0,upper=100> sd_br;
real Bs;
real Bc;
real<lower=0,upper=300> ag2;///////////////////////////////////////////
real<lower=0,upper=100> sd_ar2;
real As2;
real Ac2;
real bg2;
real<lower=0,upper=100> sd_br2;
real Bs2;
real Bc2;
real<lower=0,upper=300> ag3;///////////////////////////////////////////
real<lower=0,upper=100> sd_ar3;
real As3;
real Ac3;
real bg3;
real<lower=0,upper=100> sd_br3;
real Bs3;
real Bc3;
real<lower=0,upper=300> ag4;///////////////////////////////////////////
real<lower=0,upper=100> sd_ar4;
real As4;
real Ac4;
real bg4;
real<lower=0,upper=100> sd_br4;
real Bs4;
real Bc4;
real<lower=0,upper=300> ag5;///////////////////////////////////////////
real<lower=0,upper=100> sd_ar5;
real<lower=0,upper=100> sd_ac5;
real As5;
real Ac5;
real bg5;
real<lower=0,upper=100> sd_br5;
real<lower=0,upper=100> sd_bc5;
real Bs5;
real Bc5;
real<lower=0,upper=300> ag6;///////////////////////////////////////////
real<lower=0,upper=100> sd_ar6;
real As6;
real Ac6;
real bg6;
real<lower=0,upper=100> sd_br6;
real Bs6;
real Bc6;
vector[N_Rj] ar;
vector[N_Rj] br;
vector[N_Rj] ar2;//////
vector[N_Rj] br2;
vector[N_Rj] ar3;//////
vector[N_Rj] br3;
vector[N_Rj] ar4;//////
vector[N_Rj] br4;
vector[N_Rj] ar5;//////
vector[N_Rj] br5;
vector[N_Rj] ar6;//////
vector[N_Rj] br6;
real <lower=-1,upper=1> rho12parameter;////////////////
}
model{
vector[N] sigma1;
vector[N] xi1;
vector[N] sigma2;
vector[N] xi2;
vector[N] rho12;
vector[N] alpha1;
vector[N] alpha2;
rho12parameter ~ uniform(-1,1);
Bc ~ normal( 0 , 10 );
Bs ~ normal( 0 , 10 );
br ~ normal( 0 , sd_br );
bg ~ normal( 0 , 10 );
Ac ~ normal( 0 , 10 );
As ~ normal( 0 , 10 );
ar ~ normal( 0 , sd_ar );
ag ~ normal( 0 , 100 );
Bc2 ~ normal( 0 , 10 );
Bs2 ~ normal( 0 , 10 );
br2 ~ normal( 0 , sd_br2 );
bg2 ~ normal( 0 , 10 );
Ac2 ~ normal( 0 , 10 );
As2 ~ normal( 0 , 10 );
ar2 ~ normal( 0 , sd_ar2 );
ag2 ~ uniform( 0 , 300 );//
Bc3 ~ normal( 0 , 10 );
Bs3 ~ normal( 0 , 10 );
br3 ~ normal( 0 , sd_br3 );
bg3 ~ normal( 0 , 10 );
Ac3 ~ normal( 0 , 10 );
As3 ~ normal( 0 , 10 );
ar3 ~ normal( 0 , sd_ar3 );
ag3 ~ uniform( 0 , 300 );
Bc4 ~ normal( 0 , 10 );
Bs4 ~ normal( 0 , 10 );
br4 ~ normal( 0 , sd_br4 );
bg4 ~ normal( 0 , 10 );
Ac4 ~ normal( 0 , 10 );
As4 ~ normal( 0 , 10 );
ar4 ~ normal( 0 , sd_ar4 );
ag4 ~ uniform( 0 , 300 );//
Bc5 ~ normal( 0 , 10 );
Bs5 ~ normal( 0 , 10 );
br5 ~ normal( 0 , sd_br5 );
bg5 ~ normal( 0 , 10 );
Ac5 ~ normal( 0 , 10 );
As5 ~ normal( 0 , 10 );
ar5 ~ normal( 0 , sd_ar5 );
ag5 ~ uniform( 0 , 300 );
Bc6 ~ normal( 0 , 10 );
Bs6 ~ normal( 0 , 10 );
br6 ~ normal( 0 , sd_br6 );
bg6 ~ normal( 0 , 10 );
Ac6 ~ normal( 0 , 10 );
As6 ~ normal( 0 , 10 );
ar6 ~ normal( 0 , sd_ar6 );
ag6 ~ uniform( 0 , 300 );
xi1 = ag + ar[Rj] + As * Is + Ac * Ic + (bg + br[Rj] +Bs * Is + Bc * Ic).* mid_year_corrected ;
xi2 = ag2 + ar2[Rj] + As2 * Is + Ac2 * Ic + (bg2 + br2[Rj] +Bs2 * Is + Bc2 * Ic).* mid_year_corrected ;
sigma1 = ag3 + ar3[Rj] + As3 * Is + Ac3 * Ic + (bg3 + br3[Rj] +Bs3 * Is + Bc3 * Ic).* mid_year_corrected ;
sigma2 = ag4 + ar4[Rj] + As4 * Is + Ac4 * Ic + (bg4 + br4[Rj] +Bs4 * Is + Bc4 * Ic).* mid_year_corrected ;
for ( i in 1:N ) {
rho12[i] = rho12parameter;
}
alpha1 = ag5 + ar5[Rj] + As5 * Is + Ac5 * Ic + (bg5 + br5[Rj] +Bs5 * Is + Bc5 * Ic).* mid_year_corrected ;
alpha2 = ag6 + ar6[Rj] + As6 * Is + Ac6 * Ic + (bg6 + br5[Rj] +Bs6 * Is + Bc6 * Ic).* mid_year_corrected ;
BH~ two_d_st(xi1,xi2,sigma1,sigma2,rho12,alpha1,alpha2);
}
In the function block, A, B, C1, C2 are four parts in the special likelihood function, and I have tried to avoid to use loops to calculate them, but I think I may still need to use loops in the user-defined function block.
I also used if-else because some parameters should be strictly positive to be well-defined by definition, so I set the log-likelihood to be very small (-1000000) when these parameters are negative.
Any tips and comments are welcome. Thanks in advance!