Requesting a suggestion for a faster implementation of a model involving mixture distribution

Implementation of our model has two issues – (1) The default initialization does not work (log probablility evaluates to log(0)) and (2) Even with a small sample size (50), it takes 2 days to complete 10000 MCMC iterations for single chain.

Let me describe the model first. The data is based on a game. For each player in the game we have data on number of sessions he/she played per day (counted as 0 if that player did not log in on a particular day), the duration and starting time of each session. I have created three vectors for each of number of sessions, duration and starting time with a partition vector for each. This partitions separate data from one player to another. We haven’t used any covariates in this implementation.

The number of session for i-th player is modeled as zero inflated poisson:

n\_session_i \sim (1 - \pi_i)poisson(\lambda_i) + \pi_i \mathbb{1}\{n\_session_i = 0\},
logit(\pi_i) = \beta_0^{logit} + b_{i1},
log(\lambda_i) = \beta_0^{pois} + b_{i2}.

Duration of j-th session of i-th player is modeled as two-way hazard model, like following,

hazard(dur_{ij} | start_{ij}) = exp(u(start_{ij})*v(dur_{ij}) + \beta_0^{surv} + b_{i3}).

The functions u(.) and v(.) are modelled by basis splines with 3 interior nodes. The random intercept vector (b_{i1}, b_{i2}, b_{i3}) \sim N(0, \Sigma).

The likelihood of zero-inflated poisson part is very standard and the log-likelihood for the two way hazaed model is written as,

u(start_{ij})*v(dur_{ij}) + \beta_0^{surv} + b_{i3} - \int \limits_{0}^{dur_{ij}}exp(u(start_{ij})*v(x) + \beta_0^{surv} + b_{i3})dx.

We have used Gauss-Kronrod quadrature to approximate this integration. Following is my rstan implementation.

code <- "data {
  int <lower=1> N; // number of players
  int <lower=1> l_n_session; // length of num_session_vec
  int <lower=0> n_session[l_n_session]; // num_session_vec
  int <lower=0> part_n_session[N+1]; // partition_num_session vector
  int <lower=1> l_duration; // length of duration_vec
  vector <lower=0>[l_duration] duration; // duration_vec
  vector <lower=0>[l_duration] starting; // starting_vec
  int <lower=0> part_dur[N+1]; // partition_dur vector (same partition for starting vector)
  int <lower=0> df_splines;
  matrix[l_duration, df_splines] bs_start;
  matrix[l_duration, df_splines] bs_surv;
  matrix[(l_duration*15), df_splines] nodes; // 'Gauss_Kronrod_nodes'
  vector <lower=-1, upper=1>[15] wt; // Gauss Kronrod weights
}

parameters {
  real z_beta0_pois;
  real z_beta0_logit;
  real z_beta0_surv;
  real <lower=0> sd_beta0_pois;
  real <lower=0> sd_beta0_logit;
  real <lower=0> sd_beta0_surv;
  matrix[3, N] b_mat;
  //vector <lower=0>[3] b_sd;
  //cholesky_factor_corr[3] b_cholesky;
  vector[df_splines] z_beta_bs_start;
  vector[df_splines] z_beta_bs_surv;
  vector[df_splines] sd_beta_bs_start;
  vector[df_splines] sd_beta_bs_surv;
}

transformed parameters {
  real beta0_pois;
  real beta0_logit;
  real beta0_surv;
  //matrix[3, N] b_mat;
  vector[df_splines] beta_bs_start;
  vector[df_splines] beta_bs_surv;
  beta0_pois = sd_beta0_pois * z_beta0_pois;
  beta0_logit = sd_beta0_logit * z_beta0_logit;
  beta0_surv = sd_beta0_surv * z_beta0_surv;
  //b_mat = diag_pre_multiply(b_sd, b_cholesky) * z_b_mat;
  beta_bs_start = sd_beta_bs_start .* z_beta_bs_start;
  beta_bs_surv = sd_beta_bs_surv .* z_beta_bs_surv;
}

