Posterior predictive distribution for HMMs

I am trying to build a HMM-based model using the forward algorithm and I am able to get the thing working reasonably well and pass SBC and whatnot. Now I would like to test it’s fit against real data and for that I would like to do posterior predictive checks.

This however seems non-trivial: in all resources here on HMMs people either use the Viterbi algorithm to compute the most likely sequence of states for each sample or use forward-backward algorithm to compute marginal posterior probabilities for each state at each time point. However, if I understand stuff correctly neither of this let’s me get a sample from the posterior predictive distribution of state trajectories and hence posterior predictive of the observed trajectories. (I might also be missing something basic, I just rediscovered HMMs few days ago and I am trying to stitch the info online with what I remember from lectures from years ago). Is there a reasonably easy way to do this in generated quantities? The only thing I can come up is to try to implement particle filtering in generated quantities and that seems quite complex.

The state space of my model is reasonably small-ish (3 - 7 states, ~20 time steps).

Or are people just happy with the most likely paths from Viterbi or forward-backward smoothed probabilities for their posterior predictive needs?

I know that @vianeylb and @betanalpha did some cool stuff on HMMs recently, though I wasn’t able to find it on arXiv or something, so maybe you’ve solved/sidestepped this kind of problem?

Thanks for any pointers!

2 Likes

The forward-filtering backward-sampling algorithm FFBS (which also uses the forward and backward variables) samples from the joint posterior distribution of the underlying state process that generated the observed data. I assume this is what you’re looking for?

2 Likes

I should also mention that @charlesm93 is putting all of these functionalities into Stan right now!

2 Likes

As @vianeylb notes @charlesm93 will be adding an HMM PRNG once the efficient HMM code is merged into Stan. The FFBS algorithm is pretty straightforward to implement one you’ve implemented the forward and backwards algorithms effectively. For example here’s some C++ code implementing it with a running normalization to ensure numerical stability,

std::vector<int> hidden_state_rng_norm(int n_states, int n_transitions,
                                       const Eigen::MatrixXd& log_omegas,
                                       const Eigen::MatrixXd& Gamma,
                                       const Eigen::VectorXd& rho,
                                       boost::random::mt19937& prng) {
  Eigen::MatrixXd omegas = log_omegas.array().exp();
  
  // Forward pass with running normalization
  Eigen::MatrixXd alphas(n_states, n_transitions + 1);
  
  alphas.col(0) = omegas.col(0).cwiseProduct(rho);
  alphas.col(0) /= alphas.col(0).maxCoeff();
  
  for (int n = 0; n < n_transitions; ++n) {
    alphas.col(n + 1) = omegas.col(n + 1).cwiseProduct(Gamma * alphas.col(n));
    alphas.col(n + 1) /= alphas.col(n + 1).maxCoeff();;
  }

  // Backwards pass with running normalization
  Eigen::VectorXd beta = Eigen::VectorXd::Ones(n_states);
  
  // Sample last hidden state
  std::vector<int> hidden_states(n_transitions + 1);
    
  Eigen::VectorXd probs_vec =  alphas.col(n_transitions) 
                             / alphas.col(n_transitions).sum(); 
  std::vector<double> probs(probs_vec.data(), probs_vec.data() + n_states);
  boost::random::discrete_distribution<> cat_hidden(probs);
  hidden_states[n_transitions] = cat_hidden(prng);
  
  for (int n = n_transitions - 1; n >= 0; --n) {
    // Sample nth hidden state conditional on (n+1)st hidden state
    int last_hs = hidden_states[n + 1];
    
    Eigen::VectorXd probs_vec(n_states);
    for (int k = 0; k < n_states; ++k) {
      probs_vec[k] =   alphas(k, n) * omegas(last_hs, n + 1) 
                     * Gamma(last_hs, k) * beta(last_hs);
    }
    probs_vec /= probs_vec.sum();
    std::vector<double> probs(probs_vec.data(), probs_vec.data() + n_states);
    boost::random::discrete_distribution<> cat_hidden(probs);
    hidden_states[n] = cat_hidden(prng);
    
    // Update backwards state    
    beta = Gamma.transpose() * (omegas.col(n + 1).cwiseProduct(beta));
    beta /= beta.maxCoeff();
  }
  
  return hidden_states;
}
3 Likes

