Continuous time multi-state model w/ Hawkes detection process

I’m trying to fit a multi-state model where both the state process and detection process are modeled in continuous time but am running to some indentifiability issues that I am hoping there is a solution for.

The following example code assumes two live states (e.g., breeding/non-breeding, susceptible/infected) plus a dead state, the transition rate from state 1 to state 2 increases with time (modeled as a Weibull distribution governed by parameters g and k, as described here), and state-specific hazard rates (h[1] & h[2]) are constant. The detection process is governed by a exponential Hawkes process, with state-specific parameters (mu, alpha, and beta).

data {
  int<lower = 0> n_ind;
  int<lower = 0> n_occasions;
  int<lower = 1> max_N;
  int<lower = 1> max_N_plus_1;
  int<lower = 0> N[n_ind];
  int<lower = 0> N_plus_1[n_ind];
  int<lower = 0> N_1[n_ind];
  int<lower = 0> N_2[n_ind];
  real<lower = 0, upper = n_occasions> delta[n_ind, max_N_plus_1];
  real<lower = 0, upper = n_occasions> delta_occ[n_ind, max_N];
  int<lower = 0, upper = n_occasions> d[n_ind, max_N];
  real<lower = 0, upper = n_occasions> delta_upper[n_ind, max_N];
  real<lower = 0, upper = n_occasions> delta_lower[n_ind, max_N];
  real<lower=0> occ[n_occasions];
  matrix[1, 3] f;
  matrix[3, 1] ones;
}

parameters {
  real<lower=0> h[2];
  real<lower=0> g;
  real<lower=0> k;
  vector<lower=0>[2] mu;
  vector<lower=0>[2] beta;
  vector<lower=0>[2] alpha;
}

transformed parameters {
  real<lower=0> nu[n_occasions];
  matrix[3, 3] Q[n_occasions];

  // Define IG matrix   
  for(t in 1:n_occasions){
    nu[t] = k * occ[t] ^ (k - 1)/g ^ k;
    
    Q[t] = rep_matrix(0, 3, 3);
    
    Q[t, 1, 1] = -(nu[t] + h[1]);
    Q[t, 1, 2] = nu[t]; 
    Q[t, 1, 3] = h[1]; 
  
    Q[t, 2, 2] = -h[2];
    Q[t, 2, 3] = h[2];
  }
}

