# When to take the log? Computational efficiency versus risk of underflow in a large dimensional capture-recapture model

I find myself regularly fitting large multi-state capture-recapture models in Stan. By large, I mean a large number of individuals (100,000+) with large number of states (100+) but a smallish number of capture-occasions (<10). These types of models can be equivalently considered as hidden Markov models (HMM) with large transition matrices and Bernoulli observations. Specifically, I am talking about temporally-stratified Cormack-Jolly-Seber (TSCJS) models (attached) that we use to estimate travel time, detection and survival of migrating fish. However, I think the issues I want to highlight here are more general to capture-recapture and perhaps to HMMs.

My goal is the most efficient algorithm for calculating the likelihood for the observed data. I’ve come up with 4 ways to do it for the purposes of this post, but for TL/DR crowd here is the heart of my question: the most efficient algorithm for calculating the likelihood that I’ve come up with relies on matrix multiplication on the natural scale and I only take the logarithm after taking the product of multiple probabilities (some of which can potentially be quite small). Based on my simulations, this algorithm appears to be at least 5 times faster than any other algorithm I’ve been able to come up with that relies on taking the logarithm of individual terms prior to the multiplication and addition required to sum over latent discrete states. Obviously, I want to use this algorithm, but I’m seeking advice from this community on potential risks to inference due to underflow.

Okay? With me?

Now if you still care here’s all the nitty gritty:

Here is a brief description of the model copied from the paper.

Consider a migratory population that can be captured at K sequential fixed-location monitoring stations along the migratory corridor. Each station is continuously monitored for a period of time divided into T equal duration strata, with t \in (1, \ldots, T) representing the same physical unit of time at each station. A total of n individuals divided among multiple distinct releases occuring at a single location above the first station where a release is defined as occuring on a single stratum. If an individual is captured at a given station, the stratum in which it was captured is recorded.

The observed data consists of:

• y_i = (y_{i,0}, ..., y_{i,K}) the capture history for individual i, where,
• y_{i,0} = t if individual i was initially released on stratum t
• y_{i,k} = t if individual i was captured at station k on stratum t; 0 otherwise

I also use the following summary variables (calculated as transformed data):

• m_{j,k,s,t} = the number of individuals released on stratum s that were next recaptured at station k on stratum t
• l_{k,t} = the number of individuals from released from station k on stratum t never recaptured

The data generating process assumes that:

1. All individuals that departed station k on stratum s have the same probability of surviving to station k + 1, \phi_{k,s}; where 0 \leq \phi_{k,s} \leq 1
2. All surviving individuals that departed station k on stratum s have the same probability of arriving at station k + 1 on stratum t, \alpha_{k,s,t}; where, 0 \leq \alpha_{k,s,t}, \sum_{t = s}^{T} \alpha_{k,s,t} = 1, and \alpha_{k,s,t} = 0 for s > t (note: we allow that individuals may depart station k and arrive at station k + 1 on the same stratum, e.g. s = t)
3. All individuals arriving at station k on stratum t have the same probability of capture, p_{k,t}; where 0 \leq p_{k,t} \leq 1

A set of recursive terms is used in defining the likelihood. \lambda_{j,k,s,t} is the probability of an individual (re)released at station j on stratum s surviving to station k and arriving on stratum t, having passed each station prior to k uncaptured. Setting \lambda_{j,j + 1,s,t} = \phi_{j,s}\alpha_{j,s,t}, we define the recursion:

\lambda_{j,k,s,t} = \sum_{u = s}^t \left(\lambda_{j,k-1,s,u} (1 - p_{k-1,u}) \phi_{k-1,u} \alpha_{k-1,u,t}\right) \:\mathrm{for} \: 0 \leq j < k - 1

The product \lambda_{j,k,s,t}p_{k,t} then defines the probability that an individual released at station j on stratum s is first recaptured at station k on stratum t, where \lambda_{j,k,s,t} represents the sum of all possible mutually-exclusive combinations of unobserved passage times at each station between stations j and k. Thus, the recursive \lambda terms can be used together with the p_{k,t} terms to define the probability of any pair of a release (i.e. initial release or a subsequent release after a recapture) and the next recapture.

