Artifacts of sign-fixing in models with sign ambiguity


#1

When working with a model where there is a sign ambiguity, it is often encouraged to fix one of the signs to identify the model. For a (rank 1) factor model, for instance, where we model Y=bw^T we can simultaneously switch the signs of b and w without affecting Y, but we cannot switch the signs if b_1 is constrained to be positive. With K factors, this could be implemented as

parameters {
    row_vector<lower=0>[K] b1;
    matrix[S-1,K] B_tilde;
}
transformed parameters {
    matrix[S,K] B = append_row(b1, B_tilde);
}

Running a model with this feature implemented, I get two dissimilar chains, with the average of b (of length 30) for the two chains looking like
unflipped
It is likely most of you have the same sneaking suspicion as I did that these two chains are not actually that different, and if we switch the signs of one of the two chains we get two very similar chain means:
flipped
The only real difference is the first element, on which the positivity constraint was imposed. It appears that elements 2-29 were initialized close to the two sign-flipped modes, but for the chain with “wrong” signs where the first element should be negative to match the other elements, the positivity constraint instead pushed it to around 0, it being dominated by the remaining sign-flipped elements.

The first element appears to have a high mean when the sign is appropriate, so it doesn’t seem like I could avoid the issue by simply putting the constraint on another variable. In other words, this seems to compromise the whole idea of using fixed signs to combat sign ambiguity, unless the model can be made to mix so well that it will escape the degenerate mode.

Is there another more robust way to avoid sign ambiguity?


#2

The typical approach is to use the exponential function, which maps monotonically from unconstrained values to positive values.

In Stan, you can also just constrain a variable to be positive, which will do this all for you under the hood and apply the appropriate Jacobian so that you can still put a distribution directly on the constrained value. (Otherwise, you need to put a distribution on the log value before exponentiating.)

You can’t expect Markov chains to look similar across runs.


#3

The functions plotted are sample means for K variables, not traces. Means should be consistent across chains if things mix. This one discrepancy I detail above causes a host of rhat errors.

It’s only the first element that is constrained to be positive (using truncation), hoping that the remaining signs reorient themselves with respect to that first sign. This just doesn’t happen, as demonstrated.

Breaking it down: my original problem is bimodal, with mode 2 being the reflection of mode 1 you get from switching all signs. I try to implement a constraint that removes the symmetry mode, but as I can only impose one truncation constraint the symmetry mode does not appear entirely, leading to a weak mode that fails to mix with the “true” mode. Even if it did mix, the marginals would all be zero-centered, making it hard to interpret.


#4

Is b subject to a sum-to-zero constraint here? Or are my eyes misleading me. I’d be interested in seeing the full Stan model.

Is there some motivation to prefer one mode over the other?


#5

I normalize each factor (columns of B) to have unit length, but there are no additional constraints directly on b1.

There is a significant amount of bells and whistles and transformed priors in the full model, so I am pretty sure the small snippet above is more meaningful, but here it is:

functions { 
   //sequential ordinal likelihood from BRMS 
   vector sratio_probit_vec(real mu, vector thres, real disc) {
     int ncat = num_elements(thres) + 1; 
     vector[ncat] p; 
     vector[ncat - 1] q; 
     for (k in 1:(ncat - 1)) { 
       q[k] = 1 - Phi(disc * (thres[k] - mu)); 
       p[k] = 1 - q[k]; 
       for (kk in 1:(k - 1)) p[k] = p[k] * q[kk]; 
     } 
     p[ncat] = prod(q); 
     return p;
   }

   /* sratio-probit log-PDF for a single response 
   * Args: 
   *   y: response category 
   *   mu: linear predictor 
   *   thres: ordinal thresholds 
   *   disc: discrimination parameter 
   * Returns: 
   *   a scalar to be added to the log posterior 
   */ 
   real sratio_probit_lpmf(int y, real mu, vector thres, real disc) { 
     int ncat = num_elements(thres) + 1; 
     vector[ncat] p = sratio_probit_vec(mu, thres, disc);
     return categorical_lpmf(y | p); 
   }
   
    real sratio_probit_rng(real mu, vector thres, real disc) { 
     int ncat = num_elements(thres) + 1; 
     vector[ncat] p = sratio_probit_vec(mu, thres, disc);
     return categorical_rng(p); 
    }
   
   int sratio_probit_max(real mu, vector thres, real disc)    { 
     int ncat = num_elements(thres) + 1; 
     int catmax = 0;
     real pmax = 0;
     vector[ncat] p = sratio_probit_vec(mu, thres, disc);
     for (i in 1:ncat) {
         if (pmax < p[i]) {
             pmax =  p[i];
             catmax = i;
         }
     }
     return catmax; 
   }
} 

data {
    int<lower=1> D;
    int<lower=1> N;
    int<lower=1> S;
    int<lower=1> K;
    int testmin;
    int testmax;
    int<lower=1,upper=D> group[N];
    int<lower=testmin, upper=testmax> panss[S,N];
    real<lower=0> disc;  
    int<lower=0,upper=1> include_likelihood;
    real<lower=0> alpha;
}

transformed data {
    int ncat = testmax - testmin + 1;
    int factor_elements = S;
    int load_elements = K;
}

parameters {
    row_vector<lower=0>[K] b1;
    matrix[S-1,K] B_tilde;
    matrix[K,N] W_tilde;
    vector[testmax-1] thres[D];
    row_vector[N] person_bias;
    vector[S] symptom_bias;
    positive_ordered[K] sigma_tilde;
    real<lower=0> scale;
    row_vector<lower=0>[N] load_lambda; 
    matrix<lower=0>[K,N] local_load_lambda;
    real<lower=0,upper=1> split;

}

