Label Switching for 3 Components and 6 Classes - unimodal curves

Hello everyone,

I’m having issues with label switching in a model with 3 components and 6 classes that I want to infer. My dataset consists of 4 columns with N data points, each of which belongs to one of 6 classes: jjjj, bbbb, cccc, bbjj, ccjj, and bbcc. Here, j, b, and c are the individual components.

For inference, I’d like to enforce certain relationships across specific bins to mitigate label switching. For example, in bin 5, I’d like to ensure that p_j[5] > p_c[5] and p_j[5] > p_b[5], meaning that the j component in bin 5 should be greater than the other two components in that same bin. I want to impose similar relationships, alternating components across strategic bins where these relationships should be most evident.

One approach I tried was to use ordered[2] vectors in each bin to enforce these conditions. This is my stan code to infer 6 classes based on 3 individual components through unimodal curves. It is a bit complex, but the important thing I want to discuss is how to differentiate the individual components p_j, p_c, p_b. One approach I tried was to use ordered[2] vectors in each bin to enforce these conditions.

my_mixture_4d = """
functions {
  real partial_sum(array[,] int score_slice,
                   int start, int end,
                   vector yj,
                   vector yb,
                   vector yc,
                   vector theta)
  {
    real permutation_factor = 1.0/6;
    real partial_target = 0;
    
    vector[6] lp;   // Six distinct classes: jjjj, bbjj (+perm), bbbb, bbcc (+perm), jjcc (+perm), and cccc.
    vector[6] lp2;  // To handle permutations for bbjj
    vector[6] lp3;  // To handle permutations for bbcc
    vector[6] lp4;  // To handle permutations for jjcc.
    
    int slice_length = end - start + 1;

    //jjjj, bbjj, ccjj, ccbb, cccc, bbbb is the order of the classes
    for (k in 1:slice_length) {
      // Permutations for bbjj
      lp2[1] = log(permutation_factor) + log(yj)[score_slice[k,1]] + log(yj)[score_slice[k,2]] + log(yb)[score_slice[k,3]] + log(yb)[score_slice[k,4]];
      lp2[2] = log(permutation_factor) + log(yj)[score_slice[k,1]] + log(yb)[score_slice[k,2]] + log(yj)[score_slice[k,3]] + log(yb)[score_slice[k,4]];
      lp2[3] = log(permutation_factor) + log(yj)[score_slice[k,1]] + log(yb)[score_slice[k,2]] + log(yb)[score_slice[k,3]] + log(yj)[score_slice[k,4]];
      lp2[4] = log(permutation_factor) + log(yb)[score_slice[k,1]] + log(yj)[score_slice[k,2]] + log(yj)[score_slice[k,3]] + log(yb)[score_slice[k,4]];
      lp2[5] = log(permutation_factor) + log(yb)[score_slice[k,1]] + log(yj)[score_slice[k,2]] + log(yb)[score_slice[k,3]] + log(yj)[score_slice[k,4]];
      lp2[6] = log(permutation_factor) + log(yb)[score_slice[k,1]] + log(yb)[score_slice[k,2]] + log(yj)[score_slice[k,3]] + log(yj)[score_slice[k,4]];

      // Assign probabilities for the six classes
      lp[1] = log(yj)[score_slice[k,1]] + log(yj)[score_slice[k,2]] + log(yj)[score_slice[k,3]] + log(yj)[score_slice[k,4]]; // jjjj
      lp[2] = log_sum_exp(lp2);    // bbjj + permutations
      lp[6] = log(yb)[score_slice[k,1]] + log(yb)[score_slice[k,2]] + log(yb)[score_slice[k,3]] + log(yb)[score_slice[k,4]]; // bbbb

      // Permutations for bbcc
      lp3[1] = log(permutation_factor) + log(yb)[score_slice[k,1]] + log(yb)[score_slice[k,2]] + log(yc)[score_slice[k,3]] + log(yc)[score_slice[k,4]];
      lp3[2] = log(permutation_factor) + log(yb)[score_slice[k,1]] + log(yc)[score_slice[k,2]] + log(yb)[score_slice[k,3]] + log(yc)[score_slice[k,4]];
      lp3[3] = log(permutation_factor) + log(yb)[score_slice[k,1]] + log(yc)[score_slice[k,2]] + log(yc)[score_slice[k,3]] + log(yb)[score_slice[k,4]];
      lp3[4] = log(permutation_factor) + log(yc)[score_slice[k,1]] + log(yb)[score_slice[k,2]] + log(yb)[score_slice[k,3]] + log(yc)[score_slice[k,4]];
      lp3[5] = log(permutation_factor) + log(yc)[score_slice[k,1]] + log(yb)[score_slice[k,2]] + log(yc)[score_slice[k,3]] + log(yb)[score_slice[k,4]];
      lp3[6] = log(permutation_factor) + log(yc)[score_slice[k,1]] + log(yc)[score_slice[k,2]] + log(yb)[score_slice[k,3]] + log(yb)[score_slice[k,4]];
      lp[4] = log_sum_exp(lp3);    // bbcc + permutations

      // Permutations for jjcc
      lp4[1] = log(permutation_factor) + log(yj)[score_slice[k,1]] + log(yj)[score_slice[k,2]] + log(yc)[score_slice[k,3]] + log(yc)[score_slice[k,4]];
      lp4[2] = log(permutation_factor) + log(yj)[score_slice[k,1]] + log(yc)[score_slice[k,2]] + log(yj)[score_slice[k,3]] + log(yc)[score_slice[k,4]];
      lp4[3] = log(permutation_factor) + log(yj)[score_slice[k,1]] + log(yc)[score_slice[k,2]] + log(yc)[score_slice[k,3]] + log(yj)[score_slice[k,4]];
      lp4[4] = log(permutation_factor) + log(yc)[score_slice[k,1]] + log(yj)[score_slice[k,2]] + log(yj)[score_slice[k,3]] + log(yc)[score_slice[k,4]];
      lp4[5] = log(permutation_factor) + log(yc)[score_slice[k,1]] + log(yj)[score_slice[k,2]] + log(yc)[score_slice[k,3]] + log(yj)[score_slice[k,4]];
      lp4[6] = log(permutation_factor) + log(yc)[score_slice[k,1]] + log(yc)[score_slice[k,2]] + log(yj)[score_slice[k,3]] + log(yj)[score_slice[k,4]];
      lp[3] = log_sum_exp(lp4);    // jjcc + permutations

      lp[5] = log(yc)[score_slice[k,1]] + log(yc)[score_slice[k,2]] + log(yc)[score_slice[k,3]] + log(yc)[score_slice[k,4]]; // cccc

      partial_target += log_mix(theta, lp);
    }

    return partial_target;
  }
}

data {
  int<lower=1> m;  // steps in the discretization
  int<lower=1> N;  // data points
  array[N,4] int<lower=1, upper=m> score;  // b-tagging score for jet#1

  vector[m] prior_w_j;
  vector[m] prior_w_c;
  vector[m] prior_w_b;
  vector[6] prior_theta;

}

parameters {
  simplex[6] theta;  // Mixture coefficients of 6 classes as a simplex
  
  simplex[m] w_j_mode;  // Dirichlet weights for j
  vector<lower=0>[m-1] a_j;  // Normal parameters for j
  
  simplex[m] w_b_mode;  // Dirichlet weights for b
  vector<lower=0>[m-1] a_b;  // Normal parameters for b
  
  simplex[m] w_c_mode;  // Dirichlet weights for component c
  vector<lower=0>[m-1] a_c;  // Normal parameters for component c
  

  
}

transformed parameters {

  ordered[2] y_label_switch_1;   //this is what i would like to change, improve.
  ordered[2] y_label_switch_2;
  ordered[2] y_label_switch_3;
  ordered[2] y_label_switch_4;
  ordered[2] y_label_switch_5;
  ordered[2] y_label_switch_6;
  ordered[2] y_label_switch_7;


  vector[m] p_j;  // Approximate distribution for j
  vector[m-1] sign_j;  // Signs for combination
  vector[m] logp_j;  // Log-probabilities
  
  sign_j = rep_vector(-1.0, m-1);
  p_j = rep_vector(0.0, m);

  for (k in 1:m) {
    if (k > 1) {
      for (j in 1:(k-1)) {
        sign_j[j] = 1.0;
      }
    }
    logp_j[1] = 0.0;
    for (j in 2:m) {
      logp_j[j] = logp_j[j-1] + sign_j[j-1] * a_j[j-1];
    }
    p_j += softmax(logp_j) * w_j_mode[k];
  }

  vector[m] p_b;  // Approximate distribution for b
  vector[m-1] sign_b;  // Signs for combination
  vector[m] logp_b;  // Log-probabilities

  sign_b = rep_vector(-1.0, m-1);
  p_b = rep_vector(0.0, m);

  for (k in 1:m) {
    if (k > 1) {
      for (j in 1:(k-1)) {
        sign_b[j] = 1.0;
      }
    }
    logp_b[1] = 0.0;
    for (j in 2:m) {
      logp_b[j] = logp_b[j-1] + sign_b[j-1] * a_b[j-1];
    }
    p_b += softmax(logp_b) * w_b_mode[k];
  }
  
  vector[m] p_c;  // Approximate distribution for c
  vector[m-1] sign_c;  // Signs for combination
  vector[m] logp_c;  // Log-probabilities
  
  sign_c = rep_vector(-1.0, m-1);
  p_c = rep_vector(0.0, m);

  for (k in 1:m) {
    if (k > 1) {
      for (j in 1:(k-1)) {
        sign_c[j] = 1.0;
      }
    }
    logp_c[1] = 0.0;
    for (j in 2:m) {
      logp_c[j] = logp_c[j-1] + sign_c[j-1] * a_c[j-1];
    }
    p_c += softmax(logp_c) * w_c_mode[k];
  }

  y_label_switch_1[1] = p_c[5];
  y_label_switch_1[2] = p_j[5];
  
  y_label_switch_2[1] = p_c[4];
  y_label_switch_2[2] = p_j[4];
  
  y_label_switch_3[1] = p_b[5];
  y_label_switch_3[2] = p_j[5];

  y_label_switch_4[1] = p_b[8];
  y_label_switch_4[2] = p_c[8];

  y_label_switch_5[1] = p_j[8];
  y_label_switch_5[2] = p_c[8];
  
  y_label_switch_6[1] = p_c[17];
  y_label_switch_6[2] = p_b[17];

  y_label_switch_7[1] = p_j[17];
  y_label_switch_7[2] = p_b[17];

}

model {
  int grainsize = 1;  // parallel performance

  // Priors
  theta ~ dirichlet(prior_theta);

  a_j ~ normal(0, 0.5);
  a_b ~ normal(0, 0.5);
  a_c ~ normal(0, 0.5);

  w_j_mode ~ dirichlet(prior_w_j);
  w_b_mode ~ dirichlet(prior_w_b);
  w_c_mode ~ dirichlet(prior_w_c);

  target += reduce_sum(partial_sum, score, grainsize, p_j, p_b, p_c, theta);
}

"""

