Multivariate Hidden Markov Model

I will like to find out if there are examples of the application of HMM with multiple continuous observations in Stan, or even a case study of the new hmm functions in Stan.

Thanks,
Reg

1 Like

Here’s one for animal movement models: https://arxiv.org/abs/1806.10639

Thank you very much.

Hi,
Back to this question, I was able to make some progress. Right now, I am not sure how if my implementation of the hmm_marginal function is the best way because my results are allover the place.
Below is my stan code

data{
  int T; //number of obs
  int NY; //number of observed variables
  int K; //number of latent classes
  matrix[T, NY] Y; //observed Y
}

parameters {
  simplex[K] pi1; //initial state probs
  simplex[K] A[K]; //state tranisiton probs
  
  //continuous observation model
  positive_ordered[K] mu[NY]; //NY by K states underlying mean
  vector<lower = 0>[NY] sigma[K];//NY * K states underlying SD
  cholesky_factor_corr[NY] Y_ch[K];
}

transformed parameters {
  matrix[NY, K] mu_ord;
  matrix[K, K] tpm;
  cholesky_factor_cov[NY] sig_ch[K];

  
  for (y in 1:NY){
    mu_ord[y,] = mu[y]';
  }
  
  for(k in 1:K){
    tpm[k] = A[k]';
    sig_ch[k] = diag_pre_multiply(sigma[k,], Y_ch[k]);
  }
}

