Did i implement reduce sum correctly?

I have been trying to implement my model using reduce sum since without it, it can be quite slow. However, I am not 100% sure i implemented things correctly as the estimation time seems to vary quite heavily (although this could just be due to the data). I feel like maybe i need to be messing with the grainsize more, but this is my first time using reduce sum.

mod_cmd <- write_stan_file('
functions{

  real approximation(real log_lambda, real nu){
  
      real log_mu = log_lambda / nu;
      real nu_mu = nu * exp(log_mu);
      real nu2 = nu^2;
      // first 4 terms of the residual series
      real log_sum_resid = log1p(
        nu_mu^(-1) * (nu2 - 1) / 24 +
        nu_mu^(-2) * (nu2 - 1) / 1152 * (nu2 + 23) +
        nu_mu^(-3) * (nu2 - 1) / 414720 * (5 * nu2^2 - 298 * nu2 + 11237)
      );
      return nu_mu + log_sum_resid  -
            ((log(2 * pi()) + log_mu) * (nu - 1) / 2 + log(nu) / 2);  
  
  }
  real summation(real log_lambda, real nu){
  
    real z = negative_infinity();
    real z_last = 0;
    
    for (j in 0:130) {
      z_last = z;
      z = log_sum_exp(z, j * log_lambda  - nu * lgamma(j+1));

      if ((abs(z - z_last) < .001)) {
        break;
  
      }

    }

    
    return z;
  }
  
  
  real partial_sum(array[] int slice_n, int start, int end,
                   array[] int y_slice,
                   array[] int jj_slice, array[] int ii_slice, vector theta,
                   vector beta, vector nu, vector alpha, array[] int kk_slice, vector gamma) {
    real partial_target = 0.0;
    
    for (n in start : end) {
      
      real log_lambda = nu[ii_slice[n]]*(alpha[ii_slice[n]]*theta[jj_slice[n]]+ beta[ii_slice[n]] + gamma[kk_slice[n]]);
      real log_prob = 0;
      
      
      if (log_lambda / nu[ii_slice[n]] > log(1.5) && log_lambda > log(1.5)) {
        log_prob = y_slice[n] * log_lambda -
                   nu[ii_slice[n]] * lgamma(y_slice[n] + 1) -
                   approximation(log_lambda, nu[ii_slice[n]]);
      } else {
        log_prob = y_slice[n] * log_lambda -
                   nu[ii_slice[n]] * lgamma(y_slice[n] + 1) -
                   summation(log_lambda, nu[ii_slice[n]]);
      }
      
      partial_target += log_prob;
    }
    return partial_target;
  }
}

data {
  int<lower=0> I;
  int<lower=0> J;
  int<lower=1> N;
  int<lower=1> K;
  array[N] int<lower=1, upper=I> ii;
  array[N] int<lower=1, upper=J> jj;
  array[N] int<lower=1, upper=K> kk;
  array[I] int<lower=1,upper=K> item_type_for_beta;
  array[N] int<lower=0> y;
  array[N] int seq_N;
  int<lower=1> grainsize;
}

parameters {

  vector[J] theta;
  vector[I] beta;
  vector[K] gamma;
  vector<lower=0>[I] nu;  //.2
  vector<lower=0>[K] sigma_beta_k;
  
  vector<lower=0>[I] alpha;
  real<lower=0> sigma_alpha;
  
  real mu_gamma;
  real<lower=0> sigma_gamma;
}


model {
  
  nu ~ normal(0, 4);

  for (i in 1:I) {
    beta[i] ~ normal(gamma[item_type_for_beta[i]], sigma_beta_k[item_type_for_beta[i]]);
  }
  
  theta ~ normal(0, .3); 
  
  sigma_alpha ~ cauchy(0,1);
  alpha ~ lognormal(0,sigma_alpha);
  
  mu_gamma ~ normal(0,6);
  sigma_gamma ~ cauchy(0,2);
  gamma ~ normal(mu_gamma, sigma_gamma);

  target += reduce_sum(partial_sum, seq_N, grainsize, y, jj, ii, theta,
                       beta, nu, alpha, kk, gamma);
}

generated quantities {
  array[N] real log_lik;
  
  for (n in 1:N) {
      real log_lambda = nu[ii[n]] * (alpha[ii[n]]*theta[jj[n]] + beta[ii[n]] + gamma[kk[n]]);
      if (log_lambda / nu[ii[n]] > log(1.5) && log_lambda > log(1.5)) {
          log_lik[n] = y[n] * log_lambda - nu[ii[n]] * lgamma(y[n] + 1) - approximation(log_lambda, nu[ii[n]]);
      } else {
          log_lik[n] = y[n] * log_lambda - nu[ii[n]] * lgamma(y[n] + 1) - summation(log_lambda, nu[ii[n]]);
      }
  }
}

')

When i run things in R I am doing

mod2 <- cmdstan_model(mod_cmd, compile = T, cpp_options = list(stan_threads=T))

stan_data <- list(I = I,
                  J = J,
                  N = N,
                  K = K,
                  kk = kk,
                  item_type_for_beta = item_type_for_beta,
                  ii = ii,
                  jj = jj,
                  y = y,
                  grainsize =1,
                  seq_N = 1:N)
fit_cmp <- mod2$sample(data = stan_data, chains = 2,parallel_chains = 2, threads_per_chain = 7, iter_warmup = 500, iter_sampling = 2000, refresh = 10, thin = 1, max_treedepth = 10)

First, I’d check that you’re getting similar answers across reduce-sum runs as you do across runs without it. It’s generally considered easier to optimize a correct program than to debug an optimized one, or as Knuth is rumored to have said, “premature optimization is the root of all evil.”

This is often true if you use our random initializations, which are uniform between -2 and 2 on the unconstrained parameter scale. It’s not uncommon to see as much as a factor of 2 or more difference in speeds for different seeds. If you reduce the range from (-2, 2) to (-0.5, 0.5), it can often stabilize things at the risk of missing outlier modes, failing to diagnose bad mixing, etc. This is usually OK to do—@andrewgelman, for example, is urging us to use less diffuse initializations (which seems to be a reversal of decades of advice using people to use more diffuse initializations to debug poor mixing).

Reduce-sum is only likely to provide a big speed boost when the amount of work done on each thread (or process if using MPI) dominates the communication cost. For example, it’s almost never worth doing this for a simple GLM, but it’s almost always worth doing with nested ordinary differential equation models or if you have to do a bunch of matrix solves. All of the code inside your computations seems to be relatively simple arithmetic.