Multithreading for a multistate capture-recapture model

I am setting up an individual-based multistate capture-recapture model in stan. Multistate models are notoriously slow, plus there are many individuals (more than 48k!) in this particular dataset. To speed it up, I would like to set up multithreading using reduce_sum, but I cannot find good examples of the partial_sum code in the function block, when latent states are marginalized out, as is common for latent state models in stan. I need help with the partial_sums function for the following likelihood:

for (i in 1 : nind) {  
    for (k in 1 : 4) {
      gamma[first[i], k] = y[i, first[i]] == k;
    }
for (t in (first[i] + 1) : n_occasions) {
    for (k in 1 : 4) {
      for (j in 1 : 4) {
        acc[j] = gamma[t - 1, j] * ps[j, i, t - 1, k]
                 * po[k, i, t - 1, y[i, t]];
      }
      gamma[t, k] = sum(acc);
    }
  }
  target += log(sum(gamma[n_occasions]));
}

Example model found here: example-models/BPA/Ch.09/ms3_multinomlogit.stan at master · stan-dev/example-models · GitHub

Hi @wesmith2 , I’ve used multi-threading with multievent models before (e.g. here) but that may not be the simplest model to begin with. Below is a Stan program with a reduce_sum implementation of the example model you linked to:

// Modification of code at 
// https://github.com/stan-dev/example-models/blob/master/BPA/Ch.09/ms3_multinomlogit.stan

// -------------------------------------------------
// States (S):
// 1 alive at A
// 2 alive at B
// 3 alive at C
// 4 dead
// Observations (O):
// 1 seen at A
// 2 seen at B
// 3 seen at C
// 4 not seen
// -------------------------------------------------