transformed parameters {
    matrix[K,N] W; //loadings/weights
    matrix[S,K] B = append_row(b1, B_tilde); //factors
    vector[K] sigma = sigma_tilde/sum(sigma_tilde); //factor strength
    matrix[S,N] F; //latent matrix
    matrix[S,N] F_bias; //bias

    matrix<lower=0,upper=1>[K,N] load_shrinkage; //sparsity weight
    row_vector<lower=0, upper=load_elements>[N] load_total;
    row_vector[N] load_jac_vec;
    
    vector[testmax] p[D];
    for (d in 1:D) {
        p[d] = sratio_probit_vec(0., thres[d], disc);
    }
    { //unit length factor
        vector[K] inv_blength;
        for (k in 1:K) inv_blength[k] = 1./(sqrt(sum(square(B[,k]))));
        B = diag_post_multiply(B, inv_blength);
    }
    {//Loadings
    matrix[K,N] lambda_prod;
    matrix[K,N] lambda_tilde;
    

    //regularized horseshoe
    real tau = 1.;
    real c2 = 1.;
    lambda_prod = local_load_lambda .* rep_matrix(load_lambda, K);
    lambda_tilde = sqrt( c2 * square(lambda_prod) ./ (c2 + square(tau) * square(lambda_prod) ));
    load_shrinkage = 1. - 1. ./ (1. + square(tau) * square(lambda_prod));
    load_total = rep_row_vector(1,K) * load_shrinkage;
    load_jac_vec = rep_row_vector(1,K) * (2. * load_shrinkage .* (1. - load_shrinkage));
    load_jac_vec = log(load_jac_vec) - log(load_lambda);
    W = tau * lambda_tilde .* W_tilde;
    }

    //bias terms
    F_bias = (rep_matrix(symptom_bias, N) + rep_matrix(person_bias, S));
    
    //latent matrix
    F = scale * ( split * B * diag_pre_multiply(sigma, W) + (1.-split) * F_bias );
    
}

model {
    b1 ~ normal(0,1);
    to_vector(B_tilde) ~ normal(0.,1.);
    to_vector(W_tilde) ~ normal(0.,1.);
    
    split ~ beta(1.,1.);
    sigma_tilde ~ gamma(alpha, 1.);
    person_bias ~ normal(0, 1.);
    symptom_bias ~ normal(0., 1.);
    
    //special sparsity prior
    to_vector(local_load_lambda) ~ cauchy(0., 1.);
    load_total ~ normal(0., load_elements/3.);
    target += sum((load_jac_vec)); 
    
    scale ~ normal(0., 1.);
    
    //special threshold prior
    for (d in 1:D) { 
        p[d] ~ dirichlet(rep_vector(1., ncat));
        target += sum(log(p[d,1:(ncat-1)])) + normal_lpdf(thres[d] | 0., 1./disc) - normal_lcdf(disc * thres[d] | 0., 1.);
    } 

    //ordinal likelihood
    if (include_likelihood) {
        for (n in 1:N) {
            for (s in 1:S) {
                target += sratio_probit_lpmf(panss[s,n] | F[s,n], thres[group[n]], disc);
            }
        }
    }
}

generated quantities {
    matrix[S,N] mode;
    matrix[S,N] ppc;
    for (n in 1:N) {
        for (s in 1:S) {
            mode[s,n] = sratio_probit_max(F[s,n], thres[group[n]], disc);
            ppc[s,n] = sratio_probit_rng(F[s,n], thres[group[n]], disc);
        }
    }
}

#6

Not during prediction, but if you want to quantify the uncertainty around each factor it becomes hard to analyze a posteriori if you let it jump between the two modes since the two posterior modes will blend together. To disentangle them you would have to do a clustering of the samples post hoc.


#7

It depends what you want to do with the inferences. There’s some discussion of this in the manual. If you want to only make posterior predictions, then you usually don’t need to know things like the mixture component.


#8

I believe I have been over most of what’s available in the manual; I am basically trying to follow the guidelines laid out in e.g. the manual, Farouni’s factor model walkthrough, and Betancourt’s case studies on identifiability. I agree that it is not relevant for predictions, but I am interested in the factors themselves - I am trying to subtype diseases using factor models, with the assumption that the diagnosis falls on a spectrum, so that mixtures/clustering is inappropriate. So I need the factors to be fully identifiable. I would be willing to go pretty far in terms of changing the likelihood and priors to achieve this, as long as the factor structure remains.


#9

From the model it looks like you are normalizing Sigma to sum to zero, but not B?

If you want B to have unit length, you could compose it from unit_vectors, but then you’d need to make some modifications to keep the sign on the first element positive.

Also, as Ben says about this parameterization of the ordered simplex:

and some more discussion on this sort of thing in this thread:

Something to look into, though I’m not sure that is the root cause of the identifiability issue.


#10

sigma should be sum to 1 and positive (on the simplex). I haven’t had any noticeable problems with the ordered gamma parameterization so far. The example I gave above was for a rank 1 factor model, so it doesn’t even come into play there. B does have unit norm, it’s just handled in the transformed block as:

    { //unit length factor
        vector[K] inv_blength;
        for (k in 1:K) inv_blength[k] = 1./(sqrt(sum(square(B[,k]))));
        B = diag_post_multiply(B, inv_blength);
    }

which I think ought to work fine; it just makes the likelihood a bit more complex. Seems to be the same transform used by unit_vector type, although there might be a smarter way to code it. The implicit prior on B should be uniform on the truncated sphere (normalized Gaussian is uniform I believe?).

Thanks for the links!