Appropriate count distribution for false detection in occupancy models

In case anyone else comes across this looking for a solution: after some more soul searching and bug fixing, I have a working solution. What I was missing from my code above (apart from the many bugs I missed) was that the contaminant element of the problem is not actually part of the occupancy problem, and it doesn’t make sense to think of these ‘unused’ combinations as samples - instead they are measures of the additional noise in the reads. Thus, we need to marginalise across the whole occupancy likelihood rather than just the part for Z = 1, which also allows us to factor the occupancy marginalisation within the loop properly to account for when Y is zero or not.

functions {
  real pois_occ_diff_lpmf(int Y, real psi, real p, real q, int max_iter) {
    int N_iter;
    if(Y < max_iter) {
      N_iter = Y;
    } else {
      N_iter = max_iter;
    }
    
    vector[N_iter + 1] weighted_probs;
    for(i in 0:N_iter){
      int Y_iter = Y - i;
      
      if(Y_iter > 0) {
        weighted_probs[i + 1] = bernoulli_logit_lpmf(1 | psi) + poisson_log_lpmf(Y_iter | p);
      } else {
        weighted_probs[i + 1] = log_sum_exp(
          bernoulli_logit_lpmf(1 | psi) + poisson_log_lpmf(Y_iter | p),
          bernoulli_logit_lpmf(0 | psi)
        );
      }
      
      weighted_probs[i + 1] += poisson_log_lpmf(i | q);
      
    }
    
    return log_sum_exp(weighted_probs);
  }
}
data {
  int N_samples;                    // Number of samples
  int N_controls;                   // Number of controls
  array[N_samples] int Y;           // Raw data 
  vector[N_samples] SF;             // Size factors
  array[N_controls] int Y_control;  // Raw data (controls)

  // Site occupancy predictors
  int N_pred_occ;                      // Number of occupancy predictors
  matrix[N_samples, N_pred_occ] X_occ; // Predictor matrix for occupancy
  
  // Read abundance predictors
  int N_pred_reads;                        // Number of read predictors
  matrix[N_samples, N_pred_reads] X_reads; // Predictor matrix for reads
}
parameters {
  // Site occupancy parameters
  vector[N_pred_occ] b_occ;

  // Read abundance parameters
  vector[N_pred_reads] b_reads;

  // Contamination parameters
  real q;
}
model {
  // Priors
  // Occupancy
  target += normal_lpdf(b_occ | 0, 3);

  // Read abundance 
  target += normal_lpdf(b_reads | 0, 3);

  // Contamination
  target += normal_lpdf(q | 0, 3);

  // Calculate predictors
  vector[N_samples] p = (X_reads * b_reads) + SF;
  vector[N_samples] psi = X_occ * b_occ;
  
  // Likelihood
  for(i in 1:N_samples){
    target += pois_occ_diff_lpmf(Y[i] | psi[i], p[i], q, 1000);
  }
  
  for(i in 1:N_controls) {
    target += poisson_log_lpmf(Y_control[i] | q);
  }
}

sim_fd_pois_data <- function(
  N_samples = 150,
  N_controls = 50,
  psi,
  a_contam,
  a_det
) {
  
  ## Generate simulated data
  X_samples <- tibble(Z = rbinom(N_samples, 1, psi),
                      SF = rgamma(N_samples, 5, 2)) |> 
    mutate(reads = rpois(N_samples, lambda = exp((Z * (SF + a_det))) + exp(a_contam)))
    
  X_contam <- tibble(
    reads = rpois(N_controls, lambda = exp(a_contam))
  )
  
  stan_data <- list(
    N_samples = N_samples,
    N_controls = N_controls,
    Y = X_samples$reads,
    SF = X_samples$SF,
    Y_control = X_contam$reads,
    N_pred_occ = 1,
    N_pred_reads = 1,
    X_occ = matrix(rep(1, N_samples)),
    X_reads = matrix(rep(1, N_samples))
  )
  
  return(list(
    true_X_samples = X_samples,
    true_X_contam = X_contam,
    stan_data = stan_data
  ))
}

sim_pois <- sim_fd_pois_data(
  N_samples = 150, 
  N_controls = 50, 
  psi = 0.5,
  a_det = 5,
  a_contam = 2
)

## Marginalised poisson
model_pois_marg <- cmdstan_model("stan/sgcp_fd_pois_marg.stan")
fit_pois_marg <- model_pois_marg$sample(data = sim_pois$stan_data, chains = 4, parallel_chains = 4)
variable  mean median      sd     mad    q5   q95  rhat ess_bulk ess_tail
psi      0.499  0.499 0.0401  0.0409  0.434 0.565  1.00    2533.    2602.
a_det    5.00   5.00  0.00190 0.00193 5.00  5.01   1.00    4712.    2776.
a_contam 2.06   2.06  0.0325  0.0333  2.01  2.12   1.00    2568.    2470.

The only issue I have is that the chains seem to have a lot of trouble getting to where they need to be - some chains take a lot more time than others. In the run above, they took 215.6s, 216.1s, 280.8s, and 288.6s. I am not sure why this is, and whether or not it is a problem - one possiblilty is that certain parameter combinations are just inherently slower in the marginalisation loop above, so some chains might just take longer? The other possibility is that the model has some hard to explore geometry and some chains get stuck - in which case, I am not sure what to do or whether it’s possible to reparameterise. I am concerned that these problems might be more problematic when moving from a simple Poisson intercept model to a complex multi-species model with Negative binomials…

Looking at the trace plots, there’s definitely a little bit of evidence of autocorrelation? Here’s the plot for psi: