Is there a way to vectorize my code that contains a custom distribution?

I am having difficulties with one of my models because it requires the use of a custom distribution that does not natively exist in Stan. However, when i do this, I am forced to loop through the likelihood instead of making use of vectorization with poisson_lmpf for example. The problem is that the runtime can be very long when the sample size increases. I am not sure what exactly is causing the long runtime, but what can i do to help? I have a version of reduce_sum that helps quite a bit, but I still would need more reduction in runtime than reduce_sum can provide. I suspect that going from 1:N and looping is what causes the runtime, but i am not sure if i can avoid that in my case.

mod_baseline <- stan_model(model_code = "
functions {
    real log_Z_com_poisson_approx(real lambda, real nu) {
      
      real log_mu = log(lambda^(1/nu));
      real nu_mu = nu * exp(log_mu);
      real nu2 = nu^2;
      // first 4 terms of the residual series
      real log_sum_resid = log1p(
        nu_mu^(-1) * (nu2 - 1) / 24 +
        nu_mu^(-2) * (nu2 - 1) / 1152 * (nu2 + 23) +
        nu_mu^(-3) * (nu2 - 1) / 414720 * (5 * nu2^2 - 298 * nu2 + 11237)
      );
      return nu_mu + log_sum_resid  -
      ((log(2 * pi()) + log_mu) * (nu - 1) / 2 + log(nu) / 2);
    }

    real compute_log_z(real lambda, real nu, real log_error) {
  
      real z = negative_infinity();
      real z_last = 0;
      real j = 0;

      while ((abs(z - z_last) > log_error) && j < 100) {
        z_last = z;
        z = log_sum_exp(z, j * log(lambda) - nu * lgamma(j+1));
        j = j + 1;

      }
      return(z);
    }

}
data {
  int<lower=0> I;
  int<lower=0> J;
  int<lower=1> N;
  int<lower=1,upper=I> ii[N];
  int<lower=1,upper=J> jj[N];
  int<lower=1,upper=I> dd[N];
  int<lower=0> y[N];
}
parameters {
  vector[J] theta;
  vector[I-1] b_free;
  vector<lower=0>[I] nu;
}
transformed parameters{
  vector[I] beta = append_row(b_free, -sum(b_free));
  real log_error = .001;
  
}
model {

  theta ~ normal(0, 1);

  target += normal_lpdf(beta | 0, 1);
  
  nu ~ lognormal(0,.5);

  for (n in 1:N) {
  
    real lambda = exp(nu[dd[n]] * (theta[jj[n]] + beta[ii[n]]));
    
    
      if(log(lambda^(1/nu[dd[n]])) * nu[dd[n]] > log(1.5) && log(lambda^(1/nu[dd[n]])) > log(1.5)){

        target+= y[n] * log(lambda) - nu[dd[n]] * lgamma(y[n]+1) - log_Z_com_poisson_approx(lambda, nu[dd[n]]);
        
    } else {
    
        target+= y[n] * log(lambda) - nu[dd[n]] * lgamma(y[n]+1) - compute_log_z(lambda, nu[dd[n]], log_error);
    }
  }

}
 
")

One primary way that vectorized distributions can be faster than non-vectorized versions is when there is some quantity that is constant across observations and particularly time-intensive to calculate. Unfortunately, I don’t see any obvious examples of impactful calculations in your code, since lambda is specific to each n of N observations, so is nu_mu by extension. However, there are a lot of smaller examples that might add up to a big effect.

Some examples…

In compute_log_z(), you calculate log(lambda) every iteration of the while loop. You can instead calculate this once before the loop.

You could also pre-calculate lgamma(y[n]+1) in the transformed data block rather than calculating it N times each iteration. Since the gamma function is related to factorial for integers, you could add log(j+1) to a running sum with each iteration of the loop rather than calculate lgamma(j+1) each iteration.

In a vectorized version of log_Z_com_poisson_approx() you could calculate log(2 * pi()) * (nu - 1) / 2 + log(nu) / 2) once for each value of nu rather than N times.

A bit clunkier, but you could also pre-calculate part of the expression below for each value of nu.

real log_sum_resid = log1p(
        nu_mu^(-1) * (nu2 - 1) / 24 +
        nu_mu^(-2) * (nu2 - 1) / 1152 * (nu2 + 23) +
        nu_mu^(-3) * (nu2 - 1) / 414720 * (5 * nu2^2 - 298 * nu2 + 11237)
      );
vector[I] nu2 = nu.^2;
matrix[I,3] nu3;
nu3[,1] = (nu2 - 1) / 24;
nu3[,2] = (nu2 - 1) / 1152 + (nu2 + 23);
nu3[,3] = (nu2 - 1) / 414720 * (5 * nu2^2 - 298 * nu2 + 11237);

More generally, you could vectorize a lot of your calculations using dot notation (example below). I don’t know if this has any computational benefit, but it might reveal more places where you repeat calculations.

vector[N] lambda = exp( nu[dd] .* (theta[jj] + beta[ii]) );

You might also want to see how much time is being used is compute_log_z() compared to log_Z_com_poisson_approx(). For that, you could look at Stan’s profiling functionality, see here: Profiling Stan programs with CmdStanR • cmdstanr

2 Likes