Hidden Markov Model: Trouble With Forward Algorithm

Hi Stan community,

I’m working on developing a hidden markov model with time-varying transition probabilities, and I am currently implementing the forward algorithm (and eventually the backwards algorithm)–Trouble is, I fail to get valid simplexes for \mathbf{\alpha_t} = p(S_t=1,S_t=2 | \mathbf{y}_{1:t} ).

In my implementation of the forward algorithm, I first compute unnormalized log-probabilities, and I then normalize them on the log-scale–for this, I took inspiration from another discussion: Getting access to “internal” log probs in HMM code (& how they behave across time) - Modeling - The Stan Forums

I see now that I get forward log-probabilities that do not decrease rapidly over time, so it seems normalizing in the log-domain is working as intended;

However, when I look at the forward probabilities on the probability scale, I find that they don’t sum to one, and that their sum can sometimes fall quite short of one. Below I show a summary of \alpha_t(1) + \alpha_t(2) , for t=1,...T, for one iteration of my HMC sampler;

Below is my Stan code–I’m assuming there’s a mistake in my implementation, but for the life of me I can’t detect where it is. I will also mention that in my testing, parameter estimation seems to be just fine–I’m able to recover the parameters used to produce the synthetic data, albeit with label switching being an issue.

functions{

  // Multivariate Power Exponential Distribution
real MVPE_lpdf(vector x, vector mu, matrix sigma, real beta){

 // Dimension of random vector/size of covariance matrix;
 int p = num_elements(x);
 
 matrix[p,p] sigma_inverse = inverse(sigma);

 real quad_term = 0.0;

 vector[p] B = x - mu;
 
 quad_term = quad_form_sym(sigma_inverse, B)^beta;
 
 return log(p) + lgamma(p * 0.5) -0.5*log_determinant(sigma) - 0.5*p*log(pi()) -lgamma(1 + p/(2*beta)) -log(2) -(log(2)*p)/(2*beta) -0.5*quad_term;
  
  }
  
}
data{
  
  // Number of observations
  int<lower=1> T;
  
  // Number of distinct states/regimes
  int<lower=1> K;
  
  // Dimension of dependent (response) variable
  int<lower=1> p;
  
  // Dimension of covariates/regressors
  int<lower=1> k;
  
  // Number of covariates pertaining to the state process
  int<lower=1> nCovs;
  
  // Covariates for the state process-- +1 for the intercept
  array[T] vector[nCovs+1] covs;
  
  // Covariates/independent variables
  array[T] vector[k] x;

  // Response/dependent variables
  array[T] vector[p] y;
  
}
transformed data{

   // Below we construct a uniform simplex over the initial states;
   row_vector[K] initdist;
   
   initdist = rep_row_vector(1.0/K, K);
   
}
parameters{

   // Switching in mean vector mu;
   array[K] vector[p] mu;
   
   // Kurtosis parameter;
   real<lower=0> kappa;
   
   // Noise (Co)variance matrix sigma (just a vector for now)
   vector<lower=0>[p] V;
   
   // Loadings matrix;
   matrix[p, k] F;
   
   // Coefficients that dictate whether you stay or leave S1.
   // This contains the intercept coefficient and the slope coefficient(s)
   vector[nCovs+1] b1;
   // Coefficients that dictate whether you stay or leave S2;
   // This contains the intercept coefficient and the slope coefficient(s)
   vector[nCovs+1] b2;
   
}
transformed parameters{

// Nothing, for now....

  
}
model{

// The log of the t.p.m
array[T] matrix[K, K] log_gamma;

vector[K] lp;
vector[K] lp_p1;


   // Priors;
   for(i in 1:K){
    
   // mu is a 1D array with K entries, each of which are intercept vectors, which we allow to switch. 
   mu[i] ~ normal(0, 10);
   
   }

   kappa ~ normal(0, 10) T[0,];
   
   V ~ normal(0, 10) T[0,];   
   
   for(i in 1:k){
   
   F[,i] ~ double_exponential(0, 10);
   
   }
   
   b1 ~ normal(0, 5);
   b2 ~ normal(0, 5);
   
// Construct the log of the t.p.m for each time point;
   for(t in 1:T){
   
	    // P(S_t = 1 | S_t-1 = 1)
		log_gamma[t][1,1] = inv_logit( dot_product(covs[t], b1) );
		
		// P(S_t = 2 | S_t-1 = 1) = 1 - P(S_t = 1 | S_t-1 = 1)
	    log_gamma[t][1,2] = 1 - log_gamma[t][1,1];
		
	    // P(S_t = 2 | S_t-1 = 2)
		log_gamma[t][2,2] = inv_logit( dot_product(covs[t], b2) );
		
		// P(S_t = 1 | S_t-1 = 2) = 1 - P(S_t = 2 | S_t-1 = 2)
	    log_gamma[t][2,1] = 1 - log_gamma[t][2,2];
		
		// Convert to the log scale;
		  log_gamma[t][1] = log(log_gamma[t][1]);
	      log_gamma[t][2] = log(log_gamma[t][2]);
}
  
  // likelihood computation
 for(n in 1:K){  // first observation
  lp[n] = log(initdist[n]) + MVPE_lpdf(y[1] | mu[n] + F * x[1], diag_matrix(V), kappa);
 } 

 for(t in 2:T){ // looping over observations
 for(n in 1:K){ // looping over states
 
  lp_p1[n] = log_sum_exp(log_gamma[t][n] + to_row_vector(lp)) + MVPE_lpdf(y[t] | mu[n] + F * x[t], diag_matrix(V), kappa);


 } 
 
 lp = lp_p1;
 
 if(t==T){
 
  target += log_sum_exp(lp); 
 
    }  
  }
}
generated quantities{

 // State-dependent observation densities;
 array[T] vector[K] SD_OD;
 
 // Transpose of the log of the time-varying t.p.m;
 array[T] matrix[K,K] log_gamma_tr;

 // log-forward probabilities, normalized;
 array[T] vector[K] lalpha; 
 
 // forward probabilities, normalized, and on the probability scale;
 array[T] vector[K] alpha; 

// Make the transposed log-gamma matrix, as well as the entries of SD_OD;
   for(t in 1:T){
   
	    // P(S_t = 1 | S_t-1 = 1)
		log_gamma_tr[t][1,1] = inv_logit( dot_product(covs[t], b1) );
		
		// P(S_t = 2 | S_t-1 = 1) = 1 - P(S_t = 1 | S_t-1 = 1)
	    log_gamma_tr[t][1,2] = 1 - log_gamma_tr[t][1,1];
		
	    // P(S_t = 2 | S_t-1 = 2)
		log_gamma_tr[t][2,2] = inv_logit( dot_product(covs[t], b2) );
		
		// P(S_t = 1 | S_t-1 = 2) = 1 - P(S_t = 2 | S_t-1 = 2)
	    log_gamma_tr[t][2,1] = 1 - log_gamma_tr[t][2,2];
		
		// Convert to the log scale;
	    log_gamma_tr[t][1] = log(log_gamma_tr[t][1]);
	    log_gamma_tr[t][2] = log(log_gamma_tr[t][2]);
		
		// Tranpose;
		log_gamma_tr[t] = log_gamma_tr[t]';
		
		for(i in 1:K){
		 
		SD_OD[t][i] = MVPE_lpdf(y[t] | mu[i] + F * x[t], diag_matrix(V), kappa);

		}
}


  // t=1;
  
  for(i in 1:K){
                  
  lalpha[1][i] = SD_OD[1][i] + log_sum_exp( log_gamma_tr[1][i] + log(initdist) );
  
  }

  // Constructing the unnormalized, log-forward probabilities, and then scaling/normalizing them on the log scale
  // seems to be the most numerically stable way. 
   for(i in 1:K){

   lalpha[1][i] = lalpha[1][i] - log_sum_exp(lalpha[1]);

  }
  
  // Then convert to the probability scale
  alpha[1] = exp(lalpha[1]);
  
  print("sum(alpha[1])", sum(alpha[1]));
  
  for(t in 2:T){
  
  for(i in 1:K){

  lalpha[t][i] = SD_OD[t][i] +  log_sum_exp(  log_gamma_tr[t][i] + to_row_vector(lalpha[t-1]) );
 
    }
	
	// Then normalize;
  for(i in 1:K){
  
   lalpha[t][i] = lalpha[t][i] - log_sum_exp( lalpha[t] );

    }
	
	 alpha[t] = exp(lalpha[t]);
  
      print("sum(alpha[t])", sum(alpha[t]));
  
  }
 
}

