Hello everyone,
I’m having issues with label switching in a model with 3 components and 6 classes that I want to infer. My dataset consists of 4 columns with N data points, each of which belongs to one of 6 classes: jjjj
, bbbb
, cccc
, bbjj
, ccjj
, and bbcc
. Here, j
, b
, and c
are the individual components.
For inference, I’d like to enforce certain relationships across specific bins to mitigate label switching. For example, in bin 5, I’d like to ensure that p_j[5] > p_c[5]
and p_j[5] > p_b[5]
, meaning that the j
component in bin 5 should be greater than the other two components in that same bin. I want to impose similar relationships, alternating components across strategic bins where these relationships should be most evident.
One approach I tried was to use ordered[2]
vectors in each bin to enforce these conditions. This is my stan code to infer 6 classes based on 3 individual components through unimodal curves. It is a bit complex, but the important thing I want to discuss is how to differentiate the individual components p_j
, p_c
, p_b
. One approach I tried was to use ordered[2]
vectors in each bin to enforce these conditions.
my_mixture_4d = """
functions {
real partial_sum(array[,] int score_slice,
int start, int end,
vector yj,
vector yb,
vector yc,
vector theta)
{
real permutation_factor = 1.0/6;
real partial_target = 0;
vector[6] lp; // Six distinct classes: jjjj, bbjj (+perm), bbbb, bbcc (+perm), jjcc (+perm), and cccc.
vector[6] lp2; // To handle permutations for bbjj
vector[6] lp3; // To handle permutations for bbcc
vector[6] lp4; // To handle permutations for jjcc.
int slice_length = end - start + 1;
//jjjj, bbjj, ccjj, ccbb, cccc, bbbb is the order of the classes
for (k in 1:slice_length) {
// Permutations for bbjj
lp2[1] = log(permutation_factor) + log(yj)[score_slice[k,1]] + log(yj)[score_slice[k,2]] + log(yb)[score_slice[k,3]] + log(yb)[score_slice[k,4]];
lp2[2] = log(permutation_factor) + log(yj)[score_slice[k,1]] + log(yb)[score_slice[k,2]] + log(yj)[score_slice[k,3]] + log(yb)[score_slice[k,4]];
lp2[3] = log(permutation_factor) + log(yj)[score_slice[k,1]] + log(yb)[score_slice[k,2]] + log(yb)[score_slice[k,3]] + log(yj)[score_slice[k,4]];
lp2[4] = log(permutation_factor) + log(yb)[score_slice[k,1]] + log(yj)[score_slice[k,2]] + log(yj)[score_slice[k,3]] + log(yb)[score_slice[k,4]];
lp2[5] = log(permutation_factor) + log(yb)[score_slice[k,1]] + log(yj)[score_slice[k,2]] + log(yb)[score_slice[k,3]] + log(yj)[score_slice[k,4]];
lp2[6] = log(permutation_factor) + log(yb)[score_slice[k,1]] + log(yb)[score_slice[k,2]] + log(yj)[score_slice[k,3]] + log(yj)[score_slice[k,4]];
// Assign probabilities for the six classes
lp[1] = log(yj)[score_slice[k,1]] + log(yj)[score_slice[k,2]] + log(yj)[score_slice[k,3]] + log(yj)[score_slice[k,4]]; // jjjj
lp[2] = log_sum_exp(lp2); // bbjj + permutations
lp[6] = log(yb)[score_slice[k,1]] + log(yb)[score_slice[k,2]] + log(yb)[score_slice[k,3]] + log(yb)[score_slice[k,4]]; // bbbb
// Permutations for bbcc
lp3[1] = log(permutation_factor) + log(yb)[score_slice[k,1]] + log(yb)[score_slice[k,2]] + log(yc)[score_slice[k,3]] + log(yc)[score_slice[k,4]];
lp3[2] = log(permutation_factor) + log(yb)[score_slice[k,1]] + log(yc)[score_slice[k,2]] + log(yb)[score_slice[k,3]] + log(yc)[score_slice[k,4]];
lp3[3] = log(permutation_factor) + log(yb)[score_slice[k,1]] + log(yc)[score_slice[k,2]] + log(yc)[score_slice[k,3]] + log(yb)[score_slice[k,4]];
lp3[4] = log(permutation_factor) + log(yc)[score_slice[k,1]] + log(yb)[score_slice[k,2]] + log(yb)[score_slice[k,3]] + log(yc)[score_slice[k,4]];
lp3[5] = log(permutation_factor) + log(yc)[score_slice[k,1]] + log(yb)[score_slice[k,2]] + log(yc)[score_slice[k,3]] + log(yb)[score_slice[k,4]];
lp3[6] = log(permutation_factor) + log(yc)[score_slice[k,1]] + log(yc)[score_slice[k,2]] + log(yb)[score_slice[k,3]] + log(yb)[score_slice[k,4]];
lp[4] = log_sum_exp(lp3); // bbcc + permutations
// Permutations for jjcc
lp4[1] = log(permutation_factor) + log(yj)[score_slice[k,1]] + log(yj)[score_slice[k,2]] + log(yc)[score_slice[k,3]] + log(yc)[score_slice[k,4]];
lp4[2] = log(permutation_factor) + log(yj)[score_slice[k,1]] + log(yc)[score_slice[k,2]] + log(yj)[score_slice[k,3]] + log(yc)[score_slice[k,4]];
lp4[3] = log(permutation_factor) + log(yj)[score_slice[k,1]] + log(yc)[score_slice[k,2]] + log(yc)[score_slice[k,3]] + log(yj)[score_slice[k,4]];
lp4[4] = log(permutation_factor) + log(yc)[score_slice[k,1]] + log(yj)[score_slice[k,2]] + log(yj)[score_slice[k,3]] + log(yc)[score_slice[k,4]];
lp4[5] = log(permutation_factor) + log(yc)[score_slice[k,1]] + log(yj)[score_slice[k,2]] + log(yc)[score_slice[k,3]] + log(yj)[score_slice[k,4]];
lp4[6] = log(permutation_factor) + log(yc)[score_slice[k,1]] + log(yc)[score_slice[k,2]] + log(yj)[score_slice[k,3]] + log(yj)[score_slice[k,4]];
lp[3] = log_sum_exp(lp4); // jjcc + permutations
lp[5] = log(yc)[score_slice[k,1]] + log(yc)[score_slice[k,2]] + log(yc)[score_slice[k,3]] + log(yc)[score_slice[k,4]]; // cccc
partial_target += log_mix(theta, lp);
}
return partial_target;
}
}
data {
int<lower=1> m; // steps in the discretization
int<lower=1> N; // data points
array[N,4] int<lower=1, upper=m> score; // b-tagging score for jet#1
vector[m] prior_w_j;
vector[m] prior_w_c;
vector[m] prior_w_b;
vector[6] prior_theta;
}
parameters {
simplex[6] theta; // Mixture coefficients of 6 classes as a simplex
simplex[m] w_j_mode; // Dirichlet weights for j
vector<lower=0>[m-1] a_j; // Normal parameters for j
simplex[m] w_b_mode; // Dirichlet weights for b
vector<lower=0>[m-1] a_b; // Normal parameters for b
simplex[m] w_c_mode; // Dirichlet weights for component c
vector<lower=0>[m-1] a_c; // Normal parameters for component c
}
transformed parameters {
ordered[2] y_label_switch_1; //this is what i would like to change, improve.
ordered[2] y_label_switch_2;
ordered[2] y_label_switch_3;
ordered[2] y_label_switch_4;
ordered[2] y_label_switch_5;
ordered[2] y_label_switch_6;
ordered[2] y_label_switch_7;
vector[m] p_j; // Approximate distribution for j
vector[m-1] sign_j; // Signs for combination
vector[m] logp_j; // Log-probabilities
sign_j = rep_vector(-1.0, m-1);
p_j = rep_vector(0.0, m);
for (k in 1:m) {
if (k > 1) {
for (j in 1:(k-1)) {
sign_j[j] = 1.0;
}
}
logp_j[1] = 0.0;
for (j in 2:m) {
logp_j[j] = logp_j[j-1] + sign_j[j-1] * a_j[j-1];
}
p_j += softmax(logp_j) * w_j_mode[k];
}
vector[m] p_b; // Approximate distribution for b
vector[m-1] sign_b; // Signs for combination
vector[m] logp_b; // Log-probabilities
sign_b = rep_vector(-1.0, m-1);
p_b = rep_vector(0.0, m);
for (k in 1:m) {
if (k > 1) {
for (j in 1:(k-1)) {
sign_b[j] = 1.0;
}
}
logp_b[1] = 0.0;
for (j in 2:m) {
logp_b[j] = logp_b[j-1] + sign_b[j-1] * a_b[j-1];
}
p_b += softmax(logp_b) * w_b_mode[k];
}
vector[m] p_c; // Approximate distribution for c
vector[m-1] sign_c; // Signs for combination
vector[m] logp_c; // Log-probabilities
sign_c = rep_vector(-1.0, m-1);
p_c = rep_vector(0.0, m);
for (k in 1:m) {
if (k > 1) {
for (j in 1:(k-1)) {
sign_c[j] = 1.0;
}
}
logp_c[1] = 0.0;
for (j in 2:m) {
logp_c[j] = logp_c[j-1] + sign_c[j-1] * a_c[j-1];
}
p_c += softmax(logp_c) * w_c_mode[k];
}
y_label_switch_1[1] = p_c[5];
y_label_switch_1[2] = p_j[5];
y_label_switch_2[1] = p_c[4];
y_label_switch_2[2] = p_j[4];
y_label_switch_3[1] = p_b[5];
y_label_switch_3[2] = p_j[5];
y_label_switch_4[1] = p_b[8];
y_label_switch_4[2] = p_c[8];
y_label_switch_5[1] = p_j[8];
y_label_switch_5[2] = p_c[8];
y_label_switch_6[1] = p_c[17];
y_label_switch_6[2] = p_b[17];
y_label_switch_7[1] = p_j[17];
y_label_switch_7[2] = p_b[17];
}
model {
int grainsize = 1; // parallel performance
// Priors
theta ~ dirichlet(prior_theta);
a_j ~ normal(0, 0.5);
a_b ~ normal(0, 0.5);
a_c ~ normal(0, 0.5);
w_j_mode ~ dirichlet(prior_w_j);
w_b_mode ~ dirichlet(prior_w_b);
w_c_mode ~ dirichlet(prior_w_c);
target += reduce_sum(partial_sum, score, grainsize, p_j, p_b, p_c, theta);
}
"""
This method essentially involves setting up two ordered[2]
vectors in each bin to enforce the relationships I want. It has been relatively effective but I’m pretty sure that this could be greatly improved.
I also considered using a different approach, such as combining components in an ordered[2]
vector like this:
y_label_switch_1[1] = p_b[5] + p_c[5]; // p_b + p_c < p_j in bin 5
y_label_switch_1[2] = p_j[5];
y_label_switch_2[1] = p_j[10] + p_b[10]; //etc
y_label_switch_2[2] = p_c[10];
// etc.
However, this approach also didn’t work well, as the conditions don’t always hold exactly.
Would anyone have suggestions on how to effectively enforce these relationships to help with label switching in my model?
Thanks so much for any insights!