model {
  vector[3] lambda = rep_vector(0, 3);
  vector[3] Lambda = rep_vector(0, 3);
  matrix[1, 3] gamma;
  matrix[3, 3] Gamma;
  matrix[3, 3] Qt;
  vector[2] B;
    
  // Priors
  h ~ gamma(0.25, 10); // Hazard rate
  
  g ~ gamma(2, 0.5); // Transition parameters
  k ~ gamma(2, 0.5);
  
  mu ~ gamma(1, 2); // Hawkesprocess parameters
  beta ~ gamma(1, 2);
  alpha ~ gamma(1, 2);

  // Likelihood
  for(i in 1:n_ind) {
    vector[2] A[N_plus_1[i]];
    matrix[3, 3] Omega[N_plus_1[i]];
   
    A[1] = rep_vector(0, 2);
    B = rep_vector(0, 2);
   
   if(N[i] > 0){ // if detected
      for(j in 1:N_1[i]){
        if(j > 1) B[1] = (1 + B[1]) * exp(-beta[1] * delta[i, j]); // Crowley 2015 Eq. 4
          
        A[j + 1, 1] = 1 + exp(-beta[1] * delta[i, j]) * A[j, 1]; // Crowley 2015 Eq. 21
        A[j + 1, 2] = 0;
        
        Lambda[1:2] = delta[i, j] * mu + (alpha ./ beta) .* (1 - exp(-beta * delta[i, j])) .* A[j]; // Crowley 2015 Eq. 23
        lambda[1:2] = mu + alpha .* B; // Crowley 2015 Eq. 17
        
        if(j == 1){
         if(delta_occ[i, 1] == 0){
           Gamma = diag_post_multiply(matrix_exp(Q[1] * delta[i, 1] - diag_matrix(Lambda)), lambda);
         }else{
          Qt = Q[1];
            
          if(delta_occ[i, 1] > 1){
            for(o in 2:(d[i, 1] - 1)){
              Qt += Q[o];
            }
          }

          Qt += Q[d[i, 1]] * delta_lower[i, 1];
          
          Gamma = diag_post_multiply(matrix_exp(Qt - diag_matrix(Lambda)), lambda);
         }
        }else{
          if(delta_occ[i, j] == 0){
            Gamma = diag_post_multiply(matrix_exp(Q[d[i, j]] * delta[i, j] - diag_matrix(Lambda)), lambda);
          }else{
            Qt = Q[d[i, j - 1]] * delta_upper[i, j - 1];
            
            if(delta_occ[i, j] > 1){
              for(o in (d[i, j - 1] + 1):(d[i, j] - 1)){
                Qt += Q[o];
              }
            }

            Qt += Q[d[i, j]] * delta_lower[i, j] ;
            
            Gamma = diag_post_multiply(matrix_exp(Qt - diag_matrix(Lambda)), lambda);
          }
        }
        
        Omega[j] = rep_matrix(0, 3, 3);
        Omega[j, 1, 1] = Gamma[1, 1]; // likelihood of detection (state 1 to state 1)
      } // j
    
    // Detections in state 2
    if(N_2[i] > 0){ // if detected in state 2
      for(j in (N_1[i] + 1):N[i]){
        if(j > (N_1[i] + 1)){
          B[1] = 0;
          B[2] = (1 + B[2]) * exp(-beta[2] * delta[i, j]);
        } 
           
        A[j + 1, 1] = A[j, 1];
        A[j + 1, 2] = 1 + exp(-beta[2] * delta[i, j]) * A[j, 2];
        
        Lambda[1:2] = delta[i, j] * mu + (alpha ./ beta) .* (1 - exp(-beta * delta[i, j])) .* A[j];
        lambda[1:2] = mu + alpha .* B;
        
          if(delta_occ[i, j] == 0){
            Gamma = diag_post_multiply(matrix_exp(Q[d[i, j]] * delta[i, j] - diag_matrix(Lambda)), lambda);
          }else{
            Qt = Q[d[i, j - 1]] * delta_upper[i, j - 1];
            
            if(delta_occ[i, j] > 1){
              for(o in (d[i, j - 1] + 1):(d[i, j] - 1)){
                Qt += Q[o];
              }
            }
            
            Qt += Q[d[i, j]] * delta_lower[i, j];
            
            Gamma = diag_post_multiply(matrix_exp(Qt - diag_matrix(Lambda)), lambda);
          }
         
        Omega[j] = rep_matrix(0, 3, 3);
        if(j == (N_1[i] + 1)){
          Omega[j, 1, 2] = Gamma[1, 2]; // likelihood of detection (state 1 to state 2)
        }else{
          Omega[j, 2, 2] = Gamma[2, 2]; // likelihood of detection (state 2 to state 2)
        }
      } 
    }
    
     // last detection to T
     Lambda[1:2] = delta[i, N_plus_1[i]] * mu + (alpha ./ beta) .* (1 - exp(-beta * delta[i, N_plus_1[i]])) .* A[N_plus_1[i]];
       
     Qt = Q[d[i, N[i]]] * delta_upper[i, N[i]];
   
     if(d[i, N[i]] < n_occasions){
       for(o in (d[i, N[i]] + 1):n_occasions){
          Qt += Q[o];
        }
      }
     
      Gamma = matrix_exp(Qt - diag_matrix(Lambda));
     
      Omega[N_plus_1[i]] = rep_matrix(0, 3, 3);
      if(N_2[i] == 0){
        Omega[N_plus_1[i], 1, 1:3] = Gamma[1, 1:3]; // likelihood of last detection in state 1
      }else{
        Omega[N_plus_1[i], 2, 2:3] = Gamma[2, 2:3]; // likelihood of last detection in state 2
      }
  
      gamma = f * Omega[1];
      for(j in 2:N_plus_1[i]){
        gamma *= Omega[j];
      }
      target += log(gamma * ones);
   }else{
     // Never detected
       Lambda[1:2] = mu * n_occasions;
       Qt = Q[1];
       
       for(o in 2:n_occasions){
         Qt += Q[o];
       }
       
       Gamma = matrix_exp(Qt - diag_matrix(Lambda));

       Omega[1] = rep_matrix(0, 3, 3);
       Omega[1, 1, 1:3] = Gamma[1, 1:3];
       
       target += log(f * Omega[1] * ones);
     }
  }// i loop
} // end model"

