Handling Permutation-Invariance in HMMs

Hi,

I am currently experimenting with HMMs in Stan. Following the relevant section in Stan’s manual and this tutorial I put together a model that seems to work fairly well (I’m not using the new HMM functions in Stan because they don’t seem to support time-varying transitions probabilities).

In my model, observations are distributed according to a (Bernoulli) logistic regression with state-dependent weights (full model below):

bernoulli_logit_lpmf(y[n][t] | x[n][t] * betas[k]);

However, I run into the problem that the states do not agree across multiple chains. Even when individual chains seem to converge, they do not necessarily converge to a particular order of states, which is obviously problematic when trying to extract information on parameters, since the posteriors appear to be multimodal. I have read the case study on permutation invariance in mixture models, and the preferred approach seems to be to enforce ordering of the state parameters by using an “ordered” type.

In the multidimensional case (multiple weights per state), things seem a bit more complicated. I tried to enforce the order for only one (the first) of the state-dependent parameters, but that seems to break the whole model (Stan complains about diverging transitions after warm-up and when running on simulated data, the posteriors don’t seem to recover the true parameters):

parameters {
    ...
    ordered[K] betas_1;
    vector[M-1] betas[K]; // per-state regression coefficients
}

transformed parameters {
    ...
    vector[M] tbetas[K];
    // evidently, I haven't figured out Stan's indexing syntax yet
    for (k in 1:K) {
      tbetas[k, 1] = betas_1[k];
      for (m in 2:M) {
        tbetas[k, m] = betas[k, m-1];
      }
    }
    ...
}
2: There were 31 divergent transitions after warmup. See
http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
to find out why this is a problem and how to eliminate them. 
3: Examine the pairs() plot to diagnose sampling problems
 
4: The largest R-hat is 1.73, indicating chains have not mixed.