Thank you both, I was indeed missing something basic :-)

So thanks again, I got to implementing this and I have a few followup questions if you have some more time to spare.

First to be clear that I got stuff right:

  • omegas is the likelihood of observed data
  • Gamma is the transition matrix
  • rho is the distribution of states at the start

The followup questions:

  1. What is the advantage of the running normalization over working on the log scale and using softmax to get the probabilities for the RNG?
  2. If I wouldn’t care about running normalization, couldn’t I compute simply probs_vec[k] = alphas(k, n) * Gamma(last_hs, k) ?

Also I found the treatment of the math at https://stats.stackexchange.com/a/376975/73129 helpful, so leaving it here for future travelers.

Thanks again for your time!

Yes. And here Gamma is the transpose of what it usually is in the literature.

It is much faster. Exponentials are expensive and division is cheap, and empirically we saw no loss of numerical precision.

Without the running normalization the alpha and beta variables quickly underflow to zero.

1 Like

For those us who are a) impatient for this functionality and b) C++ semi-literate, would the following work as a Stan function to generate the posterior distribution of hidden states?

int[] hidden_states_normalized_rng(int n_states, int n_transitions, 
                                  matrix log_omegas, 
                                  matrix Gamma,
                                  vector rho){

  matrix[n_states, n_transitions + 1]     omegas = exp(log_omegas);
  matrix[n_states, n_transitions + 1]     alphas;                          
  int                                     hidden_states[n_transitions];
  vector[n_states]                        beta = rep_vector(1, n_states);
  vector[n_states]                        probs_vec;

  // Forward filtering
  alphas[ , 1]  = omegas[ , 1] .* rho;
  alphas[ , 1] /= max(alphas[, 1]);

  for (n in 1:n_transitions){
    alphas[ , n + 1]   = omegas[ , n + 1] .* (Gamma * alphas[ , n]);
    alphas[ , n + 1]  /= max(alphas[ , n + 1]);
  }

  //Backwards sampling
  // sample last hidden states
  probs_vec = alphas[ , n_transitions + 1] / sum(alphas[ , n_transitions + 1]);

  hidden_states[n_transitions + 1] = categorical_rng(probs_vec);

  for(n_rev in 1:(n_transitions - 1)){
    int n = n_transitions - n_rev;
    int last_hs = hidden_states[n + 1];
  
    for (k in 1:n_states) {
      probs_vec[k] =   alphas[k, n] * omegas[last_hs, n + 1] * 
      Gamma[last_hs, k] * beta[last_hs];
    }
    probs_vec /= sum(probs_vec);

    hidden_states[n] = categorical_rng(probs_vec);
  
    // Update backwards state
    beta  = Gamma' * (omegas[, n + 1] .* beta);
    beta /= max(beta);
  }

  return(hidden_states);
}

if you are impatient, then just use the current develop snapshot of cmdstan. That should have everything if I am not mistaken. Or maybe just wait until we tag 2.24rc the next days which will have all that.

And here Gamma is the transpose of what it usually is in the literature.

Does that mean that the rows of Gamma each sum to 1? Or the columns?

Regarding the running normalization why /= max() instead of /= sum()?

“Impatient” was meant more as a knock on myself then a complaint about the developers.

As for you suggestion, I may just do that. But I’m also trying to understand the math underlying the code as well, hence the exercise of trying to duplicate it. Thanks!

The code merged into Stan uses the common convention where the rows sum to one (i.e. the transpose of the matrix that actually updates the hidden state probabilities at each iteration).

The normalization is only for numerical stability and is just trying to keep the magnitude of each forward state probability around unity.