model {
  real p;
  real lambda;
  real log_hazard;
  real log_cum_hazard;
  
  for (n in 1:N) {
    for (i in (part_n_session[n]+1):part_n_session[n+1]) {
      p = inv_logit(beta0_logit + b_mat[2, n]);
      lambda = exp(beta0_pois + b_mat[1, n]);
      if (n_session[i] == 0) {
      target += log_sum_exp(bernoulli_lpmf(1 | p), bernoulli_lpmf(0 | p) + poisson_lpmf(n_session[i] | lambda));
      }
      else {
      target += bernoulli_lpmf(0 | p) + poisson_lpmf(n_session[i] | lambda);
      }
    }
    for (j in (part_dur[n]+1):part_dur[n+1]) {
      log_hazard = dot_product(beta_bs_start, bs_start[j])*dot_product(beta_bs_surv, bs_surv[j]) + beta0_surv + b_mat[3, n];
      log_cum_hazard = exp(beta0_surv + b_mat[3, n])*(duration[j]*(.5))*dot_product(wt, exp(dot_product(beta_bs_start, bs_start[j])*(nodes[(15*(j-1) + 1):(15*j)]*beta_bs_surv)));
      target += log_hazard - log_cum_hazard;
    }
  }
  
  //priors
  
  target += normal_lpdf(z_beta0_pois | 0, 1);
  target += normal_lpdf(z_beta0_logit | 0, 1);
  target += normal_lpdf(z_beta0_surv | 0, 1);
  target += inv_gamma_lpdf(sd_beta0_pois | .5, .5);
  target += inv_gamma_lpdf(sd_beta0_logit | .5, .5);
  target += inv_gamma_lpdf(sd_beta0_surv | .5, .5);
  target += normal_lpdf(b_mat[1] | 0, 1);
  target += normal_lpdf(b_mat[2] | 0, 1);
  target += normal_lpdf(b_mat[3] | 0, 1);
  //target += inv_gamma_lpdf(b_sd | .5, .5);
  //target += normal_lpdf(to_vector(z_b_mat) | 0, 1);
  //target += lkj_corr_cholesky_lpdf(b_cholesky | 1);
  target += normal_lpdf(z_beta_bs_start | 0, 1);
  target += normal_lpdf(z_beta_bs_surv | 0, 1);
  target += inv_gamma_lpdf(sd_beta_bs_start | .5, .5);
  target += inv_gamma_lpdf(sd_beta_bs_surv | .5, .5);
}
"
rstan_options(auto_write = TRUE)

initf <- function(){
  list(
  z_beta0_logit = 1,
  z_beta0_pois = 1,
  z_beta0_surv = 1,
  sd_beta0_logit = 1,
  sd_beta0_pois = 1,
  sd_beta0_surv = 1,
  b_mat = rbind(rnorm(50), rnorm(50), rnorm(50)),
  z_beta_bs_start = rep(1, 5),
  z_beta_bs_surv = rep(1, 5),
  sd_beta_bs_start = rep(1, 5),
  sd_beta_bs_surv = rep(1, 5)
  )
}


model <- stan(model_code=code, 
              data = list(N = length(mod_num_session),
                          l_n_session = length(num_session_vec),
                          n_session = num_session_vec,
                          part_n_session = partition_num_session,
                          l_duration = length(duration_vec),
                          duration = duration_vec,
                          starting = starting_vec,
                          part_dur = partition_dur,
                          df_splines = 5,
                          bs_start = basis_splines_starting_mat,
                          bs_surv = basis_splines_survival_mat,
                          nodes = Gauss_Kronrod_nodes,
                          wt = c15),
              init = initf,
              chains = 1,
              iter = 10000,
              warmup = 3000,
              control=list(adapt_delta=0.999, stepsize=0.001, max_treedepth=30))


Due to initial value problem, I am not able to use the cholesky factorization part as I don’t have much idea what kind of initial value I can put for them. I am a beginner of stan so I have written this code in a very simple way but it costs me huge run time. If youkindly suggest me a faster way to implement this, it will be very much helpful to me.

