I will like to find out if there are examples of the application of HMM with multiple continuous observations in Stan, or even a case study of the new hmm functions in Stan.
Thanks,
Reg
I will like to find out if there are examples of the application of HMM with multiple continuous observations in Stan, or even a case study of the new hmm functions in Stan.
Thanks,
Reg
Here’s one for animal movement models: https://arxiv.org/abs/1806.10639
Thank you very much.
Hi,
Back to this question, I was able to make some progress. Right now, I am not sure how if my implementation of the hmm_marginal
function is the best way because my results are allover the place.
Below is my stan code
data{
int T; //number of obs
int NY; //number of observed variables
int K; //number of latent classes
matrix[T, NY] Y; //observed Y
}
parameters {
simplex[K] pi1; //initial state probs
simplex[K] A[K]; //state tranisiton probs
//continuous observation model
positive_ordered[K] mu[NY]; //NY by K states underlying mean
vector<lower = 0>[NY] sigma[K];//NY * K states underlying SD
cholesky_factor_corr[NY] Y_ch[K];
}
transformed parameters {
matrix[NY, K] mu_ord;
matrix[K, K] tpm;
cholesky_factor_cov[NY] sig_ch[K];
for (y in 1:NY){
mu_ord[y,] = mu[y]';
}
for(k in 1:K){
tpm[k] = A[k]';
sig_ch[k] = diag_pre_multiply(sigma[k,], Y_ch[k]);
}
}
model {
//not sure about this
vector[K] logalpha[T];
matrix[T, K] Ys;
for (y in 1:NY){
mu[y] ~ gamma(0.1, 0.01);
sigma[,y] ~ std_normal();
}
for (k in 1:K){
Y_ch[k] ~ lkj_corr_cholesky(1);
A[k] ~ dirichlet(rep_vector(1., K));
}
pi1 ~ dirichlet(rep_vector(1., K));
for(t in 1:T){
for(k in 1:K){
Ys[t, k] = multi_normal_cholesky_lpdf(Y[t,] | mu_ord[,k], sig_ch[k]);
}
}
target += hmm_marginal(Ys', tpm, pi1);
}
generated quantities{
int<lower = 1, upper = K> latent_state[T];
matrix[T, K] Ys;
for(t in 1:T){
for(k in 1:K){
Ys[t, k] = multi_normal_cholesky_lpdf(Y[t,] | mu_ord[,k], sig_ch[k]);
}
}
latent_state = hmm_latent_rng(Ys', tpm, pi1);
}
Also, I have included my data simulation code here
runif_simplex <- function(X) {
pb <- -log(runif(X, min = 0, max = 1))
tp <- pb/sum(pb)
}
mean_fn <- function(K, N_Y) {
Y_mu <- matrix(nrow = K, ncol = N_Y)
for(k in 1:K) {
Y_mu[k,] <- abs(runif(N_Y, min = k, max = k+1))
Y_mu[k,] <- 10*Y_mu[k,]
}
return(Y_mu)
}
library(mvtnorm)
hmm_sim <- function(K, NY, NT) {
#K = Number of classes
#NY = Number of cols for MVN data
#NT = number of time points
#Parameters
pi1 <- runif_simplex(X = K) #initial probability
A <- t(replicate(K, runif_simplex(X = K))) #transition probability for the discrete classes
mu <- mean_fn(K = K, N_Y = NY)
#Hidden Paths
z <- vector(mode = "numeric", length = NT)
z[1] <- sample(1:K, size = 1, prob = pi1)
for(t in 2:NT){
z[t] <- sample(1:K, size = 1, prob = A[z[t-1],])
}
#Observations
y <- matrix(nrow = NT, ncol = NY)
for(t in 1:NT){
y[t, ] <- mvtnorm::rmvnorm(1, mean = mu[z[t], ], sigma = diag(1, nrow = NY))
}
return(list(Y = y, latent = z,
theta = list(init_prob = pi1, trns_prob = A, mean = mu)))
}
Back to the question above. Even if I assume independence in the observed outcomes at a given state, with my stan model defined as
data {
int T; //number of obs
int NY; //number of observed variables
int K; //number of latent classes
matrix[T, NY] Y; //observed Y
}
parameters {
simplex[K] pi1; //initial state probs
simplex[K] A[K]; //state tranisiton probs i.e. A[i][j] = p(z_t = j | z_{t-1} = i)
//continuous observation model
vector[K] ovr_mu;
positive_ordered[K] mu[NY]; //NY by K states underlying mean
vector<lower = 0>[NY] sigma[K]; //K * NY states underlying SD
}
transformed parameters {
matrix[K, K] tpm;
matrix[K, T] Ys;
for(k in 1:K){
tpm[k] = A[k]';
}
for(t in 1:T){
for(k in 1:K){
Ys[k, t] = normal_lpdf(Y[t, 1] | mu[1][k], sigma[k, 1]) +
normal_lpdf(Y[t, 2] | mu[2][k], sigma[k, 2]) +
normal_lpdf(Y[t, 3] | mu[3][k], sigma[k, 3]);
}
}
}
model {
for (k in 1:K){
mu[,k] ~ normal(ovr_mu[k], 2);
sigma[k] ~ std_normal();
A[k] ~ dirichlet(rep_vector(1, K));
}
pi1 ~ dirichlet(rep_vector(1, K));
ovr_mu ~ normal(20, 10);
target += hmm_marginal(Ys, tpm, pi1);
}
generated quantities {
int<lower = 1, upper = K> latent_state[T];
latent_state = hmm_latent_rng(Ys, tpm, pi1);
}
the chains do not converge.
set.seed(1212)
hmm_data <- hmm_sim(K = 4,NY = 3,NT = 400)
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
pi1[1] 0.21 0.17 0.17 0.17 0.01 0.54 1.00 2528 987
pi1[2] 0.19 0.15 0.16 0.15 0.01 0.52 1.00 2503 969
pi1[3] 0.30 0.27 0.21 0.23 0.02 0.70 1.15 19 237
pi1[4] 0.30 0.27 0.21 0.22 0.02 0.71 1.16 17 141
tpm[1,1] 0.20 0.13 0.18 0.12 0.02 0.58 1.08 36 1219
tpm[2,1] 0.20 0.03 0.24 0.04 0.00 0.60 1.84 5 125
tpm[3,1] 0.04 0.01 0.07 0.01 0.00 0.18 1.53 7 33
tpm[4,1] 0.18 0.01 0.31 0.01 0.00 0.76 1.56 7 30
tpm[1,2] 0.28 0.29 0.17 0.19 0.02 0.56 1.06 46 998
tpm[2,2] 0.12 0.09 0.12 0.04 0.04 0.41 1.27 20 34
tpm[3,2] 0.21 0.09 0.25 0.12 0.00 0.68 2.80 4 27
tpm[4,2] 0.25 0.17 0.24 0.22 0.00 0.68 2.55 4 30
tpm[1,3] 0.26 0.27 0.17 0.16 0.02 0.60 1.22 2060 1447
tpm[2,3] 0.38 0.33 0.18 0.12 0.09 0.67 1.59 7 36
tpm[3,3] 0.32 0.32 0.21 0.30 0.05 0.58 2.08 5 31
tpm[4,3] 0.24 0.25 0.14 0.18 0.02 0.43 2.66 4 32
tpm[1,4] 0.26 0.24 0.18 0.15 0.02 0.62 1.22 2789 1068
tpm[2,4] 0.30 0.29 0.24 0.38 0.00 0.67 2.01 5 31
tpm[3,4] 0.42 0.29 0.29 0.10 0.17 0.93 2.37 4 27
tpm[4,4] 0.33 0.34 0.24 0.33 0.02 0.61 2.50 4 33
mu[1,1] 9.63 8.94 5.06 6.15 2.12 16.59 1.61 6 36
mu[2,1] 8.27 8.74 3.54 4.90 1.92 12.30 1.59 6 31
mu[3,1] 8.77 8.96 4.13 5.75 1.70 13.77 1.59 6 31
mu[1,2] 16.73 16.52 4.73 0.39 8.32 23.42 2.39 4 30
mu[2,2] 15.77 12.24 7.78 0.31 8.05 29.11 2.36 4 28
mu[3,2] 15.51 13.71 5.57 0.33 7.89 24.67 2.37 4 31
mu[1,3] 21.67 23.35 2.98 0.15 16.43 23.54 1.58 6 37
mu[2,3] 24.87 29.03 7.29 0.16 12.15 29.22 1.64 6 38
mu[3,3] 21.90 24.59 4.73 0.13 13.64 24.76 1.58 6 36
mu[1,4] 35.15 37.24 7.64 7.93 23.39 44.16 2.84 4 31
mu[2,4] 34.80 35.55 3.72 3.77 29.08 39.40 2.84 4 28
mu[3,4] 33.47 34.92 5.77 5.87 24.63 40.71 2.84 4 30
ovr_mu[1] 9.06 8.95 4.15 5.26 2.36 15.20 1.54 7 38
ovr_mu[2] 16.08 14.33 5.90 2.70 8.22 26.49 2.17 5 31
ovr_mu[3] 22.78 25.13 5.06 1.75 13.30 27.27 1.53 7 30
ovr_mu[4] 34.28 35.18 5.75 6.01 24.65 41.85 2.78 4 30
sigma[1,1] 0.86 0.89 0.55 0.47 0.08 1.92 1.29 470 883
sigma[2,1] 4.03 0.98 5.43 0.18 0.23 13.79 1.94 6 33
sigma[3,1] 4.18 1.02 5.53 0.12 0.89 14.09 1.75 6 30
sigma[4,1] 3.91 2.48 3.48 2.45 0.82 9.73 2.62 4 29
sigma[1,2] 0.86 0.93 0.51 0.44 0.09 1.76 1.25 236 593
sigma[2,2] 2.90 1.03 3.44 0.19 0.26 9.19 1.92 6 30
sigma[3,2] 2.87 1.01 3.27 0.09 0.90 8.80 1.55 7 27
sigma[4,2] 2.52 1.90 1.73 1.33 0.93 5.42 2.57 4 33
sigma[1,3] 0.84 0.89 0.55 0.51 0.08 1.82 1.28 342 865
sigma[2,3] 3.62 0.99 4.69 0.18 0.25 12.10 1.96 6 29
sigma[3,3] 3.60 0.94 4.67 0.11 0.82 12.01 1.75 6 36
sigma[4,3] 3.47 2.55 2.80 2.46 0.84 7.98 2.73 4 30
This is indeed a good strategy in general - simplify the model to find the smallest one that still has issues. I don’t see anything immediately problematic with the model, but I think further simpliciations could help pinpoint the problem. What if you treat mu
and/or sigma
as known? (i.e. pass it as data) Can you make a univariate version of the model work? What if you avoid the hierarchical structure and have mu[,k] ~ something_fixed
.
You might also want to check Divergent transitions - a primer where I discuss some other debugging strategies for models.
Best of luck! HMMs are cool, but sometimes challenging.
Thank you very much for the link on addressing divergent transitions. I will start with the univariate version of the model. Hopefully, I am able to get it to work.