Hi @wesmith2 , I’ve used multi-threading with multievent models before (e.g. here) but that may not be the simplest model to begin with. Below is a Stan program with a reduce_sum implementation of the example model you linked to:
// Modification of code at
// https://github.com/stan-dev/example-models/blob/master/BPA/Ch.09/ms3_multinomlogit.stan
// -------------------------------------------------
// States (S):
// 1 alive at A
// 2 alive at B
// 3 alive at C
// 4 dead
// Observations (O):
// 1 seen at A
// 2 seen at B
// 3 seen at C
// 4 not seen
// -------------------------------------------------
functions {
/**
* Return an integer value denoting occasion of first capture.
* This function is derived from Stan Modeling Language
* User's Guide and Reference Manual.
*
* @param y Observed values
* @return Occasion of first capture
*/
int first_capture(array[] int y_i) {
for (k in 1 : size(y_i)) {
if (y_i[k] != 4) {
return k;
}
}
return 0;
}
/**
* Return a simplex such as follows (thanks to Bob Carpenter):
* p[1] <- exp(lp[1]) / (1.0 + exp(lp[1]) + exp(lp[2]));
* p[2] <- exp(lp[2]) / (1.0 + exp(lp[1]) + exp(lp[2]));
* p[3] <- 1.0 - p[1] - p[2];
*
* @param lp N-dimension vector
* @return (N+1)-simplex of given vector and 0
*/
vector softmax_0(vector lp) {
vector[num_elements(lp) + 1] lp_temp;
lp_temp[1 : num_elements(lp)] = lp;
lp_temp[num_elements(lp) + 1] = 0;
return softmax(lp_temp);
}
/**
* Return log-likelihood from subset of capture histories
*
* @param y_slice the subset of capture histories
* @param start index (in y) of first capture history in subset
* @param end index (in y) of last capture history in subset
* @param n_occasion number of occasions
* @param first first capture occasions
* @param ps array of state transition probabilities
* @param po array of state-dependent detection probabilities
* @return log-likelihood for the given subset of capture histories
*/
real partial_sum(
array[,] int y_slice,
int start,
int end,
int n_occasions,
array[] int first,
array[,,] vector ps,
array[,,] vector po
) {
real ll_term = 0;
// Likelihood using forward algorithm derived from Stan Modeling Language
// User's Guide and Reference Manual
for (i in start:end) {
int iy = i - start + 1;
array[4] real acc;
array[n_occasions] vector[4] gamma;
if (first[i]>0) {
for (k in 1:4) {
gamma[first[i], k] = y_slice[iy, first[i]] == k;
}
for (t in (first[i] + 1):n_occasions) {
for (k in 1:4) {
for (j in 1:4) {
acc[j] = gamma[t - 1, j] * ps[j, i, t - 1, k]
* po[k, i, t - 1, y_slice[iy, t]];
}
gamma[t, k] = sum(acc);
}
}
ll_term = ll_term + log(sum(gamma[n_occasions]));
}
}
return ll_term;
}
}
data {
int<lower=0> nind;
int<lower=0> n_occasions;
array[nind, n_occasions] int<lower=1, upper=4> y;
}
transformed data {
int n_occ_minus_1 = n_occasions - 1;
array[nind] int<lower=0, upper=n_occasions> first;
for (i in 1 : nind) {
first[i] = first_capture(y[i]);
}
}
parameters {
real<lower=0, upper=1> phiA; // Survival probability at site A
real<lower=0, upper=1> phiB; // Survival probability at site B
real<lower=0, upper=1> phiC; // Survival probability at site C
real<lower=0, upper=1> pA; // Recapture probability at site A
real<lower=0, upper=1> pB; // Recapture probability at site B
real<lower=0, upper=1> pC; // Recapture probability at site C
vector[2] lpsiA; // Logit of movement probability from site A
vector[2] lpsiB; // Logit of movement probability from site B
vector[2] lpsiC; // Logit of movement probability from site C
}
transformed parameters {
simplex[3] psiA; // Movement probability from site A
simplex[3] psiB; // Movement probability from site B
simplex[3] psiC; // Movement probability from site C
array[4, nind, n_occ_minus_1] simplex[4] ps;
array[4, nind, n_occ_minus_1] simplex[4] po;
// Constrain the transitions such that their sum is < 1
psiA = softmax_0(lpsiA);
psiB = softmax_0(lpsiB);
psiC = softmax_0(lpsiC);
// Define state-transition and observation matrices
for (i in 1 : nind) {
// Define probabilities of state S(t+1) given S(t)
for (t in 1 : n_occ_minus_1) {
ps[1, i, t, 1] = phiA * psiA[1];
ps[1, i, t, 2] = phiA * psiA[2];
ps[1, i, t, 3] = phiA * psiA[3];
ps[1, i, t, 4] = 1.0 - phiA;
ps[2, i, t, 1] = phiB * psiB[1];
ps[2, i, t, 2] = phiB * psiB[2];
ps[2, i, t, 3] = phiB * psiB[3];
ps[2, i, t, 4] = 1.0 - phiB;
ps[3, i, t, 1] = phiC * psiC[1];
ps[3, i, t, 2] = phiC * psiC[2];
ps[3, i, t, 3] = phiC * psiC[3];
ps[3, i, t, 4] = 1.0 - phiC;
ps[4, i, t, 1] = 0.0;
ps[4, i, t, 2] = 0.0;
ps[4, i, t, 3] = 0.0;
ps[4, i, t, 4] = 1.0;
// Define probabilities of O(t) given S(t)
po[1, i, t, 1] = pA;
po[1, i, t, 2] = 0.0;
po[1, i, t, 3] = 0.0;
po[1, i, t, 4] = 1.0 - pA;
po[2, i, t, 1] = 0.0;
po[2, i, t, 2] = pB;
po[2, i, t, 3] = 0.0;
po[2, i, t, 4] = 1.0 - pB;
po[3, i, t, 1] = 0.0;
po[3, i, t, 2] = 0.0;
po[3, i, t, 3] = pC;
po[3, i, t, 4] = 1.0 - pC;
po[4, i, t, 1] = 0.0;
po[4, i, t, 2] = 0.0;
po[4, i, t, 3] = 0.0;
po[4, i, t, 4] = 1.0;
}
}
}
model {
array[4] real acc;
array[n_occasions] vector[4] gamma;
// Priors
// Survival and recapture: uniform
// Uniform priors are implicitly defined.
// phiA ~ uniform(0, 1);
// phiB ~ uniform(0, 1);
// phiC ~ uniform(0, 1);
// pA ~ uniform(0, 1);
// pB ~ uniform(0, 1);
// pC ~ uniform(0, 1);
// Normal priors on logit of all but one transition probs
lpsiA ~ normal(0, sqrt(1000));
lpsiB ~ normal(0, sqrt(1000));
lpsiC ~ normal(0, sqrt(1000));
// Likelihood using reduce_sum
int grainsize = 1;
target += reduce_sum(partial_sum, y, grainsize, n_occasions, first, ps, po);
}
The likelihood calculation using the forward algorithm gets moved into the partial_sum function. If I understand correctly (based on material here and here), partial_sum gets fed a subset of capture histories (something like y[101:200,]) as the argument y_slice and returns the log-likelihood for the whole subset. One has to be careful that indices into y_slice and other objects first, ps, po correspond to the same individual. In the loop over individuals for (i in start:end){...}, the intention is that i indexes into first, ps, po, whilst iy picks out the corresponding individual in y_slice.
I tested the above with the following R code to simulate data under the model and fit it (assuming you have the above Stan program saved as multistate_parallel.stan). I just ran one chain with a few threads to check that it was working; of course you’ll want several chains for serious work.
# Testing the Stan code
library(cmdstanr)
# Simulate some data
# numbers captured
nind_per_occasion <- 50
n_occasions <- 5
nind <- nind_per_occasion * (n_occasions - 1)
first <- rep(1:(n_occasions -1 ), each = nind_per_occasion)
first_state <- sample(1:3, nind, replace = TRUE) # sample first states randomly
# Set parameter values
phiA <- 0.8
phiB <- 0.5
phiC <- 0.6
pA <- 0.7
pB <- 0.9
pC <- 0.6
psiA <- c(0.8, 0.1, 0.1)
psiB <- c(0.1, 0.8, 0.1)
psiC <- c(0.1, 0.1, 0.8)
# Calculate state transition and observation probabilities
ps <- array(dim = c(4, nind, n_occasions-1, 4))
po <- array(dim = c(4, nind, n_occasions-1, 4))
for (i in 1 : nind) {
# Define probabilities of state S(t+1) given S(t)
for (t in 1 : (n_occasions - 1)) {
ps[1, i, t, 1] = phiA * psiA[1];
ps[1, i, t, 2] = phiA * psiA[2];
ps[1, i, t, 3] = phiA * psiA[3];
ps[1, i, t, 4] = 1.0 - phiA;
ps[2, i, t, 1] = phiB * psiB[1];
ps[2, i, t, 2] = phiB * psiB[2];
ps[2, i, t, 3] = phiB * psiB[3];
ps[2, i, t, 4] = 1.0 - phiB;
ps[3, i, t, 1] = phiC * psiC[1];
ps[3, i, t, 2] = phiC * psiC[2];
ps[3, i, t, 3] = phiC * psiC[3];
ps[3, i, t, 4] = 1.0 - phiC;
ps[4, i, t, 1] = 0.0;
ps[4, i, t, 2] = 0.0;
ps[4, i, t, 3] = 0.0;
ps[4, i, t, 4] = 1.0;
# Define probabilities of O(t) given S(t)
po[1, i, t, 1] = pA;
po[1, i, t, 2] = 0.0;
po[1, i, t, 3] = 0.0;
po[1, i, t, 4] = 1.0 - pA;
po[2, i, t, 1] = 0.0;
po[2, i, t, 2] = pB;
po[2, i, t, 3] = 0.0;
po[2, i, t, 4] = 1.0 - pB;
po[3, i, t, 1] = 0.0;
po[3, i, t, 2] = 0.0;
po[3, i, t, 3] = pC;
po[3, i, t, 4] = 1.0 - pC;
po[4, i, t, 1] = 0.0;
po[4, i, t, 2] = 0.0;
po[4, i, t, 3] = 0.0;
po[4, i, t, 4] = 1.0;
}
}
# Simulate true latent states
z <- matrix(NA, nrow = nind, ncol = n_occasions)
for (i in 1:nind) {
z[i, first[i]] <- first_state[i]
for (t in first[i]:(n_occasions - 1)) {
z[i,t+1] <- sample(1:4, 1, prob = ps[z[i,t], i, t, ])
}
}
# Simulate the observed states
y <- matrix(4, nrow = nind, ncol = n_occasions) # '4' is not-captured code
for (i in 1:nind) {
y[i, first[i]] <- z[i, first[i]] # captured in true state at first capture
for (t in (first[i]+1):n_occasions) {
y[i,t] <- sample(1:4, 1, prob = po[z[i,t], i, t-1, ]) # time index for 'po' is one behind y
}
}
# Fit the model using within-chain parallelisation
library(cmdstanr)
file <- "multistate_parallel.stan"
mod <- cmdstan_model(file, cpp_options = list(stan_threads = TRUE))
stan_data <- list(nind=nind, n_occasions=n_occasions, y=y)
fit <- mod$sample(stan_data, chains = 1, threads_per_chain = 4)
BEWARE! In the small dataset generated from the above R code, the reduce_sum() implementation was slower than the original code without it. I dimly recall that there might be some overhead to multi-threading which might explain that (but for larger datasets it would still be profitable to use within-chain parallelization?). The other explanation is that I’ve done something silly - so use with caution!