Hi all,
I have a reasonably complex partial sum function that computes the forward algorithm for a hidden Markov model two layers of hidden states. In the partial sum function, I use the control flow with if
statement to either compute the probabilities directly or to work with the log probabilities. I would like to have both methods available in the partial sum function.
/**
* Compute partial sum over individuals of forward algorithm for hidden Markov model (HMM)
*
* @return Log probability for subset of individuals
*/
real partial_sum_lpmf(
data array[] int seq_ind, data int start, data int end,
data array[,,,] int y, data array[,,] int q, data int n_s, data int n_o, data int n_d, data int n_prim, data int n_sec, data array[] int first, data array[] int last, data array[] int first_sec, data vector tau,
real eta_a, vector phi_a, real phi_b, vector psi_a, vector p_a, vector r, vector fr, matrix m, array[] matrix n_z, real sigma, data int log_p) {
// number of individuals in partial sum and initialise partial target
int n_ind = end - start + 1;
real ptarget = 0;
// parameters
real eta;
array[n_s - 1] vector[n_prim - 1] phi, psi;
array[n_s - 1] matrix[n_sec, n_prim] p;
array[n_ind] matrix[n_sec, n_prim] n;
tuple(vector[n_prim], matrix[n_sec, n_prim]) delta;
// transform intercepts
vector[n_s - 1] log_phi_a = log(phi_a);
// transition rate/probability matrices (trm/tpm)
tuple(array[n_prim - 1] matrix[n_s, n_s - 1],
array[n_prim, n_sec] matrix[n_s - 1, n_o],
array[n_prim, n_sec] matrix[n_o - 1, n_d - 1]) tpm;
array[n_prim - 1] matrix[n_s, n_s] trm;
// initialise (log) probabilities
array[n_prim] matrix[n_o - 1, n_sec] acc;
matrix[n_s, n_prim] gamma;
// for each individual
for (i in 1:n_ind) {
// index that maps to original n_ind
int ii = i + start - 1;
// probability of entering as infected
eta = eta_a;
// mortality rates
phi[1] = rep_vector(phi_a[1], n_prim - 1);
phi[2] = exp(log_phi_a[2] + phi_b * m[1:(n_prim - 1), ii]);
// infection state transition rates
psi[1] = rep_vector(psi_a[1], n_prim - 1);
psi[2] = rep_vector(psi_a[2], n_prim - 1);
// for successive primary occasions
for (t in first[ii]:(n_prim - 1)) {
// ecological trm
trm[t][1, 1] = -(psi[1][t] + phi[1][t]);
trm[t][1, 2] = psi[2][t];
trm[t][2, 1] = psi[1][t];
trm[t][2, 2] = -(psi[2][t] + phi[2][t]);
trm[t][3, 1] = phi[1][t];
trm[t][3, 2] = phi[2][t];
trm[t][:, 3] = zeros_vector(3);
// ecological (log) tpm
tpm.1[t] = matrix_exp(trm[t] * tau[t])[:, 1:2];
} // t
// log-transform for log probabilities
if (log_p) {
tpm.1 = log(tpm.1);
}
// sample infection intensities
ptarget += std_normal_lupdf(to_vector(n_z[ii]));
n[i] = exp(rep_matrix(log(m[:, ii]), n_sec)' + n_z[ii] * sigma);
// individual and sample infection detection probabilities
delta.1 = 1 - pow(1 - r[1], m[:, ii]);
delta.2 = 1 - pow(1 - r[2], n[i]);
// detection probabilities (fixed at 1 for secondary of first capture)
for (s in 1:(n_s - 1)) {
p[s] = rep_matrix(p_a[s], n_sec, n_prim);
p[s][first_sec[ii], first[ii]] = 1;
} // s
// for each primary occasion
for (t in first[ii]:n_prim) {
// for each secondary occasion
for (k in 1:n_sec) {
// observation and diagnostic processes
if (log_p) {
// observation tpm
tpm.2[t, k][1, 1] = log(p[1][k, t]) + log1m(fr[1]);
tpm.2[t, k][1, 2] = log(p[1][k, t]) + log(fr[1]);
tpm.2[t, k][1, 3] = log1m(p[1][k, t]);
tpm.2[t, k][2, 1] = log(p[2][k, t]) + log1m(delta.1[t]);
tpm.2[t, k][2, 2] = log(p[2][k, t]) + log(delta.1[t]);
tpm.2[t, k][2, 3] = log1m(p[2][k, t]);
// diagnostic tpm
tpm.3[t, k][1, 1] = log1m(fr[2]);
tpm.3[t, k][1, 2] = log(fr[2]);
tpm.3[t, k][2, 1] = log1m(delta.2[k, t]);
tpm.3[t, k][2, 2] = log(delta.2[k, t]);
// log probability of each alive observed state conditioned on data
if (q[ii, t, k] == 1) {
for (o in 1:(n_o - 1)) {
acc[t][o, k] = sum(tpm.3[t, k][o, y[ii, t, k]]);
}
}
} else {
// observation tpm
tpm.2[t, k][1, 1] = p[1][k, t] * (1 - fr[1]);
tpm.2[t, k][1, 2] = p[1][k, t] * fr[1];
tpm.2[t, k][1, 3] = 1 - p[1][k, t];
tpm.2[t, k][2, 1] = p[2][k, t] * (1 - delta.1[t]);
tpm.2[t, k][2, 2] = p[2][k, t] * delta.1[t];
tpm.2[t, k][2, 3] = 1 - p[2][k, t];
// diagnostic tpm
tpm.3[t, k][1, 1] = 1 - fr[2];
tpm.3[t, k][1, 2] = fr[2];
tpm.3[t, k][2, 1] = 1 - delta.2[k, t];
tpm.3[t, k][2, 2] = delta.2[k, t];
// probability of each alive observed state conditioned on data
if (q[ii, t, k] == 1) {
for (o in 1:(n_o - 1)) {
acc[t][o, k] = prod(tpm.3[t, k][o, y[ii, t, k]]);
}
}
}
} // k
// initialise (log) marginal probability with first secondary (required to avoid gamma = 0)
if (q[ii, t, 1] == 0) {
gamma[1:2, t] = tpm.2[t, 1][:, 3];
} else {
if (log_p) {
gamma[1:2, t] = log_mat_prod(tpm.2[t, 1][:, 1:2], acc[t][:, 1]);
} else {
gamma[1:2, t] = tpm.2[t, 1][:, 1:2] * acc[t][:, 1];
}
}
// for successive primaries
for (k in 2:n_sec) {
// if individual was not detected
if (q[ii, t, k] == 0) {
// marginal detection (log) probabilities
if (log_p) {
gamma[1:2, t] += tpm.2[t, k][:, 3];
} else {
gamma[1:2, t] .*= tpm.2[t, k][:, 3];
}
} else {
// marginal (log) probability after each secondary for each alive ecological state
if (log_p) {
gamma[1:2, t] += log_mat_prod(tpm.2[t, k][:, 1:2], acc[t][:, k]);
} else {
gamma[1:2, t] .*= tpm.2[t, k][1:2, 1:2] * acc[t][:, k];
}
}
} // k
} // t
// marginal (log) probability of alive states at first capture
if (log_p) {
gamma[1:2, first[ii]] += [ log1m(eta), log(eta) ]';
} else {
gamma[1:2, first[ii]] .*= [ 1 - eta, eta ]';
}
// marginal (log) probability of alive states between first and last primary with captures
if (first[ii] < last[ii]) {
for (t in (first[ii] + 1):last[ii]) {
if (log_p) {
gamma[1:2, t] += log_mat_prod(tpm.1[t - 1][1:2, :], gamma[1:2, t - 1]);
} else {
gamma[1:2, t] .*= tpm.1[t - 1][1:2, 1:2] * gamma[1:2, t - 1];
}
} // t
}
// if last captured in the last primary
if (last[ii] == n_prim) {
// increment target with marginal log probabilities of alive states
if (log_p) {
ptarget += log_sum_exp(gamma[1:2, n_prim]);
} else {
ptarget += log(sum(gamma[1:2, n_prim]));
}
} else {
// marginal (log) probabilities of all ecological states in primary after last capture
if (log_p) {
gamma[1:2, last[ii] + 1] += log_mat_prod(tpm.1[last[ii]][1:2, :], gamma[1:2, last[ii]]);
gamma[3, last[ii] + 1] = log_mat_prod(tpm.1[last[ii]][3, :], gamma[1:2, last[ii]]);
} else {
gamma[1:2, last[ii] + 1] .*= tpm.1[last[ii]][1:2, 1:2] * gamma[1:2, last[ii]];
gamma[3, last[ii] + 1] = tpm.1[last[ii]][3, 1:2] * gamma[1:2, last[ii]];
}
// if first occasion after last capture is the last primary
if ((last[ii] + 1) == n_prim) {
// increment target with marginal log probabilities of all ecological states
if (log_p) {
ptarget += log_sum_exp(gamma[:, n_prim]);
} else {
ptarget += log(sum(gamma[:, n_prim]));
}
}
else {
// marginal (log) probabilities of all ecological states until last primary
for (t in (last[ii] + 2):n_prim) {
if (log_p) {
gamma[1:2, t] += log_mat_prod(tpm.1[t - 1][1:2, :], gamma[1:2, t - 1]);
gamma[3, t] = log_sum_exp(log_mat_prod(tpm.1[t - 1][3, :], gamma[1:2, t - 1]), gamma[3, t - 1]);
} else {
gamma[1:2, t] .*= tpm.1[t - 1][1:2, 1:2] * gamma[1:2, t - 1];
gamma[3, t] = tpm.1[t - 1][3, 1:2] * gamma[1:2, t - 1] + gamma[3, t - 1];
}
} // t
// increment target with marginal probabilities of all ecological states
if (log_p) {
ptarget += log_sum_exp(gamma[:, n_prim]);
} else {
ptarget += log(sum(gamma[:, n_prim]));
}
}
}
} // i
return(ptarget);
}
In the model
block, I call the partial sum function as follows to allow experimentation with different grainsizes:
if (grainsize == 1) {
target += reduce_sum(partial_sum_lupmf, seq_ind, grainsize, y, q, n_s, n_o, n_d, n_prim, n_sec, first, last, first_sec, tau, eta_a, phi_a, phi_b, psi_a, p_a, r, fr, m, n_z, sigma[2], log_p);
} else {
target += reduce_sum_static(partial_sum_lupmf, seq_ind, grainsize, y, q, n_s, n_o, n_d, n_prim, n_sec, first, last, first_sec, tau, eta_a, phi_a, phi_b, psi_a, p_a, r, fr, m, n_z, sigma[2], log_p);
}
Now, when I give grainsize = 1
in R when I call cmdstanr::sample()
, the model runs fine with both log_p = 0
(doing the likelihood calculations on the probability scale) and with log_p = 1
(doing the likelihood calculations on the log scale) and I get identical results. However, when I change grainsize
to something else, say grainsize = floor(n_ind / n_threads / 2)
, I get the following long list of errors:
Chain 1 Rejecting initial value:
Chain 1 Gradient evaluated at the initial value is not finite.
Chain 1 Stan can't start sampling from this initial value.
Chain 1 Initialization between (-0.1, 0.1) failed after 100 attempts.
Chain 1 Try specifying initial values, reducing ranges of constrained values, or reparameterizing the model.
Chain 1 Initialization failed.
So, it strikes me that the error is coming from changing the grainsize in the partial sum function. I can share the whole Stan program including data simulation, but in the meantime, is this known/expected behaviour?
thanks,
Matt