Pesky divergences in latent class model

I am implementing the Latent Class Model for Capture-Recapture estimation described in this paper.

The analogy with capture-recapture is as follows: members of a population are observed on one or more “lists”, and the goal is to recover the total population size based on the amount of overlap across the lists. Unlike vanilla capture-recapture, we do not assume that the lists are independent, and to model the rich heterogeneity in real-life data, we assume that members of the population belong to unobserved latent classes that determine their probability of being seen on one or more lists.

To that end, we have the hierarchical generative process

x_j | z \sim \text{Bernoulli}(\lambda_{jz})

z \sim \text{Discrete}(\lbrace 1, 2, ...\rbrace, (\pi_1, \pi_2, ...))

\lambda_{jk} \sim \text{Beta}(1, 1)

(\pi_1, \pi_2, ...) \sim \text{SB}(\alpha)

\alpha \sim \text{Gamma}(a,b)

where \text{SB}(\alpha) is the stick-breaking prior with probability \alpha, and j=1, ..., J indexes over the “cells”, which are patterns of observation by the lists (ie, if there are three lists, then (0, 1, 0) would be a cell indicating that someone was observed only by the second list), and z is the latent class.

From some theoretical considerations, we fix a=b=1/4, and for computational purposes, we marginalize out z. In a naive direct implementation of this model, using simulated data, we find that 30% of transitions are divergent!

One possible candidate is label-switching (asked about in this thread), and we fix it by enforcing an ordering by creating permuted versions of \lambda and \pi in the transformed parameters block. This reduces the divergences to about 20% of transitions.

We find that divergences tend to occur when \alpha is small, inducing unequal splits in the stick-breaking. This could be due to numerical issues with our stick-breaking implementation, and I switch to a logarithmic implementation, which I saw looking at this reply in the above thread. This reduces the divergences further, to roughly about 15% of transitions.

Lastly, I saw that \alpha \sim \text{Gamma}(\epsilon, \epsilon) can be a difficult prior for Stan due to the asymmetry in the log density and derivative. Thus I switched over to using \alpha \sim \text{Exponential}(1), which has the same mean as \text{Gamma}(1/4, 1/4), but is slightly less asymmetric in log-space. This further reduced the divergences to less than 10%.

However, 10% of transitions ending in divergences is still a lot. I can eliminate the divergences entirely by fixing \alpha. For example, fixing \alpha=1 implies that in the prior distribution, two randomly selected individuals have a 50-50 chance of being in the same latent class, and is apparently a “common choice” when we want a small number of clusters relative to the sample size (BDA3 p.553), but I’m reluctant to lose the ability of the model to learn the number of latent classes from the data.

Throughout all of these iterations, all of the models perform roughly similarly to one another and to the Gibbs sampler code provided in the R LCMCR package, which is reassuring from a “multiverse analysis” point of view, but it still feels strange to either use a model with divergences, or to fix the number of latent classes ahead of time.

At this point, I have exhausted my bag of tricks and am not certain why there continue to be so many divergences (unless we fix \alpha). Do readers have any insights into why this is happening, and how to eliminate the divergences further?

I am using variations of the simulated data from the original paper:

data.sim.hetero1 <- rbind(
  simulate_mce(2000 * 0.9, c(0.033, 0.033, 0.099, 0.132, 0.033)),
  simulate_mce(2000 * 0.1, c(0.660, 0.825, 0.759, 0.990, 0.693))
) %>% filter(rowSums(across(everything())) > 0)

and fitting Stan with the following parameters:

fit_stan <- function(model, data, K=10, num.iter=2000, seed=19481210, chains=4, warmup=2000, adapt.delta=0.8) {
  data.factor <- data.frame(lapply(data, factor))
  
  stan_data_tabular <- data %>% 
    group_by_all() %>%
    summarize(cell_count = n())
  
  stan_data <- stan_data_tabular %>%
    select(-cell_count) %>%
    as.matrix()
  
  stan_data_list <- list(J = ncol(stan_data),
                         C = nrow(stan_data_tabular),
                         list_indicators = stan_data,
                         cell_count = stan_data_tabular$cell_count,
                         K = K,
                         alpha = 1)
  
  fit <- model$sample(data = stan_data_list,
                                seed = 19481210,
                                chains = chains,
                                parallel_chains = chains,
                                iter_warmup = warmup,
                                iter_sampling = num.iter,
                                adapt_delta = adapt.delta)

  fit
}

