I’m hoping to improve the sampling speed of a multistate, capture–mark–recapture model of survival and transitioning between breeding states in a population of individually-marked wild birds. I’ve based the model on this (excellent) post by @mhollanders, which formulates the basic multistate Cormack–Jolly–Seber model as a hidden Markov model. I’ve adapted the code presented there to suit my own problem, mainly by making the demographic and observation parameters (survival, transition probability, and detection) individually- and time-varying, in addition to adding group effects (“random intercepts”).
I have the model up and running, and everything seems to be working smoothly—no divergences, no warnings about ESS, E-BFMI, or treedepth, and sensible estimates for population effects when using relatively simple regressions for survival and transition probability. However, adding in just a single group effect (of year) slows things down substantially, and I’m suspicious it may be because I’ve done something foolish.
Alternatively, this is just the way it is, and if that’s so then I’m quite happy to accept that; however, I would be keen to learn any ways to make things more efficient (or to know where I’ve gone awry and done something ill-advised if that’s the case). I haven’t yet worked out a good intuition for writing efficient Stan code.
I’ll start with a few details about the data and general model structure; look for the full code at the end of the post.
Data
- 1,318 individuals
- discrete-time, with 37 intervals over 19 years (38 biannual census instances)
- 8,916 observation intervals (potential recaptures of individual birds)
- 2 observable states (non-breeding, breeding)
- 1 unobservable state (dead/non-detection)
Model
Here’s what (at least as I understand it), I’m trying to do:
My Markov chain tracks the marginal probability {\omega_{i,t}}_S that individual i is in state S at time t. I’ve indexed the states so that S_{\text{non-breeding}}=1, S_{\text{breeding}}=2, and S_{\text{dead/not detected}}=3. Each individual’s initial state \boldsymbol{\omega}_{i, {t_0}_i} is encoded by a one-hot vector (either \langle\ 1\ 0\ 0\ \rangle or \langle\ 0\ 1\ 0\ \rangle, since no individual enters the model dead or undetected). I use the forward algorithm to update these marginal probabilities at each time-step by the transition matrix \boldsymbol{\Delta}_{i,t} and the column of the emission matrix \textbf{P}_{i,t} that corresponds to the individual’s (apparent) state at time t, S_{i,t}:
(where \odot is the elementwise product). The transition matrix \boldsymbol{\Delta} models the ecological processes of [state-, individual-, and time-varying] survival, {\phi_S}_{i,t} and state persistence {\psi_S}_{i,t}. (Note that I’m modelling the probability that an individual in state S at time t-1 remains in state S (conditional on surviving the interval) rather than the probability of transitioning into another state; however, since there’s only two observable states, persistence in this case is simply the complement of transition probability.)
The emission matrix \textbf{P} models the observation process of (state-, individual-, and time-varying) detection, {p_S}_{i,t}:
I’ve not included any allowances for classification error.
Finally, each of the demographic and detection parameters is modelled as the outcome of a regression, with population (“fixed”) effects (\boldsymbol{\alpha}, \boldsymbol{\beta}, and \boldsymbol{\gamma}) and (in the case of survival and state persistence) group effects (“random intercepts”) (\textbf{r}_\phi, \textbf{r}_\psi):
Technical details
- I’ve already coded the model for within-chain parallelisation with
reduce_sum()
(reduce_sum()
callspartial_sum_hmm_lmpf()
, which re-indexes and reassembles the chunked data before passing it on tohmm_logp()
to compute the log posterior). However, as this is my first time actually usingreduce_sum()
, it’s quite possible I’m not implementing it correctly or to its greatest effect. - At present, all population and group effects are assumed to be completely independent across breeding states (i.e., the effect of (e.g.) sex on survival in non-breeders is completely independent of its effect in breeders). I have considered making this more hierarchical, as I suspect in most cases there is some correlation between the effects—however, I’m not entirely sure how to implement this, especially since some effects only appear in one context or the other.
- I’m using a non-centred parameterisation for the group effects (using affine transforms).
- All continuous covariates are mean-centred and have a unit scale.
- I’m using sum-to-zero contrasts for all categorical population effect predictors.
Profiling and performance
The model runs relatively quickly when there are no group/random effects. With relatively simple population effect design matrices (survival and persistence varying by sex and age only; intercept-only model for detection) and no group effects, the model finishes sampling in 40 minutes on my machine (1e3 warmup iterations, 1e3 sampling iterations, 4 chains running in parallel, 2 threads per chain). For the purpose of troubleshooting, I’ve run some profiling tests with just 20 warmup and 20 sampling iterations as well. Without group effects, sampling finishes in just over a minute; with group effects for year on survival and persistence (as coded below), it takes close to 4 minutes.
If full execution of the model with group effects scales from 40 to 2,000 samples similarly to the model without fixed effects (and I don’t know enough about the mechanics of HMC to know whether or not that’s to be expected…), then the full model should take ~4 hrs to run. That seems quite reasonable to me, and generally speaking I’d be satisfied with that. However, there are a few additional group effects that I think should be accounted for (e.g., social group, as these birds live in family groups with a single breeding pair; all members of individual groups tend to experience relatively similar conditions, some [but maybe not all?] of which are captured by the population-effect predictors). There are many more social groups than there are years, which leads me to expect that adding that component will seriously impact sampling time.
In light of that, I’d be interested to know whether there’s anything I could try to improve the sampling speed. Also, (and although I am not strictly speaking a “Stan newbie”) as I am still (and ever) continuing to grow in understanding of Stan and Bayesian approaches generally, I’d welcome any input others might have on the model specification and implementation.
Code
functions {
/* Add a constant each element of an array
*
* n - length of the input vector
* x - input vector
* d - constant to be added to each element of `x`
*
* Returns `y`, an integer array of length `n`.
*/
array[] int add(int n, array[] int x, int d){
array[n] int y;
for(i in 1:n)
y[i] = x[i] + d;
return y;
}
/* Log-probability of state observations in hidden Markov, multistate
* Cormack-Jolly-Seber model.
*
* Compute state-specific, individual- and time-varying parameters phi
* (survival), psi (state persistence), and p (detection), then use these to
* construct transition (ecological process) and emission (detection/observation
* process) matrices Delta and Rho.
*
* Called from partial_sum_hmm_lpmf as part of a within-chain parallelised
* sampling regime.
*
* n_ind - number of individuals
* n_obs - number of observations
* S0 - initial state of each individual
* S - state observations (post-initial)
* W1 - design matrix for survival of non-breeders
* W2 - design matrix for survival of breeders
* X1 - design matrix for state persistence of non-breeders
* X2 - design matrix for state persistence of breeders
* Z1 - design matrix for detection of non-breeders
* Z2 - design matrix for detection of breeders
* first - index of first observation of each individual
* last - index of last detection observation of each individual
* end - index of last observation (detection OR not) of each individual
* a1 - population parameters for non-breeder survival
* a2 - population parameters for breeder survival
* b1 - population parameters for non-breeder state persistence
* b2 - population parameters for breeder state persistence
* c1 - population parameters for non-breeder detection
* c2 - population parameters for breeder detection
* r_phi_y - group intercepts for survival (year effect)
* r_psi_y - group intercepts for state persistence (year effect)
*
* Returns 'ptarget', the (partial) log-density of the posterior.
*
*/
real hmm_logp(
data int n_ind, data int n_obs, data array[] int S0, data array[] int S,
data matrix W1, data matrix W2,
data matrix X1, data matrix X2,
data matrix Z1, data matrix Z2,
data array[] int first, data array[] int last, data array[] int end,
vector a1, vector a2, vector b1, vector b2, vector c1, vector c2,
matrix r_phi_y, matrix r_psi_y
){
matrix[n_obs,2] phi, psi, p;
array[n_obs] matrix[3,3] Delta, Rho;
real ptarget = 0.0;
// compute linear predictors for survival, state persistence, and detection
phi = inv_logit(append_col(W1*a1, W2*a2) + r_phi_y);
psi = inv_logit(append_col(X1*b1, X2*b2) + r_psi_y);
p = inv_logit(append_col(Z1*c1, Z2*c2));
for(t in 1:n_obs){ // Construction ecological process (transition) matrix
Delta[t][1,1] = phi[t,1] * psi[t,1]; //
Delta[t][2,1] = phi[t,2] * (1.0 - psi[t,2]); //
Delta[t][3,1] = 0.0; //
Delta[t][1,2] = phi[t,1] * (1.0 - psi[t,1]); //
Delta[t][2,2] = phi[t,2] * psi[t,2]; //
Delta[t][3,2] = 0.0; // phi1*psi1 phi1*(1-psi1) (1-phi1)
Delta[t][1,3] = 1.0 - phi[t,1]; // phi2*(1-psi2) phi2*psi2 (1-phi2)
Delta[t][2,3] = 1.0 - phi[t,2]; // 0 0 1
Delta[t][3,3] = 1.0; //
}
for(t in 1:n_obs){ // Construct observation process (emission) matrix
Rho[t][1,1] = p[t,1]; //
Rho[t][2,1] = 0.0; //
Rho[t][3,1] = 0.0; //
Rho[t][1,2] = 0.0; // p1 0 1-p1
Rho[t][2,2] = p[t,2]; // 0 p2 1-p2
Rho[t][3,2] = 0.0; // 0 0 1
Rho[t][1,3] = 1.0 - p[t,1]; //
Rho[t][2,3] = 1.0 - p[t,2]; //
Rho[t][3,3] = 1.0; //
}
// likelihood
for(i in 1:n_ind){
// initialise the marginal probability vector as the initial state vector
// (reflect's the individual's status at first capture; S0 is either 1 or 2)
row_vector[3] omega = one_hot_row_vector(3, S0[i]);
// condition the live states (1 and 2) on observations from the first to
// the last detection event
for(t in first[i]:last[i])
omega = (omega * Delta[t]) .* Rho[t][:,S[t]]';
omega[3] = 0.0; // fix the p(dead) to 0 at last detection
// condition all states on non-detection after last detection
for(t in (last[i]+1):end[i])
omega = (omega * Delta[t]) .* Rho[t][:,S[t]]';
ptarget += log(sum(omega)); // increment [partial] log density
}
return ptarget; // return the [partial] log density
}
/* Chunk data for multi-threading. Called from reduce_sum, as a wrapper for
* hmm_logp. Reassumbles and indexes chunked data before passing on to hmm_logp.
*
* i_seq - indices of individuals in current chunk
* i_start - index of first individual in current chunk
* i_end - index of last individual in current chunk
* S0 - initial state of each individual
* S - state observations (post-initial)
* W1 - design matrix for survival of non-breeders
* W2 - design matrix for survival of breeders
* X1 - design matrix for state persistence of non-breeders
* X2 - design matrix for state persistence of breeders
* Z1 - design matrix for detection of non-breeders
* Z2 - design matrix for detection of breeders
* first - index of first observation of each individual
* last - index of last detection observation of each individual
* end - index of last observation (detection OR not) of each individual
* y - year identifier/index for each observation
* alpha_1 - population parameters for non-breeder survival
* alpha_2 - population parameters for breeder survival
* beta_1 - population parameters for non-breeder state persistence
* beta_2 - population parameters for breeder state persistence
* gamma_1 - population parameters for non-breeder detection
* gamma_2 - population parameters for breeder detection
* r_phi_y - group intercepts for survival (year effect)
* r_psi_y - group intercepts for state persistence (year effect)
*
* Returns `ptarget`, the output of hmm_logp using the chunked and re-
* indexed data.
*/
real partial_sum_hmm_lpmf(
data array[] int i_seq, data int i_start, data int i_end,
data array[] int S0, data array[] int S,
data matrix W1, data matrix W2,
data matrix X1, data matrix X2,
data matrix Z1, data matrix Z2,
data array[] int first, data array[] int last, data array[] int end,
data array[] int y,
vector alpha_1, vector alpha_2,
vector beta_1, vector beta_2,
vector gamma_1, vector gamma_2,
matrix r_phi_y, matrix r_psi_y
){
int n_ind = i_end - i_start + 1;
int o_start = first[i_start]; // use individual indices to find the correspon-
int o_end = end[i_end]; // ding observation indices and number of obser-
int n_obs = o_end - o_start + 1; // vations
array[n_obs] int o_seq = linspaced_int_array(n_obs, o_start, o_end);
array[n_ind] int t_first = add(n_ind, first[i_seq], 1-o_start);
array[n_ind] int t_last = add(n_ind, last[i_seq], 1-o_start);
array[n_ind] int t_end = add(n_ind, end[i_seq], 1-o_start);
// subset individual-based data with i_seq and observation-based data
// with o_seq; pass on data and parameters to hmm_logp
return hmm_logp(n_ind, n_obs, S0[i_seq], S[o_seq],
W1[o_seq,:], W2[o_seq,:],
X1[o_seq,:], X2[o_seq,:],
Z1[o_seq,:], Z2[o_seq,:],
t_first, t_last, t_end,
alpha_1, alpha_2, beta_1, beta_2, gamma_1, gamma_2,
r_phi_y[y[o_seq],:], r_psi_y[y[o_seq],:]);
}
}
data {
// Dimensions
int<lower=1> N; // observations
int<lower=1> I; // individuals
int<lower=1> Y; // [austral] years
// Population effect parameter sizes
int<lower=1> K1; // survival [non-breeder]
int<lower=1> K2; // [breeder]
int<lower=1> L1; // persistence [non-breeder]
int<lower=1> L2; // [breeder]
int<lower=1> M1; // detection [non-breeder]
int<lower=1> M2; // [breeder]
// Identifiers
array[N] int<lower=1, upper=Y> y; // year ID
// Indices
array[I] int<lower=1, upper=N> first; // indices of first observations
array[I] int<lower=0, upper=N> last; // indices of last observations
array[I] int<lower=last, upper=N> end; // indices of final records
// Data
array[I] int<lower=1, upper=2> S0; // initial status
array[N] int<lower=1, upper=3> S; // status [1=non-breeder,
// 2=breeder,
// 3=not seen]
// Covariate matrices
matrix[N,K1] W1; // survival [non-breeder]
matrix[N,K2] W2; // [breeder]
matrix[N,L1] X1; // persistence [non-breeder]
matrix[N,L2] X2; // [breeder]
matrix[N,M1] Z1; // detection [non-breeder]
matrix[N,M2] Z2; // [breeder]
}
parameters {
// Population effects
vector[K1] alpha_1; // survival [non-breeder]
vector[K2] alpha_2; // [breeder]
vector[L1] beta_1; // persistence [non-breeder]
vector[L2] beta_2; // [breeder]
vector[M1] gamma_1; // detection [non-breeder]
vector[M2] gamma_2; // [breeder]
// Group effects
real<lower=0> sigma_phi_y_1; // survival-by-year [non-breeder]
vector<offset=0.0, multiplier=sigma_phi_y_1>[Y] r_phi_y_1;
real<lower=0> sigma_phi_y_2; // [breeder]
vector<offset=0.0, multiplier=sigma_phi_y_2>[Y] r_phi_y_2;
real<lower=0> sigma_psi_y_1; // persistence-by-year [non-breeder]
vector<offset=0.0, multiplier=sigma_psi_y_1>[Y] r_psi_y_1;
real<lower=0> sigma_psi_y_2; // [breeder]
vector<offset=0.0, multiplier=sigma_psi_y_2>[Y] r_psi_y_2;
}
transformed parameters {
real lprior = 0.0;
lprior += student_t_lpdf(alpha_1[1] | 3, 0.0, 1); // Student's t(3, 0, 1)
lprior += student_t_lpdf(alpha_2[1] | 3, 0.0, 1); // for all population effect
lprior += student_t_lpdf(beta_1[1] | 3, 0.0, 1); // intercept parameters
lprior += student_t_lpdf(beta_2[1] | 3, 0.0, 1);
lprior += student_t_lpdf(gamma_1[1] | 3, 0.0, 1);
lprior += student_t_lpdf(gamma_2[1] | 3, 0.0, 1);
if(K1 > 1) // Student's t(5, 0, 0.5)
for(k in 2:K1) // for all population effect
lprior += student_t_lpdf(alpha_1[k] | 5, 0.0, 0.5); // slope parameters
if(K2 > 1)
for(k in 2:K2)
lprior += student_t_lpdf(alpha_2[k] | 5, 0.0, 0.5);
if(L1 > 1)
for(l in 2:L1)
lprior += student_t_lpdf(beta_1[l] | 5, 0.0, 0.5);
if(L2 > 1)
for(l in 2:L2)
lprior += student_t_lpdf(beta_2[l] | 5, 0.0, 0.5);
if(M1 > 1)
for(m in 2:M1)
lprior += student_t_lpdf(gamma_1[m] | 5, 0.0, 0.5);
if(M2 > 1)
for(m in 2:M2)
lprior += student_t_lpdf(gamma_2[m] | 5, 0.0, 0.5);
lprior += normal_lpdf(sigma_phi_y_1 | 0.0, 0.5) - 1.0 * // Half-Normal(0, 0.5)
normal_lccdf(0.0 | 0.0, 0.5); // for all group effect
lprior += normal_lpdf(r_phi_y_1 | 0.0, sigma_phi_y_1); // scale parameters
lprior += normal_lpdf(sigma_phi_y_2 | 0.0, 0.5) - 1.0 * // Normal(0, <sd_r>)
normal_lccdf(0.0 | 0.0, 0.5); // for all group effects
lprior += normal_lpdf(r_phi_y_2 | 0.0, sigma_phi_y_2);
lprior += normal_lpdf(sigma_psi_y_1 | 0.0, 0.5) - 1.0 *
normal_lccdf(0.0 | 0.0, 0.5);
lprior += normal_lpdf(r_psi_y_1 | 0.0, sigma_psi_y_1);
lprior += normal_lpdf(sigma_psi_y_2 | 0.0, 0.5) - 1.0 *
normal_lccdf(0.0 | 0.0, 0.5);
lprior += normal_lpdf(r_psi_y_2 | 0.0, sigma_psi_y_2);
}
model {
array[I] int i_inds = linspaced_int_array(I, 1, I); // construct individual index
target += reduce_sum(partial_sum_hmm_lpmf, i_inds, 1, S0, S,
W1, W2, X1, X2, Z1, Z2,
first, last, end, y,
alpha_1, alpha_2, beta_1, beta_2, gamma_1, gamma_2,
append_col(r_phi_y_1, r_phi_y_2),
append_col(r_psi_y_1, r_psi_y_2));
target += lprior;
}
Technical specifications
Processor: Intel Core i5-8365U CPU @ 1.60 GHz / 1.90 GHz
RAM: 16.0 GB
R: v4.4.1
cmdstan: v2.35.0
cmdstanR: v0.8.1