functions {
  /**
   * Return an integer value denoting occasion of first capture.
   * This function is derived from Stan Modeling Language
   * User's Guide and Reference Manual.
   *
   * @param y         Observed values
   * @return Occasion of first capture
   */
  int first_capture(array[] int y_i) {
    for (k in 1 : size(y_i)) {
      if (y_i[k] != 4) {
        return k;
      }
    }
    return 0;
  }
  
  /**
   * Return a simplex such as follows (thanks to Bob Carpenter):
   * p[1] <- exp(lp[1]) / (1.0 + exp(lp[1]) + exp(lp[2]));
   * p[2] <- exp(lp[2]) / (1.0 + exp(lp[1]) + exp(lp[2]));
   * p[3] <- 1.0 - p[1] - p[2];
   *
   * @param lp   N-dimension vector
   * @return (N+1)-simplex of given vector and 0
   */
  vector softmax_0(vector lp) {
    vector[num_elements(lp) + 1] lp_temp;
    
    lp_temp[1 : num_elements(lp)] = lp;
    lp_temp[num_elements(lp) + 1] = 0;
    return softmax(lp_temp);
  }
  
    /**
   * Return log-likelihood from subset of capture histories
    *
   * @param y_slice  the subset of capture histories
   * @param start  index (in y) of first capture history in subset
   * @param end  index (in y) of last capture history in subset
   * @param n_occasion  number of occasions
   * @param first  first capture occasions
   * @param ps  array of state transition probabilities
   * @param po  array of state-dependent detection probabilities
   * @return log-likelihood for the given subset of capture histories
   */
  real partial_sum(
    array[,] int y_slice,
    int start,
    int end,
    int n_occasions,
    array[] int first,
    array[,,] vector ps,
    array[,,] vector po
  ) {
    real ll_term = 0;
    // Likelihood using forward algorithm derived from Stan Modeling Language
    // User's Guide and Reference Manual
    for (i in start:end) {
      int iy = i - start + 1;
      array[4] real acc;
      array[n_occasions] vector[4] gamma;
      if (first[i]>0) {
        for (k in 1:4) {
          gamma[first[i], k] = y_slice[iy, first[i]] == k;
        }
        for (t in (first[i] + 1):n_occasions) {
          for (k in 1:4) {
            for (j in 1:4) {
              acc[j] = gamma[t - 1, j] * ps[j, i, t - 1, k]
                     * po[k, i, t - 1, y_slice[iy, t]];  
            }
            gamma[t, k] = sum(acc);
          }
        }
        ll_term = ll_term + log(sum(gamma[n_occasions])); 
      }
    }
    return ll_term;
  }
}
data {
  int<lower=0> nind;
  int<lower=0> n_occasions;
  array[nind, n_occasions] int<lower=1, upper=4> y; 
}
transformed data {
  int n_occ_minus_1 = n_occasions - 1;
  array[nind] int<lower=0, upper=n_occasions> first;
  
  for (i in 1 : nind) {
    first[i] = first_capture(y[i]);
  }
}
parameters {
  real<lower=0, upper=1> phiA; // Survival probability at site A
  real<lower=0, upper=1> phiB; // Survival probability at site B
  real<lower=0, upper=1> phiC; // Survival probability at site C
  real<lower=0, upper=1> pA; // Recapture probability at site A
  real<lower=0, upper=1> pB; // Recapture probability at site B
  real<lower=0, upper=1> pC; // Recapture probability at site C
  vector[2] lpsiA; // Logit of movement probability from site A
  vector[2] lpsiB; // Logit of movement probability from site B
  vector[2] lpsiC; // Logit of movement probability from site C
}
transformed parameters {
  simplex[3] psiA; // Movement probability from site A
  simplex[3] psiB; // Movement probability from site B
  simplex[3] psiC; // Movement probability from site C
  array[4, nind, n_occ_minus_1] simplex[4] ps;
  array[4, nind, n_occ_minus_1] simplex[4] po;
  
  // Constrain the transitions such that their sum is < 1
  psiA = softmax_0(lpsiA);
  psiB = softmax_0(lpsiB);
  psiC = softmax_0(lpsiC);
  
  // Define state-transition and observation matrices
  for (i in 1 : nind) {
    // Define probabilities of state S(t+1) given S(t)
    for (t in 1 : n_occ_minus_1) {
      ps[1, i, t, 1] = phiA * psiA[1];
      ps[1, i, t, 2] = phiA * psiA[2];
      ps[1, i, t, 3] = phiA * psiA[3];
      ps[1, i, t, 4] = 1.0 - phiA;
      ps[2, i, t, 1] = phiB * psiB[1];
      ps[2, i, t, 2] = phiB * psiB[2];
      ps[2, i, t, 3] = phiB * psiB[3];
      ps[2, i, t, 4] = 1.0 - phiB;
      ps[3, i, t, 1] = phiC * psiC[1];
      ps[3, i, t, 2] = phiC * psiC[2];
      ps[3, i, t, 3] = phiC * psiC[3];
      ps[3, i, t, 4] = 1.0 - phiC;
      ps[4, i, t, 1] = 0.0;
      ps[4, i, t, 2] = 0.0;
      ps[4, i, t, 3] = 0.0;
      ps[4, i, t, 4] = 1.0;
      
      // Define probabilities of O(t) given S(t)
      po[1, i, t, 1] = pA;
      po[1, i, t, 2] = 0.0;
      po[1, i, t, 3] = 0.0;
      po[1, i, t, 4] = 1.0 - pA;
      po[2, i, t, 1] = 0.0;
      po[2, i, t, 2] = pB;
      po[2, i, t, 3] = 0.0;
      po[2, i, t, 4] = 1.0 - pB;
      po[3, i, t, 1] = 0.0;
      po[3, i, t, 2] = 0.0;
      po[3, i, t, 3] = pC;
      po[3, i, t, 4] = 1.0 - pC;
      po[4, i, t, 1] = 0.0;
      po[4, i, t, 2] = 0.0;
      po[4, i, t, 3] = 0.0;
      po[4, i, t, 4] = 1.0;
    }
  }
}
model {
  array[4] real acc;
  array[n_occasions] vector[4] gamma;
  
  // Priors
  // Survival and recapture: uniform
  // Uniform priors are implicitly defined.
  //  phiA ~ uniform(0, 1);
  //  phiB ~ uniform(0, 1);
  //  phiC ~ uniform(0, 1);
  //  pA ~ uniform(0, 1);
  //  pB ~ uniform(0, 1);
  //  pC ~ uniform(0, 1);
  
  // Normal priors on logit of all but one transition probs
  lpsiA ~ normal(0, sqrt(1000));
  lpsiB ~ normal(0, sqrt(1000));
  lpsiC ~ normal(0, sqrt(1000));
  
  // Likelihood using reduce_sum
  int grainsize = 1;
  target += reduce_sum(partial_sum, y, grainsize, n_occasions, first, ps, po);
}

The likelihood calculation using the forward algorithm gets moved into the partial_sum function. If I understand correctly (based on material here and here), partial_sum gets fed a subset of capture histories (something like y[101:200,]) as the argument y_slice and returns the log-likelihood for the whole subset. One has to be careful that indices into y_slice and other objects first, ps, po correspond to the same individual. In the loop over individuals for (i in start:end){...}, the intention is that i indexes into first, ps, po, whilst iy picks out the corresponding individual in y_slice.

I tested the above with the following R code to simulate data under the model and fit it (assuming you have the above Stan program saved as multistate_parallel.stan). I just ran one chain with a few threads to check that it was working; of course you’ll want several chains for serious work.

# Testing the Stan code
library(cmdstanr)

# Simulate some data

# numbers captured
nind_per_occasion <- 50
n_occasions <- 5
nind <- nind_per_occasion * (n_occasions - 1)
first <- rep(1:(n_occasions -1 ), each = nind_per_occasion)
first_state <- sample(1:3, nind, replace = TRUE) # sample first states randomly

