I find myself regularly fitting large multi-state capture-recapture models in Stan. By large, I mean a large number of individuals (100,000+) with large number of states (100+) but a smallish number of capture-occasions (<10). These types of models can be equivalently considered as hidden Markov models (HMM) with large transition matrices and Bernoulli observations. Specifically, I am talking about temporally-stratified Cormack-Jolly-Seber (TSCJS) models (attached) that we use to estimate travel time, detection and survival of migrating fish. However, I think the issues I want to highlight here are more general to capture-recapture and perhaps to HMMs.
My goal is the most efficient algorithm for calculating the likelihood for the observed data. Iāve come up with 4 ways to do it for the purposes of this post, but for TL/DR crowd here is the heart of my question: the most efficient algorithm for calculating the likelihood that Iāve come up with relies on matrix multiplication on the natural scale and I only take the logarithm after taking the product of multiple probabilities (some of which can potentially be quite small). Based on my simulations, this algorithm appears to be at least 5 times faster than any other algorithm Iāve been able to come up with that relies on taking the logarithm of individual terms prior to the multiplication and addition required to sum over latent discrete states. Obviously, I want to use this algorithm, but Iām seeking advice from this community on potential risks to inference due to underflow.
Okay? With me?
Now if you still care hereās all the nitty gritty:
Here is a brief description of the model copied from the paper.
Consider a migratory population that can be captured at K sequential fixed-location monitoring stations along the migratory corridor. Each station is continuously monitored for a period of time divided into T equal duration strata, with t \in (1, \ldots, T) representing the same physical unit of time at each station. A total of n individuals divided among multiple distinct releases occuring at a single location above the first station where a release is defined as occuring on a single stratum. If an individual is captured at a given station, the stratum in which it was captured is recorded.
The observed data consists of:
- y_i = (y_{i,0}, ..., y_{i,K}) the capture history for individual i, where,
- y_{i,0} = t if individual i was initially released on stratum t
- y_{i,k} = t if individual i was captured at station k on stratum t; 0 otherwise
I also use the following summary variables (calculated as
transformed data
):
- m_{j,k,s,t} = the number of individuals released on stratum s that were next recaptured at station k on stratum t
- l_{k,t} = the number of individuals from released from station k on stratum t never recaptured
The data generating process assumes that:
- All individuals that departed station k on stratum s have the same probability of surviving to station k + 1, \phi_{k,s}; where 0 \leq \phi_{k,s} \leq 1
- All surviving individuals that departed station k on stratum s have the same probability of arriving at station k + 1 on stratum t, \alpha_{k,s,t}; where, 0 \leq \alpha_{k,s,t}, \sum_{t = s}^{T} \alpha_{k,s,t} = 1, and \alpha_{k,s,t} = 0 for s > t (note: we allow that individuals may depart station k and arrive at station k + 1 on the same stratum, e.g. s = t)
- All individuals arriving at station k on stratum t have the same probability of capture, p_{k,t}; where 0 \leq p_{k,t} \leq 1
A set of recursive terms is used in defining the likelihood. \lambda_{j,k,s,t} is the probability of an individual (re)released at station j on stratum s surviving to station k and arriving on stratum t, having passed each station prior to k uncaptured. Setting \lambda_{j,j + 1,s,t} = \phi_{j,s}\alpha_{j,s,t}, we define the recursion:
\lambda_{j,k,s,t} = \sum_{u = s}^t \left(\lambda_{j,k-1,s,u} (1 - p_{k-1,u}) \phi_{k-1,u} \alpha_{k-1,u,t}\right) \:\mathrm{for} \: 0 \leq j < k - 1
The product \lambda_{j,k,s,t}p_{k,t} then defines the probability that an individual released at station j on stratum s is first recaptured at station k on stratum t, where \lambda_{j,k,s,t} represents the sum of all possible mutually-exclusive combinations of unobserved passage times at each station between stations j and k. Thus, the recursive \lambda terms can be used together with the p_{k,t} terms to define the probability of any pair of a release (i.e. initial release or a subsequent release after a recapture) and the next recapture.
The second recursive term represents the probability that an individual is last captured at station k on stratum t. Letting, \chi_{K,t} = 1 we have,
\chi_{k,t} = (1 - \phi_{k,t}) + \phi_{k,t} \sum_{u = t}^{T} \left(\alpha_{k,t,u} (1 - p_{k,u}) \chi_{k + 1,u}\right) \:\mathrm{for} \: k < KThus \chi_{k,t} defines the probablity that an individual is last captured at the station k on stratum t, as the nested sums of mutually exclusive probabilities of dying prior to reaching the next station or of surviving but passing the next station uncaptured on a subsequent stratum.
Letting \mathbf{y} denote the entire set of observed capture histories and \mathbf{\theta} the entire set of parameters, the observed data likelihood can then be expressed as:
Pr(\mathbf{y} | \mathbf{\theta}) \propto \prod_{j = 0}^{K-1} \prod_{k = j + 1}^K \prod_{s = 1}^{T} \prod_{t = s}^{T} (\lambda_{j,k,s,t} p_{k,t})^{m_{j,k,s,t}} \times \prod_{k = 0}^K \prod_{t = 1}^{T} \chi_{k,t}^{l_{k,t}}
So thatās the background. I set up a simulation with K = 4 capture occasions after the initial release of G fish. The simulation is flexible as to the size of n, G, and T, the run times below where achieved with n = 10,000, G = 10, and T = 50. In practice, \alpha_{k,s,.} only needs to be simplex, but Iāve set them up based on a discretization of the lognormal distribution, where the logmean travel time, \mu was simulated and is estimated as a function of temporally stratified covariates. Similarly \phi and p are simulated and estimates as a function of covariates. Hereās a version Stan function that takes parameters \mu and \sigma and calculates the simplex for each row of the \alpha matrix (I also have a log version that returns log_softmax(base)
):
row_vector arrival_probs_ln(int T, real mu, real sigma){
vector[T] base;
vector[T] t = linspaced_vector(T, 0.5, T - 0.5);
base = -square(log(t) - mu) / (2 * square(sigma)) - log(t);
return(softmax(base)');
}
I simulated some data and wrote 4 versions of a Stan model to estimates the parameters. These models are essentially identical through the data
, transformed data
, parameters
, and transformed parameters
block (see attached). The only exception is that the two āearly logā versions use the log version of arrival probabilities. The 4 versions of the Stan model are:
- āForward-algorithmā natural-scale: where the recursive utility parameters \lambda and \chi are calculated based off logic from the familiar forward algorithm from HMMs modelling. This is basically a series of for loops that calculate each recursive term by taking products of the \phi, p and \alpha terms and then taking the log and multiplying by observed counts in the final step.
- āFoward-algorithmā log-scale: in this version I also make use of for-loops, but take the logarithm of the basic parameter first. Then I sum and
log_sum_exp
as needed to calculate the final likeilihood term for each observed count. - Matrix multiplication natural scale: inspired by Fujiwara and Caswell this version lets the magic of linear algebra calculate the recursive terms.
- An attempt at using linear algebra and vectorization on the log-scale to try to make something faster than 2.
Relevant sections of the Stan code for each:
1.
array[4, G] vector[T] l_Lambda_0 = rep_array(rep_vector(log(0), T), 4, G);
array[6, T] vector[T] l_Lambda = rep_array(rep_vector(log(0), T), 6, T);
array[3] vector[T] l_Chi;
vector[G] l_Chi_0;
//Calculate Chi terms in reverse time order (to allow for recursion)
for(rev_s in 1:T){
int s = T - rev_s + 1;
Chi[3, s] = 1 - phi[3, s];
Chi[2, s] = 1 - phi[2, s];
Chi[1, s] = 1 - phi[1, s];
for (t in s:T){
Chi[2, t] += phi[2, s] * alpha[2, s, t] * (1 - p[3, t]) * Chi[3, t];
Chi[1, t] += phi[1, s] * alpha[1, s, t] * (1 - p[2, t]) * Chi[2, t];
}
//increment log-probability by observed "last capture" counts
target += log(Chi[3, s]) * L[3, s];
target += log(Chi[2, s]) * L[2, s];
target += log(Chi[1, s]) * L[1, s];
}
//Calculate log probability for individuals never seen after release, Chi_0
//and for individuals seen at first monitoring station after release, Lambda_0[1]
for (g in 1:G){
Chi_0[g] = 1 - phi_0[g];
for (t in rls_day[g]:T){
Chi_0[g] += phi_0[g] * alpha_0[g, t] * (1 - p[1, t]) * Chi[1, t];
Lambda_0[1, g, t] = phi_0[g] * alpha_0[g, t];
target += (log(Lambda_0[1, g, t]) + log(p[1, t])) * M_0[1, g, t];
}
target += log(Chi_0[g]) * L_0[g];
}
// Recursive calculation for remaining Lamda_0 terms
for (g in 1:G)
for (t in rls_day[g]:T){
for (k in 1:3){
for (u in rls_day[g]:t)
Lambda_0[k + 1, g, t] += Lambda_0[k, g, u] *
(1 - p[k, u]) *
phi[k, u] *
alpha[k, u, t];
}
target += (log(Lambda_0[2, g, t]) + log(p[2, t])) * M_0[2, g, t];
target += (log(Lambda_0[3, g, t]) + log(p[3, t])) * M_0[3, g, t];
target += log(Lambda_0[4, g, t]) * M_0[4, g, t];
}
//Remaining Lambda terms
for (s in 1:T)
for (t in s:T){
Lambda[1, s, t] = phi[1, s] * alpha[1, s, t];
Lambda[2, s, t] = phi[2, s] * alpha[2, s, t];
Lambda[3, s, t] = phi[3, s] * alpha[3, s, t];
target += (log(Lambda[1, s, t]) + log(p[2, t])) * M[1, s, t];
target += (log(Lambda[2, s, t]) + log(p[3, t])) * M[2, s, t];
target += log(Lambda[3, s, t]) * M[3, s, t];
}
for (s in 1:T)
for (t in s:T){
for (u in s:t){
Lambda[4, s, t] += Lambda[1, s, u] * (1 - p[2, u]) * phi[2, u] * alpha[2, u, t];
Lambda[5, s, t] += Lambda[2, s, u] * (1 - p[3, u]) * phi[3, u] * alpha[3, u, t];
Lambda[6, s, t] += Lambda[4, s, u] * (1 - p[3, u]) * phi[3, u] * alpha[3, u, t];
}
target += (log(Lambda[4, s, t]) + log(p[3, t])) * M[4, s, t];
target += log(Lambda[5, s, t]) * M[5, s, t];
target += log(Lambda[6, s, t]) * M[6, s, t];
}
array[4, G] vector[T] l_Lambda_0 = rep_array(rep_vector(log(0), T), 4, G);
array[6, T] vector[T] l_Lambda = rep_array(rep_vector(log(0), T), 6, T);
array[3] vector[T] l_Chi;
vector[G] l_Chi_0;
for(rev_s in 1:T){
int s = T - rev_s + 1;
vector[T - s + 1] l_chi_temp;
l_Chi[3, s] = log1m(phi[3, s]);
for (t in s:T)
l_chi_temp[t - s + 1] = log(phi[2, s]) + l_alpha[2, s, t] +
log1m(p[3, t]) + l_Chi[3, t];
l_Chi[2, s] = log_sum_exp(log1m(phi[2, s]), log_sum_exp(l_chi_temp));
for (t in s:T)
l_chi_temp[t - s + 1] = log(phi[1, s]) + l_alpha[1, s, t] +
log1m(p[2, t]) + l_Chi[2, t];
l_Chi[1, s] = log_sum_exp(log1m(phi[1, s]), log_sum_exp(l_chi_temp));
target += l_Chi[3, s] * L[3, s];
target += l_Chi[2, s] * L[2, s];
target += l_Chi[1, s] * L[1, s];
}
for (g in 1:G){
vector[T - rls_day[g] + 1] l_chi_temp;
for (t in rls_day[g]:T){
l_chi_temp[t - rls_day[g] + 1] = log(phi_0[g]) + l_alpha_0[g, t] +
log1m(p[1, t]) + l_Chi[1, t];
l_Lambda_0[1, g, t] = log(phi_0[g]) + l_alpha_0[g, t];
target += (l_Lambda_0[1, g, t] + log(p[1, t])) * M_0[1, g, t];
}
l_Chi_0[g] = log_sum_exp(log1m(phi_0[g]), log_sum_exp(l_chi_temp));
target += l_Chi_0[g] * L_0[g];
}
for (g in 1:G)
for (t in rls_day[g]:T){
vector[t - rls_day[g] + 1] temp;
for (k in 1:3){
for (u in rls_day[g]:t)
temp[u - rls_day[g] + 1] = l_Lambda_0[k, g, u] + log1m(p[k, u]) +
log(phi[k, u]) + l_alpha[k, u, t];
l_Lambda_0[k + 1, g, t] = log_sum_exp(temp);
}
target += (l_Lambda_0[2, g, t] + log(p[2, t])) * M_0[2, g, t];
target += (l_Lambda_0[3, g, t] + log(p[3, t])) * M_0[3, g, t];
target += (l_Lambda_0[4, g, t]) * M_0[4, g, t];
}
for (s in 1:T)
for (t in s:T){
l_Lambda[1, s, t] = log(phi[1, s]) + l_alpha[1, s, t];
l_Lambda[2, s, t] = log(phi[2, s]) + l_alpha[2, s, t];
l_Lambda[3, s, t] = log(phi[3, s]) + l_alpha[3, s, t];
target += (l_Lambda[1, s, t] + log(p[2, t])) * M[1, s, t];
target += (l_Lambda[2, s, t] + log(p[3, t])) * M[2, s, t];
target += l_Lambda[3, s, t] * M[3, s, t];
}
for (s in 1:T)
for (t in s:T){
vector[t - s + 1] temp;
for (u in s:t)
temp[u - s + 1] = l_Lambda[1, s, u] + log1m(p[2, u]) +
log(phi[2, u]) + l_alpha[2, u, t];
l_Lambda[4, s, t] = log_sum_exp(temp);
for (u in s:t)
temp[u - s + 1] = l_Lambda[2, s, u] + log1m(p[3, u]) +
log(phi[3, u]) + l_alpha[3, u, t];
l_Lambda[5, s, t] = log_sum_exp(temp);
for (u in s:t)
temp[u - s + 1] = l_Lambda[4, s, u] + log1m(p[3, u]) +
log(phi[3, u]) + l_alpha[3, u, t];
l_Lambda[6, s, t] = log_sum_exp(temp);
target += (l_Lambda[4, s, t] + log(p[3, t])) * M[4, s, t];
target += l_Lambda[5, s, t] * M[5, s, t];
target += l_Lambda[6, s, t] * M[6, s, t];
}
array[4] matrix[G, T] Lambda_0;
array[6] matrix[T, T] Lambda;
array[3] vector[T] Chi;
vector[G] Chi_0;
Chi[3] = (1 - phi[3]);
Chi[2] = (1 - phi[2]) + phi[2] .* (alpha[2] * ((1 - p[3]) .* Chi[3]));
Chi[1] = (1 - phi[1]) + phi[1] .* (alpha[1] * ((1 - p[2]) .* Chi[2]));
Chi_0 = (1 - phi_0) + phi_0 .* (alpha_0 * ((1 - p[1]) .* Chi[1]));
Lambda_0[1] = diag_pre_multiply(phi_0, alpha_0);
for (k in 1:3)
Lambda_0[k + 1] = diag_post_multiply(Lambda_0[k], (1 - p[k]) .* phi[k]) * alpha[k];
Lambda[1] = diag_pre_multiply(phi[1], alpha[1]);
Lambda[2] = diag_pre_multiply(phi[2], alpha[2]);
Lambda[3] = diag_pre_multiply(phi[3], alpha[3]);
Lambda[4] = diag_post_multiply(Lambda[1], (1 - p[2]) .* phi[2]) * alpha[2];
Lambda[5] = diag_post_multiply(Lambda[2], (1 - p[3]) .* phi[3]) * alpha[3];
Lambda[6] = diag_post_multiply(Lambda[4], (1 - p[3]) .* phi[3]) * alpha[3];
target += log(Chi_0)' * L_0;
for (g in 1:G){
int s = rls_day[g];
for (k in 1:3)
target += (log(Lambda_0[k, g, s:]) + log(p[k, s:])') * M_0[k, g, s:];
target += log(Lambda_0[4, g, s:]) * M_0[4, g, s:];
}
for (k in 1:3)
target += log(Chi[k])' * L[k];
for (s in 1:T){
target += (log(Lambda[1, s, s:]) + log(p[2, s:])') * M[1, s, s:];
target += (log(Lambda[2, s, s:]) + log(p[3, s:])') * M[2, s, s:];
target += (log(Lambda[4, s, s:]) + log(p[3, s:])') * M[4, s, s:];
target += log(Lambda[3, s, s:]) * M[3, s, s:];
target += log(Lambda[5, s, s:]) * M[5, s, s:];
target += log(Lambda[6, s, s:]) * M[6, s, s:];
}
array[4, G] vector[T] l_Lambda_0 = rep_array(rep_vector(log(0), T), 4, G);
array[6, T] vector[T] l_Lambda = rep_array(rep_vector(log(0), T), 6, T);
array[3] vector[T] l_Chi;
vector[G] l_Chi_0;
l_Chi[3] = log1m(phi[3]);
target += dot_product(l_Chi[3], L[3]);
for (u in 1:T){
int s = T - u + 1;
real temp_chi;
temp_chi = log_sum_exp(l_alpha[2, s, s:]' + log1m(p[3, s:]) + l_Chi[3, s:]);
l_Chi[2, s] = log_sum_exp(log1m(phi[2, s]), log(phi[2, s]) + temp_chi);
target += l_Chi[2, s] * L[2, s];
temp_chi = log_sum_exp(l_alpha[1, s, s:]' + log1m(p[2, s:]) + l_Chi[2, s:]);
l_Chi[1, s] = log_sum_exp(log1m(phi[1, s]), log(phi[1, s]) + temp_chi);
target += l_Chi[1, s] * L[1, s];
}
for (g in 1:G){
int s = rls_day[g];
real temp;
temp = log_sum_exp(l_alpha_0[g, s:] + log1m(p[1, s:]) + l_Chi[1, s:]);
l_Chi_0[g] = log_sum_exp(log1m(phi_0[g]), log(phi_0[g]) + temp);
target += l_Chi_0[g] * L_0[g];
l_Lambda_0[1, g, rls_day[g]:] = log(phi_0[g]) + l_alpha_0[g, rls_day[g]:];
for (k in 1:3)
for (t in s:T){
l_Lambda_0[k + 1, g, t] = log_sum_exp(l_Lambda_0[k, g, s:t] +
log1m(p[k, s:t]) +
log(phi[k, s:t]) +
l_alpha[k, s:t, t]);
//increment target elementwise for k = 1:3
target += (l_Lambda_0[k, g, t] + log(p[k, t])) * M_0[k, g, t];
}
target += dot_product(l_Lambda_0[4, g, s:], M_0[4, g, s:]);
}
for (s in 1:T){
l_Lambda[1, s, s:] = log(phi[1, s:]) + l_alpha[1, s, s:]';
target += dot_product(l_Lambda[1, s, s:] + log(p[2, s:]), M[1, s, s:]);
l_Lambda[2, s, s:] = log(phi[2, s:]) + l_alpha[2, s, s:]';
target += dot_product(l_Lambda[2, s, s:] + log(p[3, s:]), M[2, s, s:]);
l_Lambda[3, s, s:] = log(phi[3, s:]) + l_alpha[3, s, s:]';
target += dot_product(l_Lambda[3, s, s:], M[3, s, s:]);
for (t in s:T){
l_Lambda[4, s, t] = log_sum_exp(l_Lambda[1, s, s:t] + log1m(p[2, s:t]) +
log(phi[2, s:t]) + l_alpha[2, s:t, t]);
target += (l_Lambda[4, s, t] + log(p[3, t])) * M[4, s, t];
l_Lambda[5, s, t] = log_sum_exp(l_Lambda[2, s, s:t] + log1m(p[3, s:t]) +
log(phi[3, s:t]) + l_alpha[3, s:t, t]);
target += l_Lambda[5, s, t] * M[5, s, t];
l_Lambda[6, s, t] = log_sum_exp(l_Lambda[4, s, s:t] + log1m(p[3, s:t]) +
log(phi[3, s:t]) + l_alpha[3, s:t, t]);
target += l_Lambda[6, s, t] * M[6, s, t];
}
}
Representative run times for each version with 500 warmup and 500 iterations:
1.
> $total
> [1] 361.4987
>
> $chains
> chain_id warmup sampling total
> 1 1 199.835 147.020 346.855
> 2 2 217.435 143.271 360.706
> 3 3 217.878 143.112 360.990
> 4 4 192.084 148.456 340.540
$total
[1] 581.7066
$chains
chain_id warmup sampling total
1 1 321.129 221.260 542.389
2 2 366.751 214.394 581.145
3 3 334.843 218.219 553.062
4 4 359.310 217.396 576.706
$total
[1] 76.42366
$chains
chain_id warmup sampling total
1 1 42.626 33.593 76.219
2 2 39.638 33.797 73.435
3 3 41.791 33.490 75.281
4 4 38.522 34.027 72.549
$total
[1] 574.1323
$chains
chain_id warmup sampling total
1 1 326.993 231.003 557.996
2 2 320.375 235.021 555.396
3 3 321.529 223.134 544.663
4 4 346.883 226.526 573.409
The mostly natural-scale algorithm is 8x faster than either log-scale version. I would much prefer to use this version, but what am I risking doing so?
simulate_TSCJS_4K.R (11.2 KB)
TSCJS_4K_Simulated_NoRE.stan (9.5 KB)
TSCJS_4K_Simulated_NoRE_FA_logscale.stan (11.5 KB)
TSCJS_4K_Simulated_NoRE_FA_natscale.stan (10.3 KB)
TSCJS_4K_Simulated_NoRE_logscale_matrixform.stan (10.9 KB)
biom.13171.pdf (1.4 MB)