Enforcing order for mixture of transition matrices

I’m trying to fit a mixture of transition matrices but I’m a bit stuck on how I can “order” the k matrices. I’ve basically just adapted the normal mixture from the documentation here but the ordered part is tripping me up. I’m sure this not the optimal model specification, but I could converge on a k=2 model with fake data but getting issues beyond that. For example, the data below will put all the data in one cluster.

// Pathways model -- Finite mixture of transition matrices

data {
  int<lower=1> K; // the number of clusters
  int<lower=1> N; // the number of datapoints (traversals)

  int<lower=1> S; // the number of observable states

  int<lower=0> X[N, S, S]; // the 3d-array of transitions (members * to-state * from-state)
}


parameters {
  simplex[K] theta;  // mixture distribution
  simplex[S] transitions[K, S];  //transition matrix is array of simplices

}

model {

  vector[K] log_theta = log(theta); //cache log calculation


  for (n in 1:N) {
    vector[K] lps = log_theta;
    for(k in 1:K){
      for (s in 1:S){
        lps[k] += multinomial_lpmf(X[n,:,s]| transitions[k, s]);
      }
    }
   //  lps[k] += normal_lpdf(y[n]| mu[k], sigma[k]);

    target += log_sum_exp(lps) ;
  }

}

Here’s the fake data (R):

##Input data
n_clusters <- 3
n_states <- 6
mixture_dist <- c(.6, .20, .20)

N <- 10000

#cluster 1 -- middle (same vector for all states for simplicity)

p1 <- rep(1/n_states, n_states)

#Cluster 2 - increasing
inc <- 1:n_states/n_states
p2 <- inc/sum(inc)

#Cluster 3 - decreasing
p3 <- rev(p2)

X <- array(, dim=c(N, n_states, n_states))

for(i in 1:6000){
  X[i, ,] <- rmultinom(6, 100, prob= p1)
}

for(i in 6001:8000){
  X[i, ,] <- rmultinom(6, 100, prob= p2)
}

for(i in 8001:10000){
  X[i, ,] <- rmultinom(6, 100, prob= p3)
}


##############
input_data <- list(
  K=n_clusters,
  N=N,
  S=n_states,
  X=X
)


stanmod1 <-
  stan_model('pathways.stan',
             model_name = 'sample'
  )

fit1 <-
  optimizing(
    stanmod1, data=input_data
  )
1 Like

Hi, this is a good question!

Generally mixtures of complex objects can be tricky, precisely because enforcing some ordering is a problem (e.g. I am not aware of a reliable method to fit a mixture of bivariate Gaussians). If your data are well behaved, you might be able to choose one element of the simplex and force ordering across this element. This will work if that one element is actually the main discriminant between the mixture components. If multiple components would have similar values in the chosen element, this will not work. There might be other quantities you may derive and order by (e.g. entropy), but the problem will be the same - if multiple components are actually similar in that quantity, you are likely to have fitting issues.

If you go this way, you will however probably need to do the simplex transform yourself as there is no default way to enforce this ordering (i.e. you would order some of the unconstrained values, then apply the transform and add the log of the Jacobian correction). The transform is described at 10.7 Unit simplex | Stan Reference Manual

In some cases, the multiple modes implied by different orderings are well separated so that you don’t get divergences during sampling even without ordering (or with just a partial ordering), only low ESS and high Rhat. In that case you may also be able to just postprocess the samples to enforce a shared ordering (after this post processing you can check ESS + Rhat again and if they are good, it worked). There is even a package dedicated to this: https://cran.r-project.org/package=label.switching (I’ve never used it myself)

Best of luck with your model!

2 Likes

Thanks for the great answer. I’ll look into some of these suggestions and see what is most viable. Picking a piece that differentiates the clusters to build the simplex around is indeed tricky. My intuition is that the diagonal of the transition matrix is the place would provide the most discrimination but where to go from there is not obvious.

I can always fall back on EM!