Improving speed of multistate mark-recapture hidden Markov model with group effects

I’m hoping to improve the sampling speed of a multistate, capture–mark–recapture model of survival and transitioning between breeding states in a population of individually-marked wild birds. I’ve based the model on this (excellent) post by @mhollanders, which formulates the basic multistate Cormack–Jolly–Seber model as a hidden Markov model. I’ve adapted the code presented there to suit my own problem, mainly by making the demographic and observation parameters (survival, transition probability, and detection) individually- and time-varying, in addition to adding group effects (“random intercepts”).

I have the model up and running, and everything seems to be working smoothly—no divergences, no warnings about ESS, E-BFMI, or treedepth, and sensible estimates for population effects when using relatively simple regressions for survival and transition probability. However, adding in just a single group effect (of year) slows things down substantially, and I’m suspicious it may be because I’ve done something foolish.

Alternatively, this is just the way it is, and if that’s so then I’m quite happy to accept that; however, I would be keen to learn any ways to make things more efficient (or to know where I’ve gone awry and done something ill-advised if that’s the case). I haven’t yet worked out a good intuition for writing efficient Stan code.

I’ll start with a few details about the data and general model structure; look for the full code at the end of the post.

Data

  • 1,318 individuals
  • discrete-time, with 37 intervals over 19 years (38 biannual census instances)
  • 8,916 observation intervals (potential recaptures of individual birds)
  • 2 observable states (non-breeding, breeding)
  • 1 unobservable state (dead/non-detection)

Model

Here’s what (at least as I understand it), I’m trying to do:

My Markov chain tracks the marginal probability {\omega_{i,t}}_S that individual i is in state S at time t. I’ve indexed the states so that S_{\text{non-breeding}}=1, S_{\text{breeding}}=2, and S_{\text{dead/not detected}}=3. Each individual’s initial state \boldsymbol{\omega}_{i, {t_0}_i} is encoded by a one-hot vector (either \langle\ 1\ 0\ 0\ \rangle or \langle\ 0\ 1\ 0\ \rangle, since no individual enters the model dead or undetected). I use the forward algorithm to update these marginal probabilities at each time-step by the transition matrix \boldsymbol{\Delta}_{i,t} and the column of the emission matrix \textbf{P}_{i,t} that corresponds to the individual’s (apparent) state at time t, S_{i,t}:

\boldsymbol{\omega}_{i,t}=\boldsymbol{\omega}_{i,t-1} \times \boldsymbol{\Delta}_{i,t} \odot {{\textbf{P}_{i,t}}_{[\cdot,S_{i,t}]}}'

(where \odot is the elementwise product). The transition matrix \boldsymbol{\Delta} models the ecological processes of [state-, individual-, and time-varying] survival, {\phi_S}_{i,t} and state persistence {\psi_S}_{i,t}. (Note that I’m modelling the probability that an individual in state S at time t-1 remains in state S (conditional on surviving the interval) rather than the probability of transitioning into another state; however, since there’s only two observable states, persistence in this case is simply the complement of transition probability.)

\boldsymbol{\Delta}_{i,t} = \begin{bmatrix} {\phi_1}_{i,t}{\psi_1}_{i,t} & {\phi_1}_{i,t}(1-{\psi_1}_{i,t}) & (1-{\phi_1}_{i,t}) \\ {\phi_2}_{i,t}(1-{\psi_2}_{i,t}) & {\phi_2}_{i,t}{\psi_2}_{i,t} & (1-{\phi_2}_{i,t}) \\ 0 & 0 & 1 \end{bmatrix}

The emission matrix \textbf{P} models the observation process of (state-, individual-, and time-varying) detection, {p_S}_{i,t}:

\textbf{P}_{i,t} = \begin{bmatrix} {p_1}_{i,t} & 0 & (1-{p_1}_{i,t}) \\ 0 & {p_2}_{i,t} & (1-{p_2}_{i,t}) \\ 0 & 0 & 1 \end{bmatrix}

I’ve not included any allowances for classification error.