and the model (fixing the label-switching, and using the exponential prior and logarithmic stick-breaking process) looks as follows:


data {
  int<lower=1> J; // number of lists
  int<lower=1> C; // number of observed cells in the dataset, up to 2^J-1
  int list_indicators[C, J]; // indicators of being in lists
  vector<lower=0>[C] cell_count; // cell count for each capture pattern
  int<lower=1> K; // number of latent classes
}


transformed data {
  real<lower=0> observed = sum(cell_count);
  int zeros[J] = rep_array(0,J);
}


parameters {
  matrix<lower=0,upper=1>[J, K] lambda; // list inclusion probabilities for each latent class
  vector<lower=0,upper=1>[K-1] breaks; // break proportions for stick-breaking prior on pi

  real<lower=observed> N;
  real<lower=0> alpha; // stick-breaking prior parameter
}


transformed parameters {
  matrix<lower=0,upper=1>[K, J] lambda_T; // list inclusion probabilities for each latent class
  vector[C] log_cell_probability; // log cell probability for each observed capture pattern
  real log_unobserved_cell_probability;
  // https://mc-stan.org/docs/2_26/stan-users-guide/arithmetic-precision.html#underflow-and-the-log-scale
  vector<lower=0,upper=1>[K] pi;
  vector<upper=0>[K] log_pi; 
  vector[K] lps_unobserved;
  
  log_pi[1] = log(breaks[1]); 
  {
    for (k in 2:(K-1)) {
      log_pi[k] = log(breaks[k]) + log1m(breaks[k-1]) - log(breaks[k-1]) + log_pi[k-1];
    } 
    log_pi[K] = log1m(breaks[K-1]) - log(breaks[K-1]) + log_pi[K - 1];
  }
  
  // reorder latent classes by pi
  for (i in 1:K) {
    lambda_T[i] = col(lambda, sort_indices_desc(log_pi)[i])';
  }
  
  log_pi = log_pi[sort_indices_desc(log_pi)];
  pi = exp(log_pi);
  
  // continue computation
  lps_unobserved = log_pi;
  for (c in 1:C) {
    vector[K] lps = log_pi;
    for (k in 1:K) {
      lps[k] += bernoulli_lpmf(list_indicators[c] | lambda_T[k]); // https://mc-stan.org/docs/2_26/functions-reference/vectorization.html#evaluating-vectorized-log-probability-functions
    }
    log_cell_probability[c] = log_sum_exp(lps);
  }
  for (k in 1:K) {
    lps_unobserved[k] += bernoulli_lpmf(zeros | lambda_T[k]);
  }
  log_unobserved_cell_probability = log_sum_exp(lps_unobserved);
}

model {
  target += lchoose(N, observed) + (N - observed)*log_unobserved_cell_probability + cell_count' * log_cell_probability;
  target += -log(N);
  
  breaks ~ beta(1, alpha);
  alpha ~ exponential(1);
}

I don’t think you’ve done anything wrong, and I don’t know of any tricks to reduce divergences further. The implementation I posted in the other thread was as optimized as I could make it for best practices in Stan at the time. Michael Betancourt has written a lot about how mixture models are very hard to actually fit (see e.g. Identifying Bayesian Mixture Models), and the same goes for the infinite latent class model.

I have a lot of experience with this particular model, and my big takeaway is that it’s ill-posed (The folk theorem of statistical computing | Statistical Modeling, Causal Inference, and Social Science). See this section 3.4.3 of this paper (https://arxiv.org/pdf/2112.01594.pdf) and section 3.2/appendix d of this paper (https://arxiv.org/pdf/2101.09304.pdf). The first paper points out issues with the implementation in LCMCR (that I’ve seen myself), and the second paper points out issues with identifiability (using the strict statistical definition).

If you’re using this model in an actual application and not just an academic paper, I’d heavily reconsider using it.

1 Like