model {
  //not sure about this
  vector[K] logalpha[T];
  matrix[T, K] Ys;
  
  
 for (y in 1:NY){
  mu[y] ~ gamma(0.1, 0.01); 
  sigma[,y] ~ std_normal();
  }

  for (k in 1:K){
    Y_ch[k] ~ lkj_corr_cholesky(1);
    A[k] ~ dirichlet(rep_vector(1., K));
  }
  
  pi1 ~ dirichlet(rep_vector(1., K));
  
  for(t in 1:T){
    for(k in 1:K){
      Ys[t, k] = multi_normal_cholesky_lpdf(Y[t,] | mu_ord[,k], sig_ch[k]);      
    }
    
  }
  target += hmm_marginal(Ys', tpm, pi1); 

}

generated quantities{
  int<lower = 1, upper = K> latent_state[T];
  matrix[T, K] Ys;
  
  for(t in 1:T){
    for(k in 1:K){
      Ys[t, k] = multi_normal_cholesky_lpdf(Y[t,] | mu_ord[,k], sig_ch[k]);
      
    }
    
  }
  
  latent_state = hmm_latent_rng(Ys', tpm, pi1);
}

Also, I have included my data simulation code here

runif_simplex <- function(X) {
  pb <- -log(runif(X, min = 0, max = 1))
  tp <- pb/sum(pb)
}

mean_fn <- function(K, N_Y) {
    Y_mu <- matrix(nrow = K, ncol = N_Y)

    for(k in 1:K) {
      Y_mu[k,] <- abs(runif(N_Y, min = k, max = k+1))
      Y_mu[k,] <- 10*Y_mu[k,]
    }
    
    return(Y_mu)
}

library(mvtnorm)

hmm_sim <- function(K, NY, NT) {
  #K = Number of classes
  #NY = Number of cols for MVN data
  #NT = number of time points
  
  #Parameters
  pi1 <- runif_simplex(X = K) #initial probability
  A <- t(replicate(K, runif_simplex(X = K))) #transition probability for the discrete classes

  mu <- mean_fn(K = K, N_Y = NY)
  
  #Hidden Paths
  z <- vector(mode = "numeric", length = NT)
  z[1] <- sample(1:K, size = 1, prob = pi1)
  for(t in 2:NT){
    z[t] <- sample(1:K, size = 1, prob = A[z[t-1],])
  }
  
  #Observations
  y <- matrix(nrow = NT, ncol = NY)
  for(t in 1:NT){
    y[t, ] <- mvtnorm::rmvnorm(1, mean = mu[z[t], ], sigma = diag(1, nrow = NY))
  }
  
 
  return(list(Y = y, latent = z, 
              theta = list(init_prob = pi1, trns_prob = A, mean = mu)))
}

1 Like

Back to the question above. Even if I assume independence in the observed outcomes at a given state, with my stan model defined as

data {
  int T; //number of obs
  int NY; //number of observed variables
  int K; //number of latent classes
  matrix[T, NY] Y; //observed Y
}

parameters {
  simplex[K] pi1; //initial state probs
  simplex[K] A[K]; //state tranisiton probs i.e. A[i][j] = p(z_t = j | z_{t-1} = i)
  
  //continuous observation model
  vector[K] ovr_mu;
  positive_ordered[K] mu[NY]; //NY by K states underlying mean
  vector<lower = 0>[NY] sigma[K]; //K * NY states underlying SD
}

transformed parameters {
  matrix[K, K] tpm;
  matrix[K, T] Ys;
  
  for(k in 1:K){
    tpm[k] = A[k]';
  }
  
  for(t in 1:T){
    for(k in 1:K){
      Ys[k, t] = normal_lpdf(Y[t, 1] | mu[1][k], sigma[k, 1]) + 
                 normal_lpdf(Y[t, 2] | mu[2][k], sigma[k, 2]) +
                 normal_lpdf(Y[t, 3] | mu[3][k], sigma[k, 3]);
    }
  }
  

}

model {
  
  for (k in 1:K){
    mu[,k] ~ normal(ovr_mu[k], 2);
    sigma[k] ~ std_normal();
    A[k] ~ dirichlet(rep_vector(1, K));
  }
  
  pi1 ~ dirichlet(rep_vector(1, K));
  ovr_mu ~ normal(20, 10);
  
  target += hmm_marginal(Ys, tpm, pi1);
}

generated quantities {
  int<lower = 1, upper = K> latent_state[T];
  
  latent_state = hmm_latent_rng(Ys, tpm, pi1);
}

the chains do not converge.

set.seed(1212)
hmm_data <- hmm_sim(K = 4,NY = 3,NT = 400)
   variable  mean median   sd  mad    q5   q95 rhat ess_bulk ess_tail
 pi1[1]      0.21   0.17 0.17 0.17  0.01  0.54 1.00     2528      987
 pi1[2]      0.19   0.15 0.16 0.15  0.01  0.52 1.00     2503      969
 pi1[3]      0.30   0.27 0.21 0.23  0.02  0.70 1.15       19      237
 pi1[4]      0.30   0.27 0.21 0.22  0.02  0.71 1.16       17      141
 tpm[1,1]    0.20   0.13 0.18 0.12  0.02  0.58 1.08       36     1219
 tpm[2,1]    0.20   0.03 0.24 0.04  0.00  0.60 1.84        5      125
 tpm[3,1]    0.04   0.01 0.07 0.01  0.00  0.18 1.53        7       33
 tpm[4,1]    0.18   0.01 0.31 0.01  0.00  0.76 1.56        7       30
 tpm[1,2]    0.28   0.29 0.17 0.19  0.02  0.56 1.06       46      998
 tpm[2,2]    0.12   0.09 0.12 0.04  0.04  0.41 1.27       20       34
 tpm[3,2]    0.21   0.09 0.25 0.12  0.00  0.68 2.80        4       27
 tpm[4,2]    0.25   0.17 0.24 0.22  0.00  0.68 2.55        4       30
 tpm[1,3]    0.26   0.27 0.17 0.16  0.02  0.60 1.22     2060     1447
 tpm[2,3]    0.38   0.33 0.18 0.12  0.09  0.67 1.59        7       36
 tpm[3,3]    0.32   0.32 0.21 0.30  0.05  0.58 2.08        5       31
 tpm[4,3]    0.24   0.25 0.14 0.18  0.02  0.43 2.66        4       32
 tpm[1,4]    0.26   0.24 0.18 0.15  0.02  0.62 1.22     2789     1068
 tpm[2,4]    0.30   0.29 0.24 0.38  0.00  0.67 2.01        5       31
 tpm[3,4]    0.42   0.29 0.29 0.10  0.17  0.93 2.37        4       27
 tpm[4,4]    0.33   0.34 0.24 0.33  0.02  0.61 2.50        4       33
 mu[1,1]     9.63   8.94 5.06 6.15  2.12 16.59 1.61        6       36
 mu[2,1]     8.27   8.74 3.54 4.90  1.92 12.30 1.59        6       31
 mu[3,1]     8.77   8.96 4.13 5.75  1.70 13.77 1.59        6       31
 mu[1,2]    16.73  16.52 4.73 0.39  8.32 23.42 2.39        4       30
 mu[2,2]    15.77  12.24 7.78 0.31  8.05 29.11 2.36        4       28
 mu[3,2]    15.51  13.71 5.57 0.33  7.89 24.67 2.37        4       31
 mu[1,3]    21.67  23.35 2.98 0.15 16.43 23.54 1.58        6       37
 mu[2,3]    24.87  29.03 7.29 0.16 12.15 29.22 1.64        6       38
 mu[3,3]    21.90  24.59 4.73 0.13 13.64 24.76 1.58        6       36
 mu[1,4]    35.15  37.24 7.64 7.93 23.39 44.16 2.84        4       31
 mu[2,4]    34.80  35.55 3.72 3.77 29.08 39.40 2.84        4       28
 mu[3,4]    33.47  34.92 5.77 5.87 24.63 40.71 2.84        4       30
 ovr_mu[1]   9.06   8.95 4.15 5.26  2.36 15.20 1.54        7       38
 ovr_mu[2]  16.08  14.33 5.90 2.70  8.22 26.49 2.17        5       31
 ovr_mu[3]  22.78  25.13 5.06 1.75 13.30 27.27 1.53        7       30
 ovr_mu[4]  34.28  35.18 5.75 6.01 24.65 41.85 2.78        4       30
 sigma[1,1]  0.86   0.89 0.55 0.47  0.08  1.92 1.29      470      883
 sigma[2,1]  4.03   0.98 5.43 0.18  0.23 13.79 1.94        6       33
 sigma[3,1]  4.18   1.02 5.53 0.12  0.89 14.09 1.75        6       30
 sigma[4,1]  3.91   2.48 3.48 2.45  0.82  9.73 2.62        4       29
 sigma[1,2]  0.86   0.93 0.51 0.44  0.09  1.76 1.25      236      593
 sigma[2,2]  2.90   1.03 3.44 0.19  0.26  9.19 1.92        6       30
 sigma[3,2]  2.87   1.01 3.27 0.09  0.90  8.80 1.55        7       27
 sigma[4,2]  2.52   1.90 1.73 1.33  0.93  5.42 2.57        4       33
 sigma[1,3]  0.84   0.89 0.55 0.51  0.08  1.82 1.28      342      865
 sigma[2,3]  3.62   0.99 4.69 0.18  0.25 12.10 1.96        6       29
 sigma[3,3]  3.60   0.94 4.67 0.11  0.82 12.01 1.75        6       36
 sigma[4,3]  3.47   2.55 2.80 2.46  0.84  7.98 2.73        4       30
1 Like

This is indeed a good strategy in general - simplify the model to find the smallest one that still has issues. I don’t see anything immediately problematic with the model, but I think further simpliciations could help pinpoint the problem. What if you treat mu and/or sigma as known? (i.e. pass it as data) Can you make a univariate version of the model work? What if you avoid the hierarchical structure and have mu[,k] ~ something_fixed.

You might also want to check Divergent transitions - a primer where I discuss some other debugging strategies for models.

Best of luck! HMMs are cool, but sometimes challenging.

Thank you very much for the link on addressing divergent transitions. I will start with the univariate version of the model. Hopefully, I am able to get it to work.