Hi,
I am currently experimenting with HMMs in Stan. Following the relevant section in Stan’s manual and this tutorial I put together a model that seems to work fairly well (I’m not using the new HMM functions in Stan because they don’t seem to support time-varying transitions probabilities).
In my model, observations are distributed according to a (Bernoulli) logistic regression with state-dependent weights (full model below):
bernoulli_logit_lpmf(y[n][t] | x[n][t] * betas[k]);
However, I run into the problem that the states do not agree across multiple chains. Even when individual chains seem to converge, they do not necessarily converge to a particular order of states, which is obviously problematic when trying to extract information on parameters, since the posteriors appear to be multimodal. I have read the case study on permutation invariance in mixture models, and the preferred approach seems to be to enforce ordering of the state parameters by using an “ordered” type.
In the multidimensional case (multiple weights per state), things seem a bit more complicated. I tried to enforce the order for only one (the first) of the state-dependent parameters, but that seems to break the whole model (Stan complains about diverging transitions after warm-up and when running on simulated data, the posteriors don’t seem to recover the true parameters):
parameters {
...
ordered[K] betas_1;
vector[M-1] betas[K]; // per-state regression coefficients
}
transformed parameters {
...
vector[M] tbetas[K];
// evidently, I haven't figured out Stan's indexing syntax yet
for (k in 1:K) {
tbetas[k, 1] = betas_1[k];
for (m in 2:M) {
tbetas[k, m] = betas[k, m-1];
}
}
...
}
2: There were 31 divergent transitions after warmup. See
http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
to find out why this is a problem and how to eliminate them.
3: Examine the pairs() plot to diagnose sampling problems
4: The largest R-hat is 1.73, indicating chains have not mixed.
I’m a bit reluctant to use non-exchangeable priors (partly because in the end, I want to run the model on a fairly large dataset (~50k+ observations) so the priors would have to be fairly strong.
Is there any other way of addressing this? I thought about inferring the most likely sequences of the latent states (viterbi) and then matching the states across chains. Has anyone tried something like this? Would this even be possible within Stan code (i.e., can I access results from multiple chains in a Stan program?).
Edit: Matching the most likely state sequences seems to be relatively robust. However, since I can’t do this until sampling is complete, the stanfit object (I use rstan) is not very useful for running all the fancy diagnostic tools available.
Here is my full model:
functions {
// forward algorithm
vector[] forward(int K, int T, vector logpi, vector[] logA, vector[] loglike) {
vector[T] logalpha[K];
real acc[K];
for(j in 1:K) {
logalpha[j, 1] = logpi[j] + loglike[j, 1];
}
for (t in 2:T) {
for (j in 1:K) {
for (i in 1:K) {
acc[i] = logalpha[i, t-1] + logA[i, j] + loglike[j, t];
}
logalpha[j, t] = log_sum_exp(acc);
}
}
return logalpha;
}
}
data {
int<lower = 1> T; // number of observations
int<lower = 1> K; // number of hidden states
int<lower = 1> R; // dimensionality of observations
int<lower = 1> M; // number of predictors for observations
int<lower = 1> N; // number of sequences
matrix[T, M] x[N]; // predictors for observations
int y[N, T, R]; // observations
}
parameters {
simplex[K] pi[N]; // initial state probabilities (one per sequence necessary?)
simplex[K] A[K]; // transition probabilities
vector[M] betas[K]; // per-state regression coefficients
}
transformed parameters {
vector[T] logalpha[N, K];
// for each observation sequence
for(n in 1:N) {
vector[T] loglike[K];
vector[K] logpi;
vector[K] logA[K];
logpi = log(pi[n]);
logA = log(A);
for (t in 1:T) {
// esential doing logistic regression for each state
for (k in 1:K) {
loglike[k][t] = bernoulli_logit_lpmf(y[n][t] | x[n][t] * betas[k]);
}
}
// calculate forward values
logalpha[n] = forward(K, T, logpi, logA, loglike);
}
}
model {
for (k in 1:K) {
// Dirichlet priors for transition probabilities
A[k, ] ~ dirichlet(rep_vector(1, K));
// Gaussian priors for weights
betas[k] ~ normal(0, 2);
}
for(n in 1:N) {
target += log_sum_exp(logalpha[n][, T]);
}
}
Best,
Marc