# Set parameter values
phiA <- 0.8
phiB <- 0.5
phiC <- 0.6
pA <- 0.7
pB <- 0.9
pC <- 0.6
psiA <- c(0.8, 0.1, 0.1)
psiB <- c(0.1, 0.8, 0.1)
psiC <- c(0.1, 0.1, 0.8)

# Calculate state transition and observation probabilities
ps <- array(dim = c(4, nind, n_occasions-1, 4))
po <- array(dim = c(4, nind, n_occasions-1, 4))
for (i in 1 : nind) {
  # Define probabilities of state S(t+1) given S(t)
  for (t in 1 : (n_occasions - 1)) {
    ps[1, i, t, 1] = phiA * psiA[1];
    ps[1, i, t, 2] = phiA * psiA[2];
    ps[1, i, t, 3] = phiA * psiA[3];
    ps[1, i, t, 4] = 1.0 - phiA;
    ps[2, i, t, 1] = phiB * psiB[1];
    ps[2, i, t, 2] = phiB * psiB[2];
    ps[2, i, t, 3] = phiB * psiB[3];
    ps[2, i, t, 4] = 1.0 - phiB;
    ps[3, i, t, 1] = phiC * psiC[1];
    ps[3, i, t, 2] = phiC * psiC[2];
    ps[3, i, t, 3] = phiC * psiC[3];
    ps[3, i, t, 4] = 1.0 - phiC;
    ps[4, i, t, 1] = 0.0;
    ps[4, i, t, 2] = 0.0;
    ps[4, i, t, 3] = 0.0;
    ps[4, i, t, 4] = 1.0;
    
    # Define probabilities of O(t) given S(t)
    po[1, i, t, 1] = pA;
    po[1, i, t, 2] = 0.0;
    po[1, i, t, 3] = 0.0;
    po[1, i, t, 4] = 1.0 - pA;
    po[2, i, t, 1] = 0.0;
    po[2, i, t, 2] = pB;
    po[2, i, t, 3] = 0.0;
    po[2, i, t, 4] = 1.0 - pB;
    po[3, i, t, 1] = 0.0;
    po[3, i, t, 2] = 0.0;
    po[3, i, t, 3] = pC;
    po[3, i, t, 4] = 1.0 - pC;
    po[4, i, t, 1] = 0.0;
    po[4, i, t, 2] = 0.0;
    po[4, i, t, 3] = 0.0;
    po[4, i, t, 4] = 1.0;
  }
}

# Simulate true latent states
z <- matrix(NA, nrow = nind, ncol = n_occasions)
for (i in 1:nind) {
  z[i, first[i]] <- first_state[i]
  for (t in first[i]:(n_occasions - 1)) {
    z[i,t+1] <- sample(1:4, 1, prob = ps[z[i,t], i, t, ])
  }
}

# Simulate the observed states
y <- matrix(4, nrow = nind, ncol = n_occasions)  # '4' is not-captured code
for (i in 1:nind) {
  y[i, first[i]] <- z[i, first[i]]  # captured in true state at first capture
  for (t in (first[i]+1):n_occasions) {
    y[i,t] <- sample(1:4, 1, prob = po[z[i,t], i, t-1, ])  # time index for 'po' is one behind y
  }
}

# Fit the model using within-chain parallelisation
library(cmdstanr)
file <- "multistate_parallel.stan"
mod <- cmdstan_model(file, cpp_options = list(stan_threads = TRUE))
stan_data <- list(nind=nind, n_occasions=n_occasions, y=y)
fit <- mod$sample(stan_data, chains = 1, threads_per_chain = 4)

BEWARE! In the small dataset generated from the above R code, the reduce_sum() implementation was slower than the original code without it. I dimly recall that there might be some overhead to multi-threading which might explain that (but for larger datasets it would still be profitable to use within-chain parallelization?). The other explanation is that I’ve done something silly - so use with caution!

Multistate models need not be slow in Stan. The example code you sent can be dramatically sped up by aggregating capture histories in the case where individuals share common parameters of survival and detection. That is the approach my co-authors and I took for the time-integrated migration survival model where temporal strata are treated as states: https://doi.org/10.1111/biom.13171

We’ve been fitting these models in reasonable time for datasets involving 500K+ individuals and 200+ states.

See here for code that could be adapted: Columbia River Research Lab - Quantitative Fisheries Ecology Section / Temporally Stratified Space-for-Time CJS · GitLab