A related question: Would it be valid to obtain the last entry of \mathbf{\alpha}_t by defining it in terms of it’s other entries? To illustrate, in my model block, you’ll see that I define the entries of my time-varying transition probability matrix in such a way.

In case you’re curious; I’m using cmdstan 2.36.0, and my OS is RedHat 8 x86_64. I can also share an R script that can create the synthetic data, and which also produces a .json file.

Thanks for reading!

Update:

My mistake was in the normalization step–I failed to notice that the normalization constant was changing for each iteration of this loop;

  for(i in 1:K){
  
   lalpha[t][i] = lalpha[t][i] - log_sum_exp( lalpha[t] );

    }

So that each element of ‘lalpha’ was in fact being scaled by a different quantity. I’ve since changed the normalization step to read likeso;

real normt;
.
.
.

normt = log_sum_exp( lalpha[t] ); 
  for(i in 1:K){
  
   lalpha[t][i] = lalpha[t][i] - normt;

    }

And I now find that these log-probabilities sum to 1, most of the time, with much better precision/accruracy;

So, a silly mistake on my part. Thanks again for reading, and sorry for taking up precious space on the discussion forum.

1 Like

Thanks for following up with a solution. This is a really tricky algorithm to implement given Stan’s lack of debugging tools.

1 Like