Finally, each of the demographic and detection parameters is modelled as the outcome of a regression, with population (“fixed”) effects (\boldsymbol{\alpha}, \boldsymbol{\beta}, and \boldsymbol{\gamma}) and (in the case of survival and state persistence) group effects (“random intercepts”) (\textbf{r}_\phi, \textbf{r}_\psi):

\begin{align} \boldsymbol{\phi}_S &= \text{logit}^{-1}\left(\textbf{W}_S \boldsymbol{\alpha}_S + {\textbf{r}_{\phi}}_S \right) \\ \boldsymbol{\psi}_S &= \text{logit}^{-1}\left(\textbf{X}_S \boldsymbol{\beta}_S + {\textbf{r}_{\psi}}_S \right) \\ p_S &= \text{logit}^{-1}\left(\textbf{Z}_S \boldsymbol{\gamma}_S \right) \end{align}

Technical details

  • I’ve already coded the model for within-chain parallelisation with reduce_sum() (reduce_sum() calls partial_sum_hmm_lmpf(), which re-indexes and reassembles the chunked data before passing it on to hmm_logp() to compute the log posterior). However, as this is my first time actually using reduce_sum(), it’s quite possible I’m not implementing it correctly or to its greatest effect.
  • At present, all population and group effects are assumed to be completely independent across breeding states (i.e., the effect of (e.g.) sex on survival in non-breeders is completely independent of its effect in breeders). I have considered making this more hierarchical, as I suspect in most cases there is some correlation between the effects—however, I’m not entirely sure how to implement this, especially since some effects only appear in one context or the other.
  • I’m using a non-centred parameterisation for the group effects (using affine transforms).
  • All continuous covariates are mean-centred and have a unit scale.
  • I’m using sum-to-zero contrasts for all categorical population effect predictors.

Profiling and performance

The model runs relatively quickly when there are no group/random effects. With relatively simple population effect design matrices (survival and persistence varying by sex and age only; intercept-only model for detection) and no group effects, the model finishes sampling in 40 minutes on my machine (1e3 warmup iterations, 1e3 sampling iterations, 4 chains running in parallel, 2 threads per chain). For the purpose of troubleshooting, I’ve run some profiling tests with just 20 warmup and 20 sampling iterations as well. Without group effects, sampling finishes in just over a minute; with group effects for year on survival and persistence (as coded below), it takes close to 4 minutes.

If full execution of the model with group effects scales from 40 to 2,000 samples similarly to the model without fixed effects (and I don’t know enough about the mechanics of HMC to know whether or not that’s to be expected…), then the full model should take ~4 hrs to run. That seems quite reasonable to me, and generally speaking I’d be satisfied with that. However, there are a few additional group effects that I think should be accounted for (e.g., social group, as these birds live in family groups with a single breeding pair; all members of individual groups tend to experience relatively similar conditions, some [but maybe not all?] of which are captured by the population-effect predictors). There are many more social groups than there are years, which leads me to expect that adding that component will seriously impact sampling time.

In light of that, I’d be interested to know whether there’s anything I could try to improve the sampling speed. Also, (and although I am not strictly speaking a “Stan newbie”) as I am still (and ever) continuing to grow in understanding of Stan and Bayesian approaches generally, I’d welcome any input others might have on the model specification and implementation.

Code

