I am having difficulties with one of my models because it requires the use of a custom distribution that does not natively exist in Stan. However, when i do this, I am forced to loop through the likelihood instead of making use of vectorization with poisson_lmpf for example. The problem is that the runtime can be very long when the sample size increases. I am not sure what exactly is causing the long runtime, but what can i do to help? I have a version of reduce_sum that helps quite a bit, but I still would need more reduction in runtime than reduce_sum can provide. I suspect that going from 1:N and looping is what causes the runtime, but i am not sure if i can avoid that in my case.
mod_baseline <- stan_model(model_code = "
functions {
real log_Z_com_poisson_approx(real lambda, real nu) {
real log_mu = log(lambda^(1/nu));
real nu_mu = nu * exp(log_mu);
real nu2 = nu^2;
// first 4 terms of the residual series
real log_sum_resid = log1p(
nu_mu^(-1) * (nu2 - 1) / 24 +
nu_mu^(-2) * (nu2 - 1) / 1152 * (nu2 + 23) +
nu_mu^(-3) * (nu2 - 1) / 414720 * (5 * nu2^2 - 298 * nu2 + 11237)
);
return nu_mu + log_sum_resid -
((log(2 * pi()) + log_mu) * (nu - 1) / 2 + log(nu) / 2);
}
real compute_log_z(real lambda, real nu, real log_error) {
real z = negative_infinity();
real z_last = 0;
real j = 0;
while ((abs(z - z_last) > log_error) && j < 100) {
z_last = z;
z = log_sum_exp(z, j * log(lambda) - nu * lgamma(j+1));
j = j + 1;
}
return(z);
}
}
data {
int<lower=0> I;
int<lower=0> J;
int<lower=1> N;
int<lower=1,upper=I> ii[N];
int<lower=1,upper=J> jj[N];
int<lower=1,upper=I> dd[N];
int<lower=0> y[N];
}
parameters {
vector[J] theta;
vector[I-1] b_free;
vector<lower=0>[I] nu;
}
transformed parameters{
vector[I] beta = append_row(b_free, -sum(b_free));
real log_error = .001;
}
model {
theta ~ normal(0, 1);
target += normal_lpdf(beta | 0, 1);
nu ~ lognormal(0,.5);
for (n in 1:N) {
real lambda = exp(nu[dd[n]] * (theta[jj[n]] + beta[ii[n]]));
if(log(lambda^(1/nu[dd[n]])) * nu[dd[n]] > log(1.5) && log(lambda^(1/nu[dd[n]])) > log(1.5)){
target+= y[n] * log(lambda) - nu[dd[n]] * lgamma(y[n]+1) - log_Z_com_poisson_approx(lambda, nu[dd[n]]);
} else {
target+= y[n] * log(lambda) - nu[dd[n]] * lgamma(y[n]+1) - compute_log_z(lambda, nu[dd[n]], log_error);
}
}
}
")