And the error is something like log(0) cannot evaluate likelihood, right?
I’ve been able to get rid of these errors and speed the code up if we allow reals instead of ints. Since these hypergeometric distributions are working on ratios I think it’s ok to allow reals. What I do when I work with large “urn” values is that I divide by some amount 10/100/etc and work on the small scale. The parameter values I’m interested in - the “weighted odds” - stay the same but the code runs much, much faster.
For example, opening up the pmf to a pdf and letting the lbeta functions “interpolate” across the integers gives:
functions {
real multi_wallenius_integral(
real t, // Function argument
real xc, // Complement of function argument
// on the domain (defined later)
real[] theta, // parameters
real[] x_r, // data (real)
int[] x_i) { // data (integer)
real Dinv = 1 / theta[1];
int Cp1 = num_elements(x_r);
real v = 1;
for (i in 2:Cp1)
v *= pow(1 - t^(theta[i] * Dinv), x_r[i]);
return v;
}
real multi_walleniusc_lpdf(data real[] k, vector m, vector p,
data int[] x_i, data real tol) {
int C = num_elements(m);
real D = dot_product(to_row_vector(p), (m - to_vector(k[2:C + 1])));
real lp = log(integrate_1d(
multi_wallenius_integral, 0, 1,
append_array({D}, to_array_1d(p)),
k,
x_i,
tol));
for (i in 1:C)
lp += -log1p(m[i]) - lbeta(m[i] - k[i + 1] + 1, k[i + 1] + 1);
return lp;
}
}
data {
int<lower=0> N;
int<lower=0> C;
real y[N, C + 1];
vector[C] m;
real tol;
}
transformed data {
real x_r[0];
int x_i[0];
}
parameters {
simplex[C] probs;
}
model {
// for (i in 1:N)
// y[i] ~ multi_wallenius(m, probs, x_r, tol);
for (i in 1:N)
y[i] ~ multi_walleniusc(m, probs, x_i, tol);
}
To test
fp <- file.path("../multi_walleniusc.stan")
mod <- cmdstan_model(fp, force_recompile = T)
N <- 20
m <- c(2525, 1333, 888)
n <- 4234
odds <- c(0.2, 0.7, 0.1)
y <- rMWNCHypergeo(N, m, n, odds)
meanMWNCHypergeo(m, n, odds, precision=1E-7)
mod_out <- mod$sample(
data = list(N = N,
C = length(m),
y = cbind(n / 100, t(y / 100)),
m = m / 100,
tol = 0.01),
chains = 2,
init = 1,
adapt_delta = 0.8,
parallel_chains = 2,
iter_warmup = 200,
iter_sampling = 200
)
mod_out