functions {
  /* Add a constant each element of an array
   * 
   *  n - length of the input vector
   *  x - input vector
   *  d - constant to be added to each element of `x`
   * 
   * Returns `y`, an integer array of length `n`.
   */
  array[] int add(int n, array[] int x, int d){
    array[n] int y;
    for(i in 1:n)
      y[i] = x[i] + d;
    return y;
  }
  
  /* Log-probability of state observations in hidden Markov, multistate 
   * Cormack-Jolly-Seber model.
   * 
   * Compute state-specific, individual- and time-varying parameters phi 
   * (survival), psi (state persistence), and p (detection), then use these to
   * construct transition (ecological process) and emission (detection/observation 
   * process) matrices Delta and Rho.
   * 
   * Called from partial_sum_hmm_lpmf as part of a within-chain parallelised
   * sampling regime.
   * 
   *    n_ind - number of individuals
   *    n_obs - number of observations
   *       S0 - initial state of each individual
   *        S - state observations (post-initial)
   *       W1 - design matrix for survival of non-breeders
   *       W2 - design matrix for survival of breeders
   *       X1 - design matrix for state persistence of non-breeders
   *       X2 - design matrix for state persistence of breeders
   *       Z1 - design matrix for detection of non-breeders
   *       Z2 - design matrix for detection of breeders
   *    first - index of first observation of each individual
   *     last - index of last detection observation of each individual
   *      end - index of last observation (detection OR not) of each individual
   *       a1 - population parameters for non-breeder survival
   *       a2 - population parameters for breeder survival
   *       b1 - population parameters for non-breeder state persistence
   *       b2 - population parameters for breeder state persistence
   *       c1 - population parameters for non-breeder detection
   *       c2 - population parameters for breeder detection
   *  r_phi_y - group intercepts for survival (year effect)
   *  r_psi_y - group intercepts for state persistence (year effect)
   * 
   * Returns 'ptarget', the (partial) log-density of the posterior.
   * 
   */
  real hmm_logp(
    data int n_ind, data int n_obs, data array[] int S0, data array[] int S, 
    data matrix W1, data matrix W2, 
    data matrix X1, data matrix X2, 
    data matrix Z1, data matrix Z2, 
    data array[] int first, data array[] int last, data array[] int end, 
    vector a1, vector a2, vector b1, vector b2, vector c1, vector c2, 
    matrix r_phi_y, matrix r_psi_y
  ){
    matrix[n_obs,2] phi, psi, p;
    array[n_obs] matrix[3,3] Delta, Rho;
    real ptarget = 0.0;
    
    // compute linear predictors for survival, state persistence, and detection
    phi = inv_logit(append_col(W1*a1, W2*a2) + r_phi_y);
    psi = inv_logit(append_col(X1*b1, X2*b2) + r_psi_y);
    p   = inv_logit(append_col(Z1*c1, Z2*c2));
    
    for(t in 1:n_obs){      // Construction ecological process (transition) matrix
      Delta[t][1,1] = phi[t,1] * psi[t,1];         // 
      Delta[t][2,1] = phi[t,2] * (1.0 - psi[t,2]); // 
      Delta[t][3,1] = 0.0;                         // 
      Delta[t][1,2] = phi[t,1] * (1.0 - psi[t,1]); //
      Delta[t][2,2] = phi[t,2] * psi[t,2]; // 
      Delta[t][3,2] = 0.0;                 //   phi1*psi1   phi1*(1-psi1)  (1-phi1)
      Delta[t][1,3] = 1.0 - phi[t,1];      // phi2*(1-psi2)   phi2*psi2    (1-phi2)
      Delta[t][2,3] = 1.0 - phi[t,2];      //       0             0            1
      Delta[t][3,3] = 1.0;                 // 
    }
    
    for(t in 1:n_obs){          // Construct observation process (emission) matrix
      Rho[t][1,1] = p[t,1];       // 
      Rho[t][2,1] = 0.0;          // 
      Rho[t][3,1] = 0.0;          // 
      Rho[t][1,2] = 0.0;          //    p1   0  1-p1
      Rho[t][2,2] = p[t,2];       //     0  p2  1-p2
      Rho[t][3,2] = 0.0;          //     0   0    1
      Rho[t][1,3] = 1.0 - p[t,1]; // 
      Rho[t][2,3] = 1.0 - p[t,2]; // 
      Rho[t][3,3] = 1.0;          // 
    }
    
    // likelihood
    for(i in 1:n_ind){
      // initialise the marginal probability vector as the initial state vector
      // (reflect's the individual's status at first capture; S0 is either 1 or 2)
      row_vector[3] omega = one_hot_row_vector(3, S0[i]);
      
      // condition the live states (1 and 2) on observations from the first to 
      // the last detection event
      for(t in first[i]:last[i])
        omega = (omega * Delta[t]) .* Rho[t][:,S[t]]';
      omega[3] = 0.0; // fix the p(dead) to 0 at last detection
      
      // condition all states on non-detection after last detection
      for(t in (last[i]+1):end[i])
        omega = (omega * Delta[t]) .* Rho[t][:,S[t]]';
      
      ptarget += log(sum(omega)); // increment [partial] log density
    }
    
    return ptarget; // return the [partial] log density
    
  }
  
  /* Chunk data for multi-threading. Called from reduce_sum, as a wrapper for
   * hmm_logp. Reassumbles and indexes chunked data before passing on to hmm_logp.
   * 
   *    i_seq - indices of individuals in current chunk
   *  i_start - index of first individual in current chunk
   *    i_end - index of last individual in current chunk
   *       S0 - initial state of each individual
   *        S - state observations (post-initial)
   *       W1 - design matrix for survival of non-breeders
   *       W2 - design matrix for survival of breeders
   *       X1 - design matrix for state persistence of non-breeders
   *       X2 - design matrix for state persistence of breeders
   *       Z1 - design matrix for detection of non-breeders
   *       Z2 - design matrix for detection of breeders
   *    first - index of first observation of each individual
   *     last - index of last detection observation of each individual
   *      end - index of last observation (detection OR not) of each individual
   *        y - year identifier/index for each observation
   *  alpha_1 - population parameters for non-breeder survival
   *  alpha_2 - population parameters for breeder survival
   *   beta_1 - population parameters for non-breeder state persistence
   *   beta_2 - population parameters for breeder state persistence
   *  gamma_1 - population parameters for non-breeder detection
   *  gamma_2 - population parameters for breeder detection
   *  r_phi_y - group intercepts for survival (year effect)
   *  r_psi_y - group intercepts for state persistence (year effect)
   * 
   * Returns `ptarget`, the output of hmm_logp using the chunked and re-
   * indexed data.
   */
  real partial_sum_hmm_lpmf(
    data array[] int i_seq, data int i_start, data int i_end, 
    data array[] int S0, data array[] int S, 
    data matrix W1, data matrix W2, 
    data matrix X1, data matrix X2, 
    data matrix Z1, data matrix Z2, 
    data array[] int first, data array[] int last, data array[] int end, 
    data array[] int y, 
    vector alpha_1, vector alpha_2, 
    vector beta_1,  vector beta_2, 
    vector gamma_1, vector gamma_2, 
    matrix r_phi_y, matrix r_psi_y
  ){
    int n_ind   = i_end - i_start + 1;
    int o_start = first[i_start]; // use individual indices to find the correspon-
    int o_end   = end[i_end];     // ding observation indices and number of obser-
    int n_obs   = o_end - o_start + 1; // vations
    array[n_obs] int o_seq = linspaced_int_array(n_obs, o_start, o_end);
    array[n_ind] int t_first = add(n_ind, first[i_seq], 1-o_start);
    array[n_ind] int t_last  = add(n_ind, last[i_seq],  1-o_start);
    array[n_ind] int t_end   = add(n_ind, end[i_seq],   1-o_start);
    
    // subset individual-based data with i_seq and observation-based data 
    // with o_seq; pass on data and parameters to hmm_logp
    return hmm_logp(n_ind, n_obs, S0[i_seq], S[o_seq], 
                    W1[o_seq,:], W2[o_seq,:], 
                    X1[o_seq,:], X2[o_seq,:], 
                    Z1[o_seq,:], Z2[o_seq,:], 
                    t_first, t_last, t_end, 
                    alpha_1, alpha_2, beta_1, beta_2, gamma_1, gamma_2, 
                    r_phi_y[y[o_seq],:], r_psi_y[y[o_seq],:]);
  }
}

