How can i speed up my estimation?

Here is my stan code. I am currently using a custom distribution so this may limit on what I can do to speed up the estimation time. When I = 5 and J = 150, the estimation only takes ~40 seconds. However, when i increase it to I = 5 and J = 200, it takes much much longer, seemingly to the point of being unusable. Is there anything i can do here to speed things up outside of using reduce_sum? Or if there are any mistakes or things i have overlooked that could be problematic.

After some further testing it seems to somewhat be related to the initial values. If the values it starts at are “bad” then it may get stuck. I am guessing it gets stuck in the summation function and never converges. Not sure what i can really do about this though, as even setting initial values, while helps, doesn’t completely solve the problem it seems . is there a way i can control things a bit so the estimation doesn’t get stuck?

mod_baseline <- stan_model(model_code = '
functions{

  real approximation(real lambda, real nu){
  
      real log_mu = log(lambda^(1/nu));
      real nu_mu = nu * exp(log_mu);
      real nu2 = nu^2;
      // first 4 terms of the residual series
      real log_sum_resid = log1p(
        nu_mu^(-1) * (nu2 - 1) / 24 +
        nu_mu^(-2) * (nu2 - 1) / 1152 * (nu2 + 23) +
        nu_mu^(-3) * (nu2 - 1) / 414720 * (5 * nu2^2 - 298 * nu2 + 11237)
      );
      return nu_mu + log_sum_resid  -
            ((log(2 * pi()) + log_mu) * (nu - 1) / 2 + log(nu) / 2);  
  
  }
  real summation(real lambda, real nu){
  
    real z = negative_infinity();
    real z_last = 0;
    real count = 0;
    
    for (j in 0:60) {
      z_last = z;
      z = log_sum_exp(z, j * log(lambda) - nu * lgamma(j+1));
      count = count + 1;
      if ((abs(z - z_last) < .001)) {
        break;
  
      }

    }
    return z;
  }
}

data {
  int<lower=0> I;
  int<lower=0> J;
  int<lower=1> N;
  int<lower=1,upper=I> ii[N];
  int<lower=1,upper=J> jj[N];
  int<lower=0> y[N];
}

parameters {
  real theta[J] ;
  real beta [I];
  real<lower=0> nu[I];
  
  real mu_beta;
  real<lower=0> sigma_beta;
}


model {
  
  nu ~ lognormal(0, 1);
  
  mu_beta ~ normal(0,5);
  sigma_beta ~ cauchy(0,2);
  beta ~ normal(mu_beta, sigma_beta); 
  
  theta ~ normal(0, .3); 
  
    real lambda [N];
    for(n in 1:N){
      lambda[n] = exp(nu[ii[n]]*(theta[jj[n]] + beta[ii[n]]));
          
      if(log(lambda[n]^(1/nu[ii[n]])) * nu[ii[n]] > log(1.5) && log(lambda[n]^(1/nu[ii[n]])) > log(1.5)){
  
        target += y[n] * log(lambda[n]) - nu[ii[n]] * lgamma(y[n] + 1) - approximation(lambda[n], nu[ii[n]]);
      }else{
        target += y[n] * log(lambda[n]) - nu[ii[n]] * lgamma(y[n] + 1) - summation(lambda[n], nu[ii[n]]);


    }
  }
} 

')

Here is my data simulation if needed.

library(COMPoissonReg)
library(rstan)
options(mc.cores = parallel::detectCores())
rstan_options(auto_write = TRUE)

I <- 5
J <- 60

theta <- rnorm(J,0,.3)
beta <- rnorm(I, 2, .5)
nu <- rlnorm(I,0,.5)

ii <- rep(1:I, times = J)
jj <- rep(1:J, each = I)

N <- J*I
y <- numeric(N)
for(n in 1:N){
  mu_n <- exp(nu[ii[n]]*(theta[jj[n]] + beta[ii[n]])) 
  y[n] <- rcmp(1, mu_n, nu[ii[n]])
 
mod <- sampling(mod_baseline, data = list(I = I,
                                          J = J,
                                          N = N,
                                          ii = ii,
                                          jj = jj,
                                          y = y), 
                iter = 2000, chains = 4, control = list(max_treedepth = 10), thin = 1, warmup=1000)
}

Just from a quick glance, the lambda’s are going to be very sensitive to the prior.

For instance, the prior on nu is a normal distribution that gets exponentiated twice. Just as a quick check compare, the effect of the prior in the Stan program

log_nu <- rnorm(1000, 0, 1)
hist(exp(exp(log_nu)))

to the prior in the R simulation

log_nu <- rnorm(1000, 0, 0.5)
hist(exp(exp(log_nu)))

I wouldn’t be surprised if the prior leads to overflow issues.

A speculative solution would be to write everything in terms of log lambda or effectively nu, beta, and theta. For instance, I think

\textrm{log}(\lambda) = \nu (\theta + \beta)
\textrm{log}(\lambda^{1/\nu}) \times \nu = \nu (\theta + \beta)

and thus

\textrm{log}(\lambda^{1/\nu}) = (\theta + \beta)
1 Like

I converted my code to use log(lambda) instead, however, I still suffer from the same issues. With small data size, the estimation is quite fast, but with any meaningful increase in data size, the estimation is a huge slog. I checked for overflow with print statements, and it doesn’t appear that any variables are at extreme values that would cause problems. The functions are also not outputting any extreme values.

For instance:
total N is 600, the runtime is roughly 30 seconds across 4 chains.
total N is 1200, the runtime is roughly 53 seconds across 4 chains.
total N is 1680, here is where some issues arise. 1 chain finished in 71 seconds, the other 3 take a long time, even after 10 minutes 1 of the chains is still at iteration 1, while the last 2 are 50% through.

This only seems to happen with larger N though where some chains get “stuck” or severely slowed down. I tried adding in some limits on the parameter ranges and this seems to help a bit, but is this the only realistic solution here?


mod_baseline <- stan_model(model_code = '
functions{

  real approximation(real log_lambda, real nu){
  
      real log_mu = log_lambda / nu;
      real nu_mu = nu * exp(log_mu);
      real nu2 = nu^2;
      // first 4 terms of the residual series
      real log_sum_resid = log1p(
        nu_mu^(-1) * (nu2 - 1) / 24 +
        nu_mu^(-2) * (nu2 - 1) / 1152 * (nu2 + 23) +
        nu_mu^(-3) * (nu2 - 1) / 414720 * (5 * nu2^2 - 298 * nu2 + 11237)
      );
      return nu_mu + log_sum_resid  -
            ((log(2 * pi()) + log_mu) * (nu - 1) / 2 + log(nu) / 2);  
  
  }
  real summation(real log_lambda, real nu){
  
    real z = negative_infinity();
    real z_last = 0;
    real count = 0;
    
    for (j in 0:50) {
      z_last = z;
      z = log_sum_exp(z, j * log_lambda  - nu * lgamma(j+1));
      
    
      
      count = count + 1;
      if ((fabs(z - z_last) < .001)) {
        break;
  
      }

    }

    
    return z;
  }
}

data {
  int<lower=0> I;
  int<lower=0> J;
  int<lower=1> N;
  int<lower=1,upper=I> ii[N];
  int<lower=1,upper=J> jj[N];
  int<lower=0> y[N];
}

parameters {
  vector<lower = -5, upper = 5>[J] theta;
  vector<lower = -8, upper = 8>[I] beta;
  vector<lower=0, upper = 5>[I] nu;
  
  real mu_beta;
  real<lower=0> sigma_beta;
}


model {
  
  nu ~ cauchy(0, 1);
  
  mu_beta ~ normal(0,5);
  sigma_beta ~ cauchy(0,2);
  beta ~ normal(mu_beta, sigma_beta); 
  
  theta ~ normal(0, .3); 

  
    real log_lambda [N];
    for(n in 1:N){
      //lambda[n] = exp(nu[ii[n]]*(theta[jj[n]] + beta[ii[n]]));
      log_lambda[n] = nu[ii[n]]*(theta[jj[n]] + beta[ii[n]]);
          //print("log_lambda[", n, "] = ", log_lambda[n]);
          //print("nu[", ii[n], "] = ", nu[ii[n]]);
          
      if(log_lambda[n] / nu[ii[n]] > log(1.5) && log_lambda[n] > log(1.5)) {
  
        target += y[n] * log_lambda[n] - nu[ii[n]] * lgamma(y[n] + 1) - approximation(log_lambda[n], nu[ii[n]]);
        //print(approximation(log_lambda[n], nu[ii[n]]));
      }else{
        target += y[n] * log_lambda[n] - nu[ii[n]] * lgamma(y[n] + 1) - summation(log_lambda[n], nu[ii[n]]);
        //print(summation(log_lambda[n], nu[ii[n]]));


    }
  }
} 

')