However, if individuals do not share common parameters (for example, if survival, transition and/or detection parameters are functions of individual covariates) then you’ll need to write functions that take in the regression coefficients and covariate matrices (or random effects) and then compute the survival and transition matrices on the fly for each individual. In any case, the forward algorithm and in particular the approach of explicitly including the death state are not necessarily the most efficient way to fit these models. Because mark-recapture models are factorizable at each positive detection, all you need to do is find the likelihood for each pair of detections and include a term for the probability not seeing an individual again after the last occasion. The following set of functions should be a bit more efficient for calculating individual-specific likelihoods because they rely on vectorization of the forward algorithm. You’d have to come up with another function to slice the capture histories among individuals and calculate arrays of transition matrices and detection probabilities.

 /**
 * Returns an integer corresponding to the number of non-zero
 * elements of an individual capture-history.
 *
 * @param x_i Integer array representing a capture history
 * @return number of non-zero entries.
 */
  
  int n_detections(array[] int x_i){
    int n_det = 0;
    
    for (k in 1:size(x_i))
      if (x_i[k] > 0)
        n_det += 1;
    
    return(n_det);
  }
  
 /**
 * Returns an integer array of the index of each non-zero
 * element of an individual capture-history.
 *
 * @param x_i Integer array representing a capture history
 * @return indices of non-zero entries.
 */  

  array[] int detection_indices(int n, array[] int x_i){
    array[n] int det_idxs;
    int occ_counter = 1;
    int det_counter = 1;

    while(det_counter < n + 1){
      if (x_i[occ_counter] > 0){
        det_idxs[det_counter] = occ_counter;
        det_counter += 1;
      }
      occ_counter += 1;
    }
    
    return(det_idxs);
  }  
  /**
   * Return log-likelihood for a single capture history with individual-specific
   * survival, transition and/or detection probabilities.
   
   * @param  T number of capture occasions
   * @param  K number of non-death states
   * @param  x_i capture history for individual i
   * @param gamma array of occassion-dependent state transition 
   *                probability matrices (phi * psi)
   * @param p  array of occassion-&state-dependent detection probability vectors
   * @return log-likelihood for the given capture histories
   */
  real cap_hist_logp_i(
    data int T, data int K, data array[] int x_i,
    array[] matrix gamma, array[] vector p){
    
    int n_d = n_detections(x_i);
    array[n_d] d_i = detection_indices(n_d, x_i);
    vector[n_d] partial_log_p;
    int occ_now = d_i[1, 1];
    int st_now  = d_i[1, 2];
    row_vector[K] omega = gamma[occ_now][st_now,];
    vector[T - d_i[nd, 1]] chi1m;

    for (det in 1:(n_d - 1)){
      int nxt = det + 1;
      int occ_nxt = d_i[nxt][1];
      int st_nxt  = d_i[nxt][2];
      vector p_nxt = (p[occ_nxt] .* one_hot_vector(K, st_nxt));
      
      for (t in (occ_now + 1):(occ_nxt - 1))
        omega *= diag_pre_multiply(1 - p[t - 1], gamma[t]);
         
      partial_log_p[det] = log(omega * p_nxt);
      occ_now = occ_nxt;
      st_now = st_nxt;
      omega = gamma[occ_now][st_now,];
    }
    
    if (T > det_idx[nd, 1]){
      chi1m[1] = omega * p[occ_now];
      for (t in (occ_now + 1):T){
        omega *= diag_pre_multiply(1 - p[t - 1], gamma[t]);
        chi1m[t - occ_now + 1] = omega * p[t];
      }
      partial_log_p[n_d] = log1m(sum(chi1m));
    }
    
    return (log_sum_exp(partial_log_p));
  }  

Thank you @mbc! I was so close. I just needed to see how you did it to find where I was making mistakes. This helps very much.

Thank you Dalton! I wish I could simplify from an individual based model and aggregate to groups so that I could just use a multinomial distribution rather than categorical, but there are several individual-level covariates that must be addressed for this particular analysis. I admit that your solution goes beyond my meager coding skills, so I will need to take some time to soak in your reply. Thanks for the link to the paper. I just met your co-author Russ Perry a few months ago when he was in Sacramento for a conference.

Thanks @wesmith2! I realize my reply was a bit terse in way of explanation. I’m working on something right now very much related to this and so I had some of the math and code already worked out, but the explanation isn’t there by any means. Hoping to find some time to come back and flesh it out.

Anyway if I can leave you with just a bit of advice, I think you’ll find that there are greater efficiencies to be found in thinking through how to build up the likelihood for your particular than in a more general solution that just uses more horses to pull it. For example, depending on the where in the model the individual covariates apply (e.g. survival-only) there may be ways to speed up sampling by eliminating redundant computations. Or if you have a relatively sparse transition matrix, it can be faster to just hard code the likelihood accounting for only the non-zero transitions than it is to step through a bunch of steps that you know will evaluate to zero. In general, my approach in Stan with mark-recapture models has been to avoid introducing zeros as much as possible (and not just because I’m inherently freaked out by zero).