This method essentially involves setting up two ordered[2] vectors in each bin to enforce the relationships I want. It has been relatively effective but I’m pretty sure that this could be greatly improved.

I also considered using a different approach, such as combining components in an ordered[2] vector like this:

y_label_switch_1[1] = p_b[5] + p_c[5];  // p_b + p_c <  p_j in bin 5
y_label_switch_1[2] = p_j[5];

y_label_switch_2[1] = p_j[10] + p_b[10];  //etc
y_label_switch_2[2] = p_c[10];

// etc.

However, this approach also didn’t work well, as the conditions don’t always hold exactly.

Would anyone have suggestions on how to effectively enforce these relationships to help with label switching in my model?

Thanks so much for any insights!

First, why are you concerned about label switching? Most of the posterior inferences you will want to make with a mixture model will wind up marginalizing out the mixture components. I wrote a section in the User’s Guide discussing this. For example, if you look at posterior predictive inference for p(\widetilde{y} \mid y), you’ll find that the labels don’t appear—they get marginalized out of the likelihood.

Label switching will mess up is convergence analysis, effective sample size, etc., at the parameter level. But these are not usually things you care about directly in a mixture model—the labels are just a nuisance.

The other thing to do (when @andrewgelman isn’t watching) is to run a single chain. Or to run a single chain and then use a draw to initialize other chains. This will work when the modes are well separated enough to not switch within a single chain.

I would also hope there’d be some logic in the “Permutations for bbcc” section that you could lean on to make this more readable. Also, doing things like putting lp[3] and lp[4] in order would improve readability.

I do not think you’re going to be able to fix this with fancy footwork in the transformed parameters block. It also won’t make the underlying parameters converge, so it won’t solve the underlying problems you can run into in, say, systems that do cross-chain warmup.

P.S. In your Stan code, there’s a lot of expensive recompilation going on for things like log(permutation_factor) and log(yc). If you pull those out and assign them to local variables, you’ll find everything will be a lot faster.