data {
                     // Dimensions
  int<lower=1> N;    //   observations
  int<lower=1> I;    //   individuals
  int<lower=1> Y;    //   [austral] years
                     //   Population effect parameter sizes
  int<lower=1> K1;   //     survival [non-breeder]
  int<lower=1> K2;   //              [breeder]
  int<lower=1> L1;   //     persistence [non-breeder]
  int<lower=1> L2;   //                 [breeder]
  int<lower=1> M1;   //     detection [non-breeder]
  int<lower=1> M2;   //               [breeder]
                                            // Identifiers
  array[N] int<lower=1, upper=Y> y;         //   year ID
                                            // Indices
  array[I] int<lower=1, upper=N> first;     //   indices of first observations
  array[I] int<lower=0, upper=N> last;      //   indices of last observations
  array[I] int<lower=last, upper=N> end;    //   indices of final records
                                            // Data
  array[I] int<lower=1, upper=2> S0;        //   initial status
  array[N] int<lower=1, upper=3> S;         //   status [1=non-breeder, 
                                            //           2=breeder, 
                                            //           3=not seen]
                                            //   Covariate matrices
  matrix[N,K1] W1;                          //   survival [non-breeder]
  matrix[N,K2] W2;                          //            [breeder]
  matrix[N,L1] X1;                          //   persistence [non-breeder]
  matrix[N,L2] X2;                          //               [breeder]
  matrix[N,M1] Z1;                          //   detection [non-breeder]
  matrix[N,M2] Z2;                          //             [breeder]
}