Thanks,
Soumya.

1 Like

Hi,
it seems like you are missing lower=0 constraint for sd_beta_bs_start and sd_beta_bs_surv which could explain both the initialization failure and the slow sampling. Do you still see problems with this bounds added?

Best of luck with your model!

2 Likes

Thanks a lot for responding. Adding bounds solves the initialization problem but sampling is still very slow (didn’t improve much). Do you think reducing the second for loop may help? I can give a try.

Following is my current code:


code <- "data {
  int <lower=1> N; // number of players
  int <lower=1> l_n_session; // length of num_session_vec
  int <lower=0> n_session[l_n_session]; // num_session_vec
  int <lower=0> part_n_session[N+1]; // partition_num_session vector
  int <lower=1> l_duration; // length of duration_vec
  vector <lower=0>[l_duration] duration; // duration_vec
  vector <lower=0>[l_duration] starting; // starting_vec
  int <lower=0> part_dur[N+1]; // partition_dur vector
  int <lower=0> df_splines;
  matrix[l_duration, df_splines] bs_start;
  matrix[l_duration, df_splines] bs_surv;
  matrix[(l_duration*15), df_splines] nodes; // 'Gauss_Kronrod_nodes'
  vector <lower=-1, upper=1>[15] wt; // Gauss Kronrod weights
}

parameters {
  real z_beta0_pois;
  real z_beta0_logit;
  real z_beta0_surv;
  real <lower=0> sd_beta0_pois;
  real <lower=0> sd_beta0_logit;
  real <lower=0> sd_beta0_surv;
  matrix[3, N] z_b_mat;
  vector <lower=0>[3] b_sd;
  cholesky_factor_corr[3] b_cholesky;
  vector[df_splines] z_beta_bs_start;
  vector[df_splines] z_beta_bs_surv;
  vector <lower=0> [df_splines] sd_beta_bs_start;
  vector <lower=0> [df_splines] sd_beta_bs_surv;
}

transformed parameters {
  real beta0_pois;
  real beta0_logit;
  real beta0_surv;
  matrix[3, N] b_mat;
  vector[df_splines] beta_bs_start;
  vector[df_splines] beta_bs_surv;
  beta0_pois = sd_beta0_pois * z_beta0_pois;
  beta0_logit = sd_beta0_logit * z_beta0_logit;
  beta0_surv = sd_beta0_surv * z_beta0_surv;
  b_mat = diag_pre_multiply(b_sd, b_cholesky) * z_b_mat;
  beta_bs_start = sd_beta_bs_start .* z_beta_bs_start;
  beta_bs_surv = sd_beta_bs_surv .* z_beta_bs_surv;
}

model {
  real p;
  real lambda;
  real log_hazard;
  real log_cum_hazard;
  
  for (n in 1:N) {
    for (i in (part_n_session[n]+1):part_n_session[n+1]) {
      p = inv_logit(beta0_logit + b_mat[2, n]);
      lambda = exp(beta0_pois + b_mat[1, n]);
      if (n_session[i] == 0) {
      target += log_sum_exp(bernoulli_lpmf(1 | p), bernoulli_lpmf(0 | p) + poisson_lpmf(n_session[i] | lambda));
      }
      else {
      target += bernoulli_lpmf(0 | p) + poisson_lpmf(n_session[i] | lambda);
      }
    }
    for (j in (part_dur[n]+1):part_dur[n+1]) {
      log_hazard = dot_product(beta_bs_start, bs_start[j])*dot_product(beta_bs_surv, bs_surv[j]) + beta0_surv + b_mat[3, n];
      log_cum_hazard = exp(beta0_surv + b_mat[3, n])*(duration[j]*(.5))*dot_product(wt, exp(dot_product(beta_bs_start, bs_start[j])*(nodes[(15*(j-1) + 1):(15*j)]*beta_bs_surv)));
      target += log_hazard - log_cum_hazard;
    }
  }
  
  //priors
  
  target += normal_lpdf(z_beta0_pois | 0, 1);
  target += normal_lpdf(z_beta0_logit | 0, 1);
  target += normal_lpdf(z_beta0_surv | 0, 1);
  target += inv_gamma_lpdf(sd_beta0_pois*sd_beta0_pois | .5, .5);
  target += inv_gamma_lpdf(sd_beta0_logit*sd_beta0_logit | .5, .5);
  target += inv_gamma_lpdf(sd_beta0_surv*sd_beta0_surv | .5, .5);
  target += student_t_lpdf(b_sd | 1, 0, 1);
  target += normal_lpdf(to_vector(z_b_mat) | 0, 1);
  target += lkj_corr_cholesky_lpdf(b_cholesky | 1);
  target += normal_lpdf(z_beta_bs_start | 0, 1);
  target += normal_lpdf(z_beta_bs_surv | 0, 1);
  target += inv_gamma_lpdf(sd_beta_bs_start | .5, .5);
  target += inv_gamma_lpdf(sd_beta_bs_surv | .5, .5);
}
"
rstan_options(auto_write = TRUE)