Code for simulating data and fitting is included below. Note that the likelihood for the multi-state model uses the forward algorithm and the likelihood for the Hawkes process is described here (specifically, equations 4, 17, 21, 23, and 27) . The likelihoods are, I think, written correctly. This is based on 1) a similar model w/ constant state-specific detection rates (e.g., \lambda_s) estimates h, g, and k correctly, and 2) a similar model w/ constant transition rate estimates the Hawkes process parameters (mu, alpha, beta) correctly.

So the problem seems to be that temporal variation in transition rate + Hawkes detection process leads to identifiability issues. Specifically, the above model estimates h[1] = 0 (or very close), h[2] is biased high, and mu[2]/beta[2] are biased low. The model seems to think individuals that are last detected in state 1 all transition to state 2 (h[1] = 0), are not detected (mu[2] is very low), and then die (h[2] is too high). That’s my best guess at what’s happening, anyway.

In discrete-time models, lack of identifiability like this can often be solved using a robust design framework. In continuous time, I’m not sure there’s a solution (assuming I have correctly diagnosed the problem, which maybe I haven’t). It does seem like the state 2 detections should be sufficient to estimate mu[2] (they are in simplified models w/o temporal variation in transition probability) but perhaps not.

Any thoughts/advice on what the problem might be or if there are ways to solve it would be much appreciated.


A few minor notes:

  1. As written, the model assumes that any individual detected in state 2 was previously detected in state 1 (just to simplify the model code a bit by leaving out a couple of if statements). If you use a different seed or simulate different data sets, you will get an error if individuals are detected in state 2 but not state 1

  2. The parameter values used in the code below are somewhat arbitrary but the issue outlined above doesn’t appear to be sensitive to the chosen values

  3. I’m relatively new to Stan (learning while trying to get this model working) and the code above is almost certainly inefficient in places. This model doesn’t take terribly long to run but any tips on ways to increase efficiency are also appreciated.


Code for simulating data and fitting the model:

library(rstan)
library(expm)
library(hawkes)

set.seed(326)

### Set parameters
n.ind <- 100
n.occ <- 100
h <- c(0.005, 0.01)
g <- 6
k <- 8
mu <- c(0.5, 0.25)
alpha <- c(0.3, 0.3)
beta <- c(0.6, 0.6)

## Transition probabilities
occ <- seq(from = 0, to = 10, length.out = n.occ)

nu <- k * occ ^ (k - 1)/g ^ k
plot((1 - exp(-nu)) ~ occ, type = "l")

## IG matrix
Q <- array(0, dim = c(3, 3, n.occ))

Q[1, 2, ] <- nu
Q[1, 3, ] <- h[1]
Q[1, 1, ] <- -apply(Q[1, 2:3,], 2, sum)

Q[2, 2, ] <- -h[2]
Q[2, 3, ] <- h[2]

## Simulate true states
s <- matrix(NA, nrow = n.ind, ncol = n.occ)
s[,1] <- 1

for(i in 1:n.ind){
  for(t in 2:n.occ){
    s[i, t] <- which(rmultinom(1, 1, prob = expm::expm(Q[,,t])[s[i, t - 1],]) == 1)
  }
}