parameters {
                                // Population effects
  vector[K1] alpha_1;           //   survival [non-breeder]
  vector[K2] alpha_2;           //            [breeder]
  vector[L1] beta_1;            //   persistence [non-breeder]
  vector[L2] beta_2;            //               [breeder]
  vector[M1] gamma_1;           //   detection [non-breeder]
  vector[M2] gamma_2;           //             [breeder]
                                // Group effects
  real<lower=0> sigma_phi_y_1;  //   survival-by-year [non-breeder]
  vector<offset=0.0, multiplier=sigma_phi_y_1>[Y] r_phi_y_1;
  real<lower=0> sigma_phi_y_2;  //                    [breeder]
  vector<offset=0.0, multiplier=sigma_phi_y_2>[Y] r_phi_y_2;
  real<lower=0> sigma_psi_y_1;  //   persistence-by-year [non-breeder]
  vector<offset=0.0, multiplier=sigma_psi_y_1>[Y] r_psi_y_1;
  real<lower=0> sigma_psi_y_2;  //                       [breeder]
  vector<offset=0.0, multiplier=sigma_psi_y_2>[Y] r_psi_y_2;
}

transformed parameters {
  real lprior = 0.0;
  lprior += student_t_lpdf(alpha_1[1] | 3, 0.0, 1); // Student's t(3, 0, 1)
  lprior += student_t_lpdf(alpha_2[1] | 3, 0.0, 1); //   for all population effect
  lprior += student_t_lpdf(beta_1[1]  | 3, 0.0, 1); //   intercept parameters
  lprior += student_t_lpdf(beta_2[1]  | 3, 0.0, 1);
  lprior += student_t_lpdf(gamma_1[1] | 3, 0.0, 1);
  lprior += student_t_lpdf(gamma_2[1] | 3, 0.0, 1);
  if(K1 > 1)                                        // Student's t(5, 0, 0.5)
    for(k in 2:K1)                                  //   for all population effect
      lprior += student_t_lpdf(alpha_1[k] | 5, 0.0, 0.5); // slope parameters
  if(K2 > 1)
    for(k in 2:K2)
      lprior += student_t_lpdf(alpha_2[k] | 5, 0.0, 0.5);
  if(L1 > 1)
    for(l in 2:L1)
      lprior += student_t_lpdf(beta_1[l]  | 5, 0.0, 0.5);
  if(L2 > 1)
    for(l in 2:L2)
      lprior += student_t_lpdf(beta_2[l]  | 5, 0.0, 0.5);
  if(M1 > 1)
    for(m in 2:M1)
      lprior += student_t_lpdf(gamma_1[m] | 5, 0.0, 0.5);
  if(M2 > 1)
    for(m in 2:M2)
      lprior += student_t_lpdf(gamma_2[m] | 5, 0.0, 0.5);
  lprior += normal_lpdf(sigma_phi_y_1 | 0.0, 0.5) - 1.0 * // Half-Normal(0, 0.5)
    normal_lccdf(0.0 | 0.0, 0.5);                         //   for all group effect
  lprior += normal_lpdf(r_phi_y_1 | 0.0, sigma_phi_y_1);  //   scale parameters
  lprior += normal_lpdf(sigma_phi_y_2 | 0.0, 0.5) - 1.0 * // Normal(0, <sd_r>)
    normal_lccdf(0.0 | 0.0, 0.5);                         //  for all group effects
  lprior += normal_lpdf(r_phi_y_2 | 0.0, sigma_phi_y_2);
  lprior += normal_lpdf(sigma_psi_y_1 | 0.0, 0.5) - 1.0 * 
    normal_lccdf(0.0 | 0.0, 0.5);
  lprior += normal_lpdf(r_psi_y_1 | 0.0, sigma_psi_y_1);
  lprior += normal_lpdf(sigma_psi_y_2 | 0.0, 0.5) - 1.0 * 
    normal_lccdf(0.0 | 0.0, 0.5);
  lprior += normal_lpdf(r_psi_y_2 | 0.0, sigma_psi_y_2);
}