model <- stan(model_code=code, 
              data = list(N = length(mod_num_session),
                          l_n_session = length(num_session_vec),
                          n_session = num_session_vec,
                          part_n_session = partition_num_session,
                          l_duration = length(duration_vec),
                          duration = duration_vec,
                          starting = starting_vec,
                          part_dur = partition_dur,
                          df_splines = 8,
                          bs_start = basis_splines_starting_mat,
                          bs_surv = basis_splines_survival_mat,
                          nodes = Gauss_Kronrod_nodes,
                          wt = c15),
              chains = 1,
              iter = 200,
              warmup = 50,
              control=list(adapt_delta=0.999, stepsize=0.001, max_treedepth=30))

I have one more query in this regard. If I use poisson_lpmf (vector1| vector2), will that result a vector where i-th component evaluated at ith value of vector1 for ith lambda in vector2? If yes, it will help to reduce a for loop.

I am not able to find a simple function for squaring all elements of a vector. pow() function does not work for pow(vector, int). Is there an easy way to do that?

Thanks,
Soumya.

Hi, is there a particular reseason that you used
adapt_delta=0.999, max_treedepth=30 in the control for stan?
How many seconds do gradient evaluation and 10 leapfrog steps take?

If the model itself converges fast, increasing the adapt_delta and max_treedepth will reduce the sampling speed.
However, if the gradient evaluation is slow, then we can try to improve the model, by vectorising/parameterising the code, or, as you mentioned, removing the second for loop.

For the pow question, you can use the rows_dot_self function. A vector self product its row elements for each column. For example:

stanfun <-"
functions{
 vector vec_pow(vector a){    
    return rows_dot_self(a);  
 }
}
"
expose_stan_functions(stanc(model_code = stanfun))
a <- rnorm(10)

vec_pow(a)
a^2

Thank you. :)

Thanks a lot for your kind help.
With the code in my last comment, rstan shows:

Gradient evaluation took 0.17 seconds
Chain 1: 1000 transitions using 10 leapfrog steps per transition would take 1700 seconds.

I just adjusted –

control=list(adapt_delta=0.99, stepsize=0.01, max_treedepth=10)

Now it shows:

Gradient evaluation took 0.092 seconds
Chain 1: 1000 transitions using 10 leapfrog steps per transition would take 920 seconds.

I would first check, whether the posterior is not pathological in some way - improving the posterior geometry can improve speed by a much larger factor than optimizing the computation. What is the distribution of treedepth of your iterations? (You can use nuts_params to get it in R, not sure about Python, but there certainly is a way)

Most of the hints at Divergent transitions - a primer are also applicable for debuggin pathlogical geometries, especially simplifying the model and trying to understand it better. I admit I have basically no experience with the math of the hazard/competing risk models, so can’t really help that much directly…

Also, if you end up deciding that you want to improve computation, than I would start with using the profiler (see e.g. Profiling Stan programs with CmdStanR • cmdstanr) to see which parts are actually slow.