s[s == 3] <- 0

## Simulate detections
dets <- states <- vector(mode = "list", length = n.ind)

l1 <- apply(s, 1, function(x) sum(x == 1))
l2 <- apply(s, 1, function(x) sum(x == 2))

for(i in 1:n.ind){
  dets[[i]] <- simulateHawkes(lambda0 = mu[1], alpha = alpha[1], beta = beta[1], horizon = l1[i])[[1]]
  states[[i]] <- rep(1, length(dets[[i]]))

  dets2 <- simulateHawkes(lambda0 = mu[2], alpha = alpha[2], beta = beta[2], horizon = l2[i])[[1]]
  dets2 <- dets2 + max(which(s[i,] == 1))
  dets[[i]] <- c(dets[[i]], dets2)
  states[[i]] <- c(states[[i]], rep(2, length(dets2)))
}

N <- unlist(lapply(dets, length))

det <- state  <- matrix(0, nrow = n.ind, ncol = max(N))
for(i in 1:n.ind){
  if(N[i] > 0){
    det[i, 1:N[i]] <- dets[[i]] 
    state[i, 1:N[i]] <- states[[i]]
  }
}

N_1 <- apply(state, 1, function(x) sum(x == 1))
N_2 <- apply(state, 1, function(x) sum(x == 2))


d <- ceiling(det) # Occasion of detection

## Time between detections
delta <- matrix(0, nrow = n.ind, ncol = max(N) + 1)
delta[,1] <- det[,1]

for(i in 1:n.ind){
  if(N[i] > 1){
    for(j in 2:N[i]){
      delta[i, j] <- det[i, j] - det[i, j - 1]
    }
  }
  delta[i, N[i] + 1] <- n.occ - max(det[i,])
}

# Number of occasions between detections; upper/lower intervals
delta_occ <- delta_lower <-  delta_upper <- matrix(0, nrow = n.ind, ncol = max(N))

for(i in 1:n.ind){
  if(N[i] > 0){
    delta_occ[i, 1] <- d[i, 1] - 1
    delta_lower[i, 1] <- det[i, 1] - d[i, 1] + 1
    delta_upper[i, 1] <- d[i, 1] - det[i, 1]
    if(N[i] > 1){
      for(j in 2:N[i]){
        delta_occ[i, j] <- d[i, j] - d[i, j - 1]
        delta_lower[i, j] <- det[i, j] - d[i, j] + 1
        delta_upper[i, j] <- d[i, j] - det[i, j]
      } 
    }
  }
}

## Data
stan_data <- list(n_ind = n.ind, 
                  n_occasions = n.occ, 
                  occ = occ,
                  d = d,
                  delta = delta, 
                  delta_occ = delta_occ,
                  delta_lower = delta_lower,
                  delta_upper = delta_upper,
                  N = N,
                  N_plus_1 = N + 1,
                  N_1 = N_1,
                  N_2 = N_2,
                  max_N = max(N),
                  max_N_plus_1 = max(N) + 1,
                  f = matrix(c(1, rep(0, 2)), 1, 3),
                  ones = matrix(rep(1, 3), 3, 1))

initf1 <- function() {
  list(h = runif(2, 0.001, 0.005),
       g = runif(1, 9, 11),
       k = runif(1, 9, 11),
       mu = runif(2, 0.2, 0.4),
       alpha = runif(2, 0.2, 0.4), 
       beta = runif(2, 0.4, 0.6))
}


fit <- rstan::stan(file = "hawkes_2_states_eg.stan",
                   data = stan_data,
                   pars = c("h", "g", "k", "mu", "alpha", "beta"),
                   init = initf1,
                   warmup = 250,
                   iter = 500, # May need to be set higher to achieve adequate sample 
                   chains = 4,
                   cores = 4,
                   thin = 1)
print(fit, digits_summary = 4)