I’m a bit reluctant to use non-exchangeable priors (partly because in the end, I want to run the model on a fairly large dataset (~50k+ observations) so the priors would have to be fairly strong.

Is there any other way of addressing this? I thought about inferring the most likely sequences of the latent states (viterbi) and then matching the states across chains. Has anyone tried something like this? Would this even be possible within Stan code (i.e., can I access results from multiple chains in a Stan program?).

Edit: Matching the most likely state sequences seems to be relatively robust. However, since I can’t do this until sampling is complete, the stanfit object (I use rstan) is not very useful for running all the fancy diagnostic tools available.

Here is my full model:

functions {
  // forward algorithm
  vector[] forward(int K, int T, vector logpi, vector[] logA, vector[] loglike) {
    vector[T] logalpha[K];
    real acc[K];

    for(j in 1:K) {
      logalpha[j, 1] = logpi[j] + loglike[j, 1];
    }

    for (t in 2:T) {
      for (j in 1:K) {
        for (i in 1:K) {
          acc[i] = logalpha[i, t-1] + logA[i, j] + loglike[j, t];
        }
        logalpha[j, t] = log_sum_exp(acc);
      }
    }
    return logalpha;
  }
}

data {
    int<lower = 1> T;   // number of observations
    int<lower = 1> K;   // number of hidden states
    int<lower = 1> R;   // dimensionality of observations
    int<lower = 1> M;   // number of predictors for observations
    int<lower = 1> N;   // number of sequences

    matrix[T, M] x[N];  // predictors for observations
    int y[N, T, R];     // observations
}

parameters {
    simplex[K] pi[N];   // initial state probabilities (one per sequence necessary?)
    simplex[K] A[K];    // transition probabilities
    vector[M] betas[K]; // per-state regression coefficients
}

transformed parameters {
    vector[T] logalpha[N, K];

    // for each observation sequence
    for(n in 1:N) {
        vector[T] loglike[K];
        vector[K] logpi;
        vector[K] logA[K];						 

        logpi = log(pi[n]);
        logA = log(A);

        for (t in 1:T) {
            // esential doing logistic regression for each state
            for (k in 1:K) {
              loglike[k][t] = bernoulli_logit_lpmf(y[n][t] | x[n][t] * betas[k]);
            }
    }

    // calculate forward values
    logalpha[n] = forward(K, T, logpi, logA, loglike);
    }
}

model {
    for (k in 1:K) {
        // Dirichlet priors for transition probabilities
        A[k, ] ~ dirichlet(rep_vector(1, K));
        // Gaussian priors for weights
        betas[k] ~ normal(0, 2);
    }

    for(n in 1:N) {
      target += log_sum_exp(logalpha[n][, T]); 
    }
}

Best,
Marc

2 Likes

This is great! Do you have the full code that includes the gen quantities and with data (or simulated data)?

I’ll mess around and see if I can get the chains identified

1 Like

Sure! I’m still struggling with understanding the different ways of generating predictive checks (viterbi v. FFBS v. just sampling) but I certainly appreciate any help with the label switching problem.

I will post the code and some sample+real data later today!

1 Like

To be completely unhelpful I don’t think we have a strong enough understanding of the degeneracies inherent to Hidden Markov Models quite yet. Label switching is just one problem out of many, the others often being far more insidious and problematic when the label switching can be eliminated.

One of the problems with trying to match to Viterbi is that it makes it very easy to overfit when there are more degeneracies than pure label switching, taking one chain that seems “close enough” and ignoring many other modeling configurations that are also consistent with the data. In other words trying to filter out chains by how similar they are to the Viterbi sequence works only when the posterior concentrates around a small neighborhood of sequences around Viterbi, and if that were true then you wouldn’t be seeing such chain-by-chain variation.

With that said, when trying to eliminated label switching one has to keep in mind that label switching in HMMs happens at the observational level. If the component observational models for each state are identical then there’s nothing that can distinguish between the latent states; it’s a fundamental flaw with the experimental design. Consequently the key is to differentiate those observational processes somehow, for example with priors on the auxiliary parameters of those observational models that don’t overlap, so that each observational model captures a unique set of behaviors.

If the transitions are rigid enough then even occasional measurements that can distinguish between the states can be enough to identify everything at intermediate times, although a time-dependent transition matrix will make this tough.

1 Like

Thanks Michael,

while this is a bit daunting, this really is a learning experience for me, so I really appreciate your input!

For my dataset, there seems to be one “dominant” state and one (or two) much less prevalent states. From my limited understanding, the states seem to be sufficiently different to actually produce good (identifiable) results although I fear that this won’t be the case anymore when I introduce time-varying transition matrixes.

In any case, reading about and dealing with these things is extremely helpful in getting a better understanding of what is actually going on!

One of the problems with trying to match to Viterbi is that it makes it very easy to overfit when there are more degeneracies than pure label switching, taking one chain that seems “close enough” and ignoring many other modeling configurations that are also consistent with the data. In other words trying to filter out chains by how similar they are to the Viterbi sequence works only when the posterior concentrates around a small neighborhood of sequences around Viterbi, and if that were true then you wouldn’t be seeing such chain-by-chain variation.

I don’t have a plot at hand, but at least for the two-state case, the model seems “identifiable” (not sure if that’s the right word to use here) in that the chains end up in the same neighbourhoods (just mirrored) and only between-chain label-switching appears to be an issue. Assuming that I use some sort of crude “label-deswitching” (like viterbi matching), shouldn’t at least Rhat catch the problems you’ve described?

With that said, when trying to eliminated label switching one has to keep in mind that label switching in HMMs happens at the observational level. If the component observational models for each state are identical then there’s nothing that can distinguish between the latent states; it’s a fundamental flaw with the experimental design.

If I’m not mistaken, this is the real problem with identifiability, right? Label-switching is just annoying but mostly harmless whereas extensive overlap in state parameters will make it impossible to fit the model.

If the transitions are rigid enough

What does “rigid” mean in this context?

@spinkney: This is the model I’m currently using.

And here is some sample data:
sample_data.csv (195.7 KB)

functions {
  // forward algorithm
  vector[] forward(matrix log_b, matrix[] A, vector pi) {
   
    int K = dims(log_b)[1];
    int T = dims(log_b)[2];

    matrix[K,K] log_A[T];
    vector[K] log_pi;
    
    for (i in 1:K) {
      log_pi[i] = log(pi[i]);

      for (j in 1:K) {
        for (t in 1:T) {
          log_A[t][i,j] = log(A[t][i,j]);
        }
      }
    } 

    vector[T] log_alpha[K];
    real acc[K];

    for(j in 1:K) {
      log_alpha[j, 1] = log_pi[j] + log_b[j, 1];
    }

    for (t in 2:T) {
      for (j in 1:K) {
        for (i in 1:K) {
          acc[i] = log_alpha[i, t-1] + log_A[t][i, j] + log_b[j, t];
        }
        log_alpha[j, t] = log_sum_exp(acc);
      }
    }

    return log_alpha;
  }
}

data {
    // sequences
    int<lower = 1> N;   // number of sequences
    int T[N];           // length of each sequence

    int<lower = 1> K;   // number of hidden states

    // observation model
    int<lower = 0> I;   // number of (flattend/total) obervations
    int<lower = 1> R;   // dimensionality of observations (NOT USED atm) 
    int<lower = 1> M;   // number of predictors for observations

    int y[I];        // observations
    matrix[I, M] x;     // predictors for observations
}

parameters {
    simplex[K] pi;   // initial state probabilities (one per sequence necessary?)
    vector[M] betas_x[K]; // per-state regression coefficients for obervation model

    simplex[K] A[K];  // per-state regression intercepts transition model
}

transformed parameters {

  vector[N] log_like_sess;

  {
    int pos = 1;
    for(n in 1:N) {
      // compute forward values
      int T_ = T[n];
      matrix[T_, M] x_ = block(x, pos, 1, T_, M);

      matrix[K, K] A_[T_];

      int y_[T_] = y[pos:pos+T_-1];
      matrix[K, T_] logb_;

      for (t in 1:T_) {
        for (i in 1:K) {
          A_[t][i,] = to_row_vector(A[i,]);
          logb_[i][t] = bernoulli_logit_lpmf(y_[t] | x_[t] * betas_x[i]);
        }
      }

      // this should also work
      //log_like_sess[n] = hmm_marginal(logb_, A_[1], pi);
      log_like_sess[n] = log_sum_exp( forward(logb_, A_, pi)[, T_] );

      pos = pos + T_;
    }
  }

}

model {
    for (k in 1:K) {
        A[k, ] ~ dirichlet(rep_vector(1, K));
        // Gaussian priors for weights
        betas_x[k] ~ normal(0, 2);
    }
    
    for(n in 1:N) {
      target += log_like_sess[n];
    }
}

generated quantities {

  int zpred[I];
  int ypred[I];

   {
    int pos = 1;
    for(n in 1:N) {

      int T_ = T[n];
      matrix[T_, M] x_ = block(x, pos, 1, T_, M);
      matrix[K, K] A_[T_];

      int y_[T_] = y[pos:pos+T_-1];
      matrix[K, T_] logb_;

      for (t in 1:T_) {
        for (i in 1:K) {
          A_[t][i,] = to_row_vector(A[i,]);
          logb_[i][t] = bernoulli_logit_lpmf(y_[t] | x_[t] * betas_x[i]);
        }
      }

      zpred[pos:pos+T_-1]  = hmm_latent_rng(logb_, A_[1], pi);

      for (t in 1:T_) {
        // sample from observation posterior
        ypred[pos+t-1] = bernoulli_logit_rng(x_[t] * betas_x[zpred[pos+t-1]]);
      }
      pos = pos + T_;
    }
  }

}

… and some R code:

library(tidyverse)
library(abind)
library(cmdstanr)

# load data from csv
df <- read.csv("sample_data.csv")


x <- df %>% group_by(session) %>%
  group_split() %>% 
  lapply(select, "stimulus", "bias") %>%
  lapply(as.matrix) %>%
  abind(along=1)

y <- df %>% group_by(session) %>%
  mutate(choice = 1 - choice) %>%
  group_split() %>% 
  lapply(select, "choice") %>%
  lapply(as.matrix) %>%
  abind(along=1)

T <- df %>% group_by(session) %>%
  count() %>%
  ungroup() %>%
  select(n) %>%
  as.matrix() %>%
  as.vector()


# input data for model
data <- list(
  x = x,
  y = drop(y),
  T = T,
  K = 3,
  R = 1,
  M = 4,
  N = length(T),
  I = sum(T)
)

# the model
model <- cmdstan_model("./stan-models/glm-hmm.stan")

# fit model
fit <- model$sample(
  data = data,            # named list of data
  chains = 1,             # number of Markov chains
  refresh = 5,             # print progress every 5 iterations
  iter_warmup = 1000,
  iter_sampling = 1000
  )

The complication is that the modes arising from the label-switching/permutation symmetry might not be the only modes, just the most prominent ones. A not uncommon occurrence is that only once the label switching has been resolved will your Markov chains start to settle into those less prominent modes (I have chains struggling with this extremely frustrating problem at this very second) where you can actually see them.

Label switching is an identifiability problem, and it can only happen if at least some of the observational models are the same. Remember that here we’re looking at label switching in the posterior conditioned on the observed data which means that permuting the labels doesn’t change the likelihood function at all. If each state had very different observational consequences then swapping the labels, which would swap which component observational model is associated with each component observation, would lead to a very different likelihood function and your posterior would no longer be permutation invariant.

It may help to note that identifiability is a property of the observational model, and non-identifiable models will be problematic no matter how many observations you have. In practice we often care more about the complex uncertainties that arise when we have smaller data sets which can arise for both identifiable and non-identifiable models. I discuss this in Identity Crisis.

By rigid I mean diagonally dominant so that the states are less likely to transition. The more rigid the transition matrix the longer a sequence of neighboring observations informs the same set of states, which usually leads to better behaved inferences.

1 Like