The second recursive term represents the probability that an individual is last captured at station k on stratum t. Letting, \chi_{K,t} = 1 we have,
\chi_{k,t} = (1 - \phi_{k,t}) + \phi_{k,t} \sum_{u = t}^{T} \left(\alpha_{k,t,u} (1 - p_{k,u}) \chi_{k + 1,u}\right) \:\mathrm{for} \: k < K

Thus \chi_{k,t} defines the probablity that an individual is last captured at the station k on stratum t, as the nested sums of mutually exclusive probabilities of dying prior to reaching the next station or of surviving but passing the next station uncaptured on a subsequent stratum.

Letting \mathbf{y} denote the entire set of observed capture histories and \mathbf{\theta} the entire set of parameters, the observed data likelihood can then be expressed as:

Pr(\mathbf{y} | \mathbf{\theta}) \propto \prod_{j = 0}^{K-1} \prod_{k = j + 1}^K \prod_{s = 1}^{T} \prod_{t = s}^{T} (\lambda_{j,k,s,t} p_{k,t})^{m_{j,k,s,t}} \times \prod_{k = 0}^K \prod_{t = 1}^{T} \chi_{k,t}^{l_{k,t}}

So that’s the background. I set up a simulation with K = 4 capture occasions after the initial release of G fish. The simulation is flexible as to the size of n, G, and T, the run times below where achieved with n = 10,000, G = 10, and T = 50. In practice, \alpha_{k,s,.} only needs to be simplex, but I’ve set them up based on a discretization of the lognormal distribution, where the logmean travel time, \mu was simulated and is estimated as a function of temporally stratified covariates. Similarly \phi and p are simulated and estimates as a function of covariates. Here’s a version Stan function that takes parameters \mu and \sigma and calculates the simplex for each row of the \alpha matrix (I also have a log version that returns log_softmax(base)):

  row_vector arrival_probs_ln(int T, real mu, real sigma){
vector[T] base;
vector[T] t = linspaced_vector(T, 0.5, T - 0.5);

base = -square(log(t) - mu) / (2 * square(sigma)) - log(t);

return(softmax(base)');
}


I simulated some data and wrote 4 versions of a Stan model to estimates the parameters. These models are essentially identical through the data, transformed data, parameters, and transformed parameters block (see attached). The only exception is that the two “early log” versions use the log version of arrival probabilities. The 4 versions of the Stan model are:

1. “Forward-algorithm” natural-scale: where the recursive utility parameters \lambda and \chi are calculated based off logic from the familiar forward algorithm from HMMs modelling. This is basically a series of for loops that calculate each recursive term by taking products of the \phi, p and \alpha terms and then taking the log and multiplying by observed counts in the final step.
2. “Foward-algorithm” log-scale: in this version I also make use of for-loops, but take the logarithm of the basic parameter first. Then I sum and log_sum_exp as needed to calculate the final likeilihood term for each observed count.
3. Matrix multiplication natural scale: inspired by Fujiwara and Caswell this version lets the magic of linear algebra calculate the recursive terms.
4. An attempt at using linear algebra and vectorization on the log-scale to try to make something faster than 2.

Relevant sections of the Stan code for each:
1.

  array[4, G] vector[T]      l_Lambda_0 = rep_array(rep_vector(log(0), T), 4, G);
array[6, T] vector[T]      l_Lambda   = rep_array(rep_vector(log(0), T), 6, T);
array vector[T]          l_Chi;
vector[G]                      l_Chi_0;

//Calculate Chi terms in reverse time order (to allow for recursion)
for(rev_s in 1:T){
int s = T - rev_s + 1;
Chi[3, s] = 1 - phi[3, s];
Chi[2, s] = 1 - phi[2, s];
Chi[1, s] = 1 - phi[1, s];
for (t in s:T){
Chi[2, t] += phi[2, s] * alpha[2, s, t] * (1 - p[3, t]) * Chi[3, t];
Chi[1, t] += phi[1, s] * alpha[1, s, t] * (1 - p[2, t]) * Chi[2, t];
}
//increment log-probability by observed "last capture" counts
target += log(Chi[3, s]) * L[3, s];
target += log(Chi[2, s]) * L[2, s];
target += log(Chi[1, s]) * L[1, s];
}

//Calculate log probability for individuals never seen after release, Chi_0
//and for individuals seen at first monitoring station after release, Lambda_0
for (g in 1:G){
Chi_0[g] = 1 - phi_0[g];
for (t in rls_day[g]:T){
Chi_0[g] += phi_0[g] * alpha_0[g, t] * (1 - p[1, t]) * Chi[1, t];
Lambda_0[1, g, t] = phi_0[g] * alpha_0[g, t];
target += (log(Lambda_0[1, g, t]) + log(p[1, t])) * M_0[1, g, t];
}
target += log(Chi_0[g]) * L_0[g];
}
// Recursive calculation for remaining Lamda_0 terms
for (g in 1:G)
for (t in rls_day[g]:T){
for (k in 1:3){
for (u in rls_day[g]:t)
Lambda_0[k + 1, g, t] += Lambda_0[k, g, u] *
(1 - p[k, u]) *
phi[k, u] *
alpha[k, u, t];
}
target += (log(Lambda_0[2, g, t]) + log(p[2, t])) * M_0[2, g, t];
target += (log(Lambda_0[3, g, t]) + log(p[3, t])) * M_0[3, g, t];
target += log(Lambda_0[4, g, t]) * M_0[4, g, t];
}
//Remaining Lambda terms
for (s in 1:T)
for (t in s:T){
Lambda[1, s, t] = phi[1, s] * alpha[1, s, t];
Lambda[2, s, t] = phi[2, s] * alpha[2, s, t];
Lambda[3, s, t] = phi[3, s] * alpha[3, s, t];
target += (log(Lambda[1, s, t]) + log(p[2, t])) * M[1, s, t];
target += (log(Lambda[2, s, t]) + log(p[3, t])) * M[2, s, t];
target += log(Lambda[3, s, t]) * M[3, s, t];
}

for (s in 1:T)
for (t in s:T){
for (u in s:t){
Lambda[4, s, t] += Lambda[1, s, u] * (1 - p[2, u]) * phi[2, u] * alpha[2, u, t];
Lambda[5, s, t] += Lambda[2, s, u] * (1 - p[3, u]) * phi[3, u] * alpha[3, u, t];
Lambda[6, s, t] += Lambda[4, s, u] * (1 - p[3, u]) * phi[3, u] * alpha[3, u, t];
}
target += (log(Lambda[4, s, t]) + log(p[3, t])) * M[4, s, t];
target += log(Lambda[5, s, t]) * M[5, s, t];
target += log(Lambda[6, s, t]) * M[6, s, t];
}

  array[4, G] vector[T]      l_Lambda_0 = rep_array(rep_vector(log(0), T), 4, G);
array[6, T] vector[T]      l_Lambda   = rep_array(rep_vector(log(0), T), 6, T);
array vector[T]         l_Chi;
vector[G]                     l_Chi_0;

for(rev_s in 1:T){
int s = T - rev_s + 1;
vector[T - s + 1] l_chi_temp;
l_Chi[3, s] = log1m(phi[3, s]);

for (t in s:T)
l_chi_temp[t - s + 1] = log(phi[2, s]) + l_alpha[2, s, t] +
log1m(p[3, t]) + l_Chi[3, t];

l_Chi[2, s] = log_sum_exp(log1m(phi[2, s]), log_sum_exp(l_chi_temp));

for (t in s:T)
l_chi_temp[t - s + 1] = log(phi[1, s]) + l_alpha[1, s, t] +
log1m(p[2, t]) + l_Chi[2, t];

l_Chi[1, s] = log_sum_exp(log1m(phi[1, s]), log_sum_exp(l_chi_temp));

target += l_Chi[3, s] * L[3, s];
target += l_Chi[2, s] * L[2, s];
target += l_Chi[1, s] * L[1, s];
}

for (g in 1:G){
vector[T - rls_day[g] + 1] l_chi_temp;
for (t in rls_day[g]:T){
l_chi_temp[t - rls_day[g] + 1] = log(phi_0[g]) + l_alpha_0[g, t] +
log1m(p[1, t]) + l_Chi[1, t];
l_Lambda_0[1, g, t] = log(phi_0[g]) + l_alpha_0[g, t];
target += (l_Lambda_0[1, g, t]  + log(p[1, t])) * M_0[1, g, t];
}
l_Chi_0[g] = log_sum_exp(log1m(phi_0[g]), log_sum_exp(l_chi_temp));
target += l_Chi_0[g] * L_0[g];
}

for (g in 1:G)
for (t in rls_day[g]:T){
vector[t - rls_day[g] + 1] temp;
for (k in 1:3){
for (u in rls_day[g]:t)
temp[u - rls_day[g] + 1] = l_Lambda_0[k, g, u] + log1m(p[k, u]) +
log(phi[k, u]) + l_alpha[k, u, t];

l_Lambda_0[k + 1, g, t] = log_sum_exp(temp);
}
target += (l_Lambda_0[2, g, t] + log(p[2, t])) * M_0[2, g, t];
target += (l_Lambda_0[3, g, t] + log(p[3, t])) * M_0[3, g, t];
target += (l_Lambda_0[4, g, t]) * M_0[4, g, t];
}

for (s in 1:T)
for (t in s:T){
l_Lambda[1, s, t] = log(phi[1, s]) + l_alpha[1, s, t];
l_Lambda[2, s, t] = log(phi[2, s]) + l_alpha[2, s, t];
l_Lambda[3, s, t] = log(phi[3, s]) + l_alpha[3, s, t];
target += (l_Lambda[1, s, t] + log(p[2, t])) * M[1, s, t];
target += (l_Lambda[2, s, t] + log(p[3, t])) * M[2, s, t];
target += l_Lambda[3, s, t] * M[3, s, t];
}

for (s in 1:T)
for (t in s:T){
vector[t - s + 1] temp;
for (u in s:t)
temp[u - s + 1] = l_Lambda[1, s, u] + log1m(p[2, u]) +
log(phi[2, u]) + l_alpha[2, u, t];

l_Lambda[4, s, t]  = log_sum_exp(temp);

for (u in s:t)
temp[u - s + 1] = l_Lambda[2, s, u] + log1m(p[3, u]) +
log(phi[3, u]) + l_alpha[3, u, t];

l_Lambda[5, s, t]  = log_sum_exp(temp);

for (u in s:t)
temp[u - s + 1] = l_Lambda[4, s, u] + log1m(p[3, u]) +
log(phi[3, u]) + l_alpha[3, u, t];

l_Lambda[6, s, t]  = log_sum_exp(temp);

target += (l_Lambda[4, s, t] + log(p[3, t])) * M[4, s, t];
target += l_Lambda[5, s, t] * M[5, s, t];
target += l_Lambda[6, s, t] * M[6, s, t];
}

  array matrix[G, T]                        Lambda_0;
array matrix[T, T]                        Lambda;
array vector[T]                           Chi;
vector[G]                                        Chi_0;

Chi   = (1 - phi);
Chi   = (1 - phi) + phi .* (alpha * ((1 - p) .* Chi));
Chi   = (1 - phi) + phi .* (alpha * ((1 - p) .* Chi));
Chi_0    = (1 - phi_0) + phi_0 .* (alpha_0 * ((1 - p) .* Chi));

Lambda_0 = diag_pre_multiply(phi_0, alpha_0);
for (k in 1:3)
Lambda_0[k + 1] = diag_post_multiply(Lambda_0[k], (1 - p[k]) .* phi[k]) * alpha[k];

Lambda = diag_pre_multiply(phi, alpha);
Lambda = diag_pre_multiply(phi, alpha);
Lambda = diag_pre_multiply(phi, alpha);
Lambda = diag_post_multiply(Lambda, (1 - p) .* phi) * alpha;
Lambda = diag_post_multiply(Lambda, (1 - p) .* phi) * alpha;
Lambda = diag_post_multiply(Lambda, (1 - p) .* phi) * alpha;

target += log(Chi_0)' * L_0;
for (g in 1:G){
int s = rls_day[g];
for (k in 1:3)
target += (log(Lambda_0[k, g, s:]) + log(p[k, s:])') * M_0[k, g, s:];

target += log(Lambda_0[4, g, s:]) * M_0[4, g, s:];
}

for (k in 1:3)
target += log(Chi[k])' * L[k];

for (s in 1:T){
target += (log(Lambda[1, s, s:]) + log(p[2, s:])') * M[1, s, s:];
target += (log(Lambda[2, s, s:]) + log(p[3, s:])') * M[2, s, s:];
target += (log(Lambda[4, s, s:]) + log(p[3, s:])') * M[4, s, s:];

target += log(Lambda[3, s, s:]) * M[3, s, s:];
target += log(Lambda[5, s, s:]) * M[5, s, s:];
target += log(Lambda[6, s, s:]) * M[6, s, s:];
}

  array[4, G] vector[T]      l_Lambda_0 = rep_array(rep_vector(log(0), T), 4, G);
array[6, T] vector[T]      l_Lambda   = rep_array(rep_vector(log(0), T), 6, T);
array vector[T]         l_Chi;
vector[G]                  l_Chi_0;

l_Chi = log1m(phi);
target += dot_product(l_Chi, L);
for (u in 1:T){
int s = T - u + 1;
real temp_chi;

temp_chi    = log_sum_exp(l_alpha[2, s, s:]' + log1m(p[3, s:]) + l_Chi[3, s:]);
l_Chi[2, s] = log_sum_exp(log1m(phi[2, s]), log(phi[2, s]) + temp_chi);
target += l_Chi[2, s] * L[2, s];
temp_chi    = log_sum_exp(l_alpha[1, s, s:]' + log1m(p[2, s:]) + l_Chi[2, s:]);
l_Chi[1, s] = log_sum_exp(log1m(phi[1, s]), log(phi[1, s]) + temp_chi);
target += l_Chi[1, s] * L[1, s];
}

for (g in 1:G){
int s = rls_day[g];
real temp;

temp   = log_sum_exp(l_alpha_0[g, s:] + log1m(p[1, s:]) + l_Chi[1, s:]);
l_Chi_0[g] = log_sum_exp(log1m(phi_0[g]), log(phi_0[g]) + temp);
target += l_Chi_0[g] * L_0[g];

l_Lambda_0[1, g, rls_day[g]:] = log(phi_0[g]) + l_alpha_0[g, rls_day[g]:];
for (k in 1:3)
for (t in s:T){
l_Lambda_0[k + 1, g, t] = log_sum_exp(l_Lambda_0[k, g, s:t] +
log1m(p[k, s:t]) +
log(phi[k, s:t]) +
l_alpha[k, s:t, t]);
//increment target elementwise for k = 1:3
target += (l_Lambda_0[k, g, t] + log(p[k, t])) * M_0[k, g, t];
}
target += dot_product(l_Lambda_0[4, g, s:], M_0[4, g, s:]);
}

for (s in 1:T){
l_Lambda[1, s, s:] = log(phi[1, s:]) + l_alpha[1, s, s:]';
target += dot_product(l_Lambda[1, s, s:] + log(p[2, s:]), M[1, s, s:]);
l_Lambda[2, s, s:] = log(phi[2, s:]) + l_alpha[2, s, s:]';
target += dot_product(l_Lambda[2, s, s:] + log(p[3, s:]), M[2, s, s:]);
l_Lambda[3, s, s:] = log(phi[3, s:]) + l_alpha[3, s, s:]';
target += dot_product(l_Lambda[3, s, s:], M[3, s, s:]);

for (t in s:T){
l_Lambda[4, s, t] = log_sum_exp(l_Lambda[1, s, s:t] + log1m(p[2, s:t]) +
log(phi[2, s:t]) + l_alpha[2, s:t, t]);
target += (l_Lambda[4, s, t] + log(p[3, t])) * M[4, s, t];
l_Lambda[5, s, t] = log_sum_exp(l_Lambda[2, s, s:t] + log1m(p[3, s:t]) +
log(phi[3, s:t]) + l_alpha[3, s:t, t]);
target += l_Lambda[5, s, t] * M[5, s, t];
l_Lambda[6, s, t] = log_sum_exp(l_Lambda[4, s, s:t] + log1m(p[3, s:t]) +
log(phi[3, s:t]) + l_alpha[3, s:t, t]);
target += l_Lambda[6, s, t] * M[6, s, t];
}
}


Representative run times for each version with 500 warmup and 500 iterations:
1.


>  $total >  361.4987 > >$chains
>   chain_id  warmup sampling   total
> 1        1 199.835  147.020 346.855
> 2        2 217.435  143.271 360.706
> 3        3 217.878  143.112 360.990
> 4        4 192.084  148.456 340.540


$total  581.7066$chains
chain_id  warmup sampling   total
1        1 321.129  221.260 542.389
2        2 366.751  214.394 581.145
3        3 334.843  218.219 553.062
4        4 359.310  217.396 576.706


$total  76.42366$chains
chain_id warmup sampling  total
1        1 42.626   33.593 76.219
2        2 39.638   33.797 73.435
3        3 41.791   33.490 75.281
4        4 38.522   34.027 72.549


$total  574.1323$chains
chain_id  warmup sampling   total
1        1 326.993  231.003 557.996
2        2 320.375  235.021 555.396
3        3 321.529  223.134 544.663
4        4 346.883  226.526 573.409


The mostly natural-scale algorithm is 8x faster than either log-scale version. I would much prefer to use this version, but what am I risking doing so?
simulate_TSCJS_4K.R (11.2 KB)
TSCJS_4K_Simulated_NoRE.stan (9.5 KB)
TSCJS_4K_Simulated_NoRE_FA_logscale.stan (11.5 KB)
TSCJS_4K_Simulated_NoRE_FA_natscale.stan (10.3 KB)
TSCJS_4K_Simulated_NoRE_logscale_matrixform.stan (10.9 KB)
biom.13171.pdf (1.4 MB)

2 Likes

Hey @Dalton , I haven’t yet groked all the math here to know for sure if what I have to say is relevant, but just in case it helps:

1. If desired, you can save the necessary quantities to check for underflow explicitly. Perhaps it would provide some comfort if you can confirm after-the-fact that you never see literal zeros in places where literal zeros shouldn’t happen.
2. You might not need to worry so much when some probabilities are small and others are large. For example if a product of transition matrices has some not-very-small elements, and some truly minuscule elements, it might be of little practical consequence for inference if the minuscule elements underflow (i.e. if they represent vanishingly small probabilities that are genuinely negligible).
3. On the other hand, if all elements get very small, then inference could get really wonky if some happen to underflow and others do not. However, since all elements are getting small, you can just multiply everything by a scalar constant at various points during the computation, and the divide (subtract) out those scalar constants on the log scale at the end. This is related to the numerical trick that log_sum_exp uses for stability.

Thanks for the insight @jsocolar.

1. If desired, you can save the necessary quantities to check for underflow explicitly. Perhaps it would provide some comfort if you can confirm after-the-fact that you never see literal zeros in places where literal zeros shouldn’t happen.

This is a good check. In practice I would also expect to get errors about the log probability equaling log(0) whenever an underflow term appears on the righthand side of a ‘target +=’ statement. In Stan log(0) * 0 is negative infinity. Due to the nature of the problem, my transition matrices are usually upper triangular matrices, and so the lower half of the \lambda matrices are all zero. That’s why I have to use a for loop for these terms to index out the known zero terms.

1. You might not need to worry so much when some probabilities are small and others are large. For example if a product of transition matrices has some not-very-small elements, and some truly minuscule elements, it might be of little practical consequence for inference if the minuscule elements underflow (i.e. if they represent vanishingly small probabilities that are genuinely negligible).

Well it’s mark-recapture, so it’s a product multinomial. That means the sum of all the \lambda p and \chi terms is 1. Assuming that the transition matrix has no non-zero elements, the number of terms that must sum to 1 is something like \sum_{i = 1}^K (i \times T^2) + (K \times T). So the limit of where I’d like to take this model eventually is to follow a few million tagged fish as they migrate pass 8 dams of the Federal Columbia River Power System over a 200 day monitoring period in a given year. So that’s like 750,000 terms in the likelihood (assuming triangular transition matrices). Not sure I’ll ever get there, but that’s the motivation of the question. But in that circumstance, there’s going to be some really small terms especially since in recent years detection probability can be around 0.01 or less.

1. On the other hand, if all elements get very small, then inference could get really wonky if some happen to underflow and others do not. However, since all elements are getting small, you can just multiply everything by a scalar constant at various points during the computation, and the divide (subtract) out those scalar constants on the log scale at the end. This is related to the numerical trick that log_sum_exp uses for stability.

This is brilliant. This might just be the solution I need if I start running into underflow, particularly if I just use it for the transition matrices, which is where I’m most likely to run into underflow issues. So maybe I just return each row of the transition matrix as something like exp(log(c) + log(y) - log_sum_exp(y)) That’d be sort of like a softmax that sums to some constant c rather than 1.

1 Like