I am implementing the Latent Class Model for Capture-Recapture estimation described in this paper.
The analogy with capture-recapture is as follows: members of a population are observed on one or more “lists”, and the goal is to recover the total population size based on the amount of overlap across the lists. Unlike vanilla capture-recapture, we do not assume that the lists are independent, and to model the rich heterogeneity in real-life data, we assume that members of the population belong to unobserved latent classes that determine their probability of being seen on one or more lists.
To that end, we have the hierarchical generative process
x_j | z \sim \text{Bernoulli}(\lambda_{jz})
z \sim \text{Discrete}(\lbrace 1, 2, ...\rbrace, (\pi_1, \pi_2, ...))
\lambda_{jk} \sim \text{Beta}(1, 1)
(\pi_1, \pi_2, ...) \sim \text{SB}(\alpha)
\alpha \sim \text{Gamma}(a,b)
where \text{SB}(\alpha) is the stick-breaking prior with probability \alpha, and j=1, ..., J indexes over the “cells”, which are patterns of observation by the lists (ie, if there are three lists, then (0, 1, 0) would be a cell indicating that someone was observed only by the second list), and z is the latent class.
From some theoretical considerations, we fix a=b=1/4, and for computational purposes, we marginalize out z. In a naive direct implementation of this model, using simulated data, we find that 30% of transitions are divergent!
One possible candidate is label-switching (asked about in this thread), and we fix it by enforcing an ordering by creating permuted versions of \lambda and \pi in the transformed parameters block. This reduces the divergences to about 20% of transitions.
We find that divergences tend to occur when \alpha is small, inducing unequal splits in the stick-breaking. This could be due to numerical issues with our stick-breaking implementation, and I switch to a logarithmic implementation, which I saw looking at this reply in the above thread. This reduces the divergences further, to roughly about 15% of transitions.
Lastly, I saw that \alpha \sim \text{Gamma}(\epsilon, \epsilon) can be a difficult prior for Stan due to the asymmetry in the log density and derivative. Thus I switched over to using \alpha \sim \text{Exponential}(1), which has the same mean as \text{Gamma}(1/4, 1/4), but is slightly less asymmetric in log-space. This further reduced the divergences to less than 10%.
However, 10% of transitions ending in divergences is still a lot. I can eliminate the divergences entirely by fixing \alpha. For example, fixing \alpha=1 implies that in the prior distribution, two randomly selected individuals have a 50-50 chance of being in the same latent class, and is apparently a “common choice” when we want a small number of clusters relative to the sample size (BDA3 p.553), but I’m reluctant to lose the ability of the model to learn the number of latent classes from the data.
Throughout all of these iterations, all of the models perform roughly similarly to one another and to the Gibbs sampler code provided in the R LCMCR package, which is reassuring from a “multiverse analysis” point of view, but it still feels strange to either use a model with divergences, or to fix the number of latent classes ahead of time.
At this point, I have exhausted my bag of tricks and am not certain why there continue to be so many divergences (unless we fix \alpha). Do readers have any insights into why this is happening, and how to eliminate the divergences further?
I am using variations of the simulated data from the original paper:
data.sim.hetero1 <- rbind(
simulate_mce(2000 * 0.9, c(0.033, 0.033, 0.099, 0.132, 0.033)),
simulate_mce(2000 * 0.1, c(0.660, 0.825, 0.759, 0.990, 0.693))
) %>% filter(rowSums(across(everything())) > 0)
and fitting Stan with the following parameters:
fit_stan <- function(model, data, K=10, num.iter=2000, seed=19481210, chains=4, warmup=2000, adapt.delta=0.8) {
data.factor <- data.frame(lapply(data, factor))
stan_data_tabular <- data %>%
group_by_all() %>%
summarize(cell_count = n())
stan_data <- stan_data_tabular %>%
select(-cell_count) %>%
as.matrix()
stan_data_list <- list(J = ncol(stan_data),
C = nrow(stan_data_tabular),
list_indicators = stan_data,
cell_count = stan_data_tabular$cell_count,
K = K,
alpha = 1)
fit <- model$sample(data = stan_data_list,
seed = 19481210,
chains = chains,
parallel_chains = chains,
iter_warmup = warmup,
iter_sampling = num.iter,
adapt_delta = adapt.delta)
fit
}
and the model (fixing the label-switching, and using the exponential prior and logarithmic stick-breaking process) looks as follows:
data {
int<lower=1> J; // number of lists
int<lower=1> C; // number of observed cells in the dataset, up to 2^J-1
int list_indicators[C, J]; // indicators of being in lists
vector<lower=0>[C] cell_count; // cell count for each capture pattern
int<lower=1> K; // number of latent classes
}
transformed data {
real<lower=0> observed = sum(cell_count);
int zeros[J] = rep_array(0,J);
}
parameters {
matrix<lower=0,upper=1>[J, K] lambda; // list inclusion probabilities for each latent class
vector<lower=0,upper=1>[K-1] breaks; // break proportions for stick-breaking prior on pi
real<lower=observed> N;
real<lower=0> alpha; // stick-breaking prior parameter
}
transformed parameters {
matrix<lower=0,upper=1>[K, J] lambda_T; // list inclusion probabilities for each latent class
vector[C] log_cell_probability; // log cell probability for each observed capture pattern
real log_unobserved_cell_probability;
// https://mc-stan.org/docs/2_26/stan-users-guide/arithmetic-precision.html#underflow-and-the-log-scale
vector<lower=0,upper=1>[K] pi;
vector<upper=0>[K] log_pi;
vector[K] lps_unobserved;
log_pi[1] = log(breaks[1]);
{
for (k in 2:(K-1)) {
log_pi[k] = log(breaks[k]) + log1m(breaks[k-1]) - log(breaks[k-1]) + log_pi[k-1];
}
log_pi[K] = log1m(breaks[K-1]) - log(breaks[K-1]) + log_pi[K - 1];
}
// reorder latent classes by pi
for (i in 1:K) {
lambda_T[i] = col(lambda, sort_indices_desc(log_pi)[i])';
}
log_pi = log_pi[sort_indices_desc(log_pi)];
pi = exp(log_pi);
// continue computation
lps_unobserved = log_pi;
for (c in 1:C) {
vector[K] lps = log_pi;
for (k in 1:K) {
lps[k] += bernoulli_lpmf(list_indicators[c] | lambda_T[k]); // https://mc-stan.org/docs/2_26/functions-reference/vectorization.html#evaluating-vectorized-log-probability-functions
}
log_cell_probability[c] = log_sum_exp(lps);
}
for (k in 1:K) {
lps_unobserved[k] += bernoulli_lpmf(zeros | lambda_T[k]);
}
log_unobserved_cell_probability = log_sum_exp(lps_unobserved);
}
model {
target += lchoose(N, observed) + (N - observed)*log_unobserved_cell_probability + cell_count' * log_cell_probability;
target += -log(N);
breaks ~ beta(1, alpha);
alpha ~ exponential(1);
}