model {
  array[I] int i_inds = linspaced_int_array(I, 1, I); // construct individual index
  target += reduce_sum(partial_sum_hmm_lpmf, i_inds, 1, S0, S, 
                       W1, W2, X1, X2, Z1, Z2, 
                       first, last, end, y, 
                       alpha_1, alpha_2, beta_1, beta_2, gamma_1, gamma_2, 
                       append_col(r_phi_y_1, r_phi_y_2), 
                       append_col(r_psi_y_1, r_psi_y_2));
  target += lprior;
}

Technical specifications

Processor: Intel Core i5-8365U CPU @ 1.60 GHz / 1.90 GHz
RAM: 16.0 GB
R: v4.4.1
cmdstan: v2.35.0
cmdstanR: v0.8.1

Do you have sample data or simulated data code you’d be willing to provide? I believe that this can be sped up with some tweaks. I have some code waiting for approval to be released by my agency that fits a particular form of a multistate model (assuming transition matrices are upper triangular), but which could be readily adapted to this situation. However, I will also take a stab at rewriting what you have here (omitting the reduce_sum).

Some thoughts on the priors:

  • Move to the model block and drop the lprior += business
  • The if statements are not necessary since if X < 1, then for (2:X) will not run.
  • The constraints are already handled in the parameters block since those parameters were declared with constraints (with Jacobians added as necessary). Therefore you don’t need the normal_lccdf bit
  • Try more zero-avoiding priors for the standard deviations of the random effects (like logistic(1,4)
  • One thing that might be happening to slow the model down is that the model is getting trapped in funnel when the sigma_ variables approach zero. As a test, try setting their lower bound as a some value greater than zero (like 1e-10)

For calculating the likelihood, you really don’t need the dead state. There are ways to avoid doing so.

For the CJS model, the likelihood component for each individual can be constructed by taking the product of survival from an individual’s first capture to last capture with terms accounting for the probability of being captured at each capture occasion after the first, with an additional term reserved for the probability of remaining uncaptured after the last observation. One way to define this likelihood is:

g(y_{i,k}|\mathbf{\phi_i},\mathbf{p_i}) = \left(\phi_{i,x_i} \prod_{k = x_i + 1}^{z_i - 1}\left[\phi_{i,k}(1 - p_{i,k}) \left( \frac{p_{i,k}}{1-p_{i,k}}\right)^{y_{i,k}}\right] p_{i,z_i} \right)\chi_{i,z_i}

where, y_{i,k} is an individual’s capture-history, which consists of a Boolean indicator of whether individual i was captured at time k; x_i is the index of individuals first capture and z_i is the index of an individuals last capture with x_i, z_i \in (1, ..., K)
For capture histories where x_i = z_i, the entire term in large parentheses is omitted (leaving only

g(y_{i,k}|\mathbf{\phi_i},\mathbf{p_i}) = \chi_{i,z_i}), while for x_i = z_i - 1 only the product term in the middle of that expression is omitted (leaving only

g(y_{i,k}|\mathbf{\phi_i},\mathbf{p_i}) = \phi_{i,x_i}p_{i,z_i}\chi_{i,z_i}).

The term \chi_{i,k} is the probability of never being observed again after recapture which can be either because an individual died at some point after last capture or because the individual remains alive, but uncaptured. In most of the literature this term is usually defined and expressed recursively both in theory and in computation. Setting \chi_{i,K} = 1, the preceeding terms through \chi_{i,1} are usually written as:

\chi_{i,k} = (1 - \phi_{i,k}) + \phi_{i,k}(1-p_{i,k + 1})\chi_{i,k+1}

The intuition behind the recursion is that for each capture occasion we have to account for two possibilities: (1) the individual did not survive to the next time capture occasion with probability 1 - \phi_{k} ( i subscript dropped for clarity temporarily) or (2) the individual did survive to the next capture occasion with probability \phi_{k} but was not captured (with probability 1 - p_{k + 1}) and was also not captured again after being alive at capture occasion k + 1 probability \chi_{k + 1}.

However, this recursive definition of the \chi terms can be avoided, which may be particularly useful in individual based models as both I and K increase in size. This term can be re-expressed by recognizing that the probability of never being captured again after the final capture is simply one minus the probability of being captured again after the final capture. We don’t necessarily need to step through each either/or possibility at each capture occasion, we simply need to account for the set of circumstances that would result in at least one more capture. Either the individual will be seen again at capture occasion k + 1 with probability \phi_kp_{k + 1}, or it will be seen again at capture occasion k + 2 with probability \phi_k(1 - p_{k + 1})\phi_{k + 1}p_{k + 2} and so forth up until the probability it will be seen again at the final capture occasion with probability:

\phi_k\left(\prod_{j=k+1}^{K-1} \left[(1 - p_{j})\phi_{j}\right]\right)p_{K}

These two expression are mathematically equivalent. For example, with four capture occasions, the probability of an individual last being captured again on occasion two is:

\chi_{2} = (1 - \phi_{2}) + \phi_{2}(1-p_{3}) = 1 - \phi_2 + \phi_2 - \phi_2p3= 1 - \phi_2p_3

While the probability of an individual last being captured again on occasion one is:

\chi_{1} = (1 - \phi_{1}) + \phi_{1}(1-p_{2})\chi_2 = (1 - \phi_{1}) + \phi_{1}(1-p_{2})[1 - \phi_2p_3] = (1 - \phi_1) + [\phi_1(1 - p_2) - \phi_1(1 - p_2)\phi_2p_3]

= 1 - [\phi_1p_2 + \phi_1(1 - p_2)\phi_2p_3]

More generally we can write this as:

\chi_{i,k} = 1 - \sum_{j = k}^{K-1}\left[\left( \prod_{l = k}^j \phi_{i,l}(1-p_{i,l + 1})\right)\frac{p_{i,j + 1}}{1-p_{i,j + 1}} \right]

The nice thing about the non-recursive definition is that it’s calculation can be vectorized. Here are some functions I wrote for calculating the likelihood of the CJS model on an individual basis on the log-scale.

functions {
  int n_detections(array[] int y_i){
    int n_det = 0;
    
    for (k in 1:size(y_i))
      if (y_i[k] > 0)
        n_det += 1;
    
    return(n_det);
  }

  array[] int detection_indices(int n, array[] int y_i){
    array[n] int det_idxs;
    int occ_counter = 1;
    int det_counter = 1;

    while(det_counter < n + 1){
      if (y_i[occ_counter] > 0){
        det_idxs[det_counter] = occ_counter;
        det_counter += 1;
      }
      occ_counter += 1;
    }
    
    return(det_idxs);
  }  

  real caphist_logprob(array[] int y_i, vector log_phi, 
                       vector logit_p, vector log1m_p){
    if(size(log_phi) != size(y_i) - 1)
      reject("log_phi must be size K - 1");
    if(size(log_phi) != size(logit_p) - 1)
      reject("log_phi must be size K - 1 and logit_p must be size K");
    
    int K = size(y_i);
    int n = n_detections(y_i);
    array[n] int d = detection_indices(n, y_i);
    int f = d[1];
    int l = d[n];
    real lp = 0;
    
    if (n > 1)
      lp += sum(log_phi[f:(l - 1)]) + 
            sum(log1m_p[(f + 1):l]) + 
            sum(logit_p[d[2:]]);

    if (l < K){
      vector[K - l] log1m_chi;
      
      log1m_chi = cumulative_sum(log_phi[l:] + log1m_p[(l + 1):]) +
                  logit_p[(l + 1):];
      
      lp  += log1m_exp(log_sum_exp(log1m_chi));
    }
    
    return(lp);
  } 
  
}
2 Likes

Thanks for the shoutout! I don’t immediately see something in your model code (but that doesn’t mean much, I’m pretty slow). I’d be curious to see some traceplots for the random effect hyperparameters. Honestly, I’m not entirely sure what to expect in terms of speed differences when you add such random effects, but I’m a bit surprised it’s taking that much longer. Have you tried the actual profiling tools included in cmdstanr? Are you positive all the indexing is going well in the partial sum function? I’m struggling to follow it. I do have a few general comments:

  • I think where you condition on live states only, you just want to index 1:2, i.e. omega[1:2] = (omega[1:2] * Delta[t][1:2, 1:2]) .* Rho[t][1:2,S[t]]';. You also don’t need to set omega[3] = 0, as it remains 0 from the initialisation.
  • Then from after last capture, you don’t need S[t] indices anymore in the observation matrix as you know it’s a 3 (not detected).
  • You can declare i_inds in the transformed data block.

I agree with Dalton that the lprior stuff isn’t necessary but I sometimes add it when I use the priorsense package.

1 Like

A lot of your math was redacted because of the use of an unknown “\boldsymbol” macro.

You have a non-heterogeneous HMM, and I’m afraid that can’t be converted to our built-in code.

You can, by the way, greatly simplify some of the code, e.g.,

Delta[t] = [[a, b, c], 
            [d, e, f],
            [g, h, I]];

A lot of the other code can be made more efficient, but this is a daunting amount of code to try to review. For example, whenever you see transpositions, consider oriented the parameter. For instance, rather than have Rho as you have it and then using Rho[t][:, S[t]]', instead transpose the last two dimensions of Rho so you can write Rho[t, S[t]]. I’m not sure about the state of our indexing optimizations, but it should be at least as efficient if not more efficient to write Rho[t, :, S[t]]'.

Rather than adding to arrays and then indexing into them, you might be able to cut out the middle step and jus calculate the indexes directly.

I would try tightening up some of the Student-t priors, though by the time you get to 5 degrees of freedom it’s already pretty normal-like.

You want to check if centered or non-centered parameterizations are better—when you have a lot of data and/or strong priors, the centered parameterizations can be more efficient.

I have no idea how you’d check correctness of something this complicated. You’re also running it on a very slow computer—a 1.6 GHz i5?

1 Like

Thanks Dalton, Matthijs, and Bob! I appreciate you each taking the time to have a look and a think—I know it is a lot of code to wade through! At this stage, I still benefit a lot from relatively simple/general tips (like moving the priors into the model block, thinking about transpositions, and zero-avoiding priors)—so those are very useful!

You’ve given me some great starting points for tuning the code, and I’m hoping to get some time in the next week or so to start tweaking things a bit (I’m in the field presently, so I [like my code and computer] will continue to operate a bit slowly :). I’ll be sure to follow up with any more specific issues I encounter.

Thanks again!