Speedup of independent factor analysis model

I have tried implementing a so-called independent factor analysis model which is a standard factor model (K factors) with an independent univariate mixture (C components) on the weights of each factor. I have not registered anything catastrophic yet, but it is quite slow.
Both mixtures and factor models are … problematic in the Bayesian setting, so right now I just want to perform a sanity check by doing cross-validation on some heldout data and the MAP solution.

The speed issues are likely just a consequence of applying a multi-normal likelihood on vectors with lengths in the 200s, but is there anything obvious I could do to speed this up further? Perhaps some pre-computation or vectorization to simplify the multi-normal likelihood? The test block is also made costlier by having to marginalize over a product of mixtures, generating a combinatorial number of components; should I maybe switch to computing that via sampling instead?

obs: there are some includes in the definition, but it is mostly just a lot of array manipulation necessary to form the features matrix, which is a combination of parameters (missing values) and data.

functions {
    int[] num2baseB(int num, int bits, int B) {
    /* convert a number `num` to a base `B` number using `bits` digits. */ 
        int base[bits];
        int factor = num;
        for (brev in 1:bits) {
            int b = bits + 1 - brev;
            base[b] = factor % B;
            factor = factor / B;
        }
        return base;
    }
}

data {
#include data_template.stan
    int include_prior;
    int include_likelihood;
    int K; // number of factors
    int C; // number of clusters per factor
}


transformed data {
    int P = P_cog + P_eeg + P_mri + P_dti;
    int combinations = 1;
    int holdoutsize = 0;
    for (k in 1:K) {
        combinations *= C; //count codes of length K with C symbols
    }
    for (n in 1:N) {
        holdoutsize += heldout[n]; //increment if true
    }

}

parameters {
#include parameter_template.stan
    matrix[K,C] mu;
    matrix<lower=0>[K,C] sigma;
    simplex[C] weight[K];
    
    matrix[N,K] W;
    matrix[K, P] B;
    
    vector<lower=0>[P] noise_std;    
}

transformed parameters {
    matrix[N,P] features;
    real testllk = 0;
    
    {
#include transformed_parameters_declaration_template.stan
#include transformed_parameters_template.stan
    features = append_col(cog,append_col(eeg,append_col(mri,dti))); //concatenation of feature blocks
    }
    
    //calculate marginal probability of heldout rows, integrating over W 
    if (holdoutsize>0) {
        vector[combinations] lps[holdoutsize];
        for (index in 1:combinations) {
            int cluster_indices[K] = num2baseB(index - 1, K, C);
            row_vector[K] mu_index; 
            row_vector[K] sigma_index; 
            real log_weight_index = 0;
            vector[P] mu_combination;
            matrix[P,P] cov_combination;
            matrix[P,P] cov_chol_combination;
            int test_index = 1;
            for (k in 1:K) {
                mu_index[k] = mu[k, cluster_indices[k] + 1];
                sigma_index[k] = sigma[k, cluster_indices[k] + 1];
                log_weight_index += log(weight[k, cluster_indices[k] + 1]);
            }
            mu_combination = to_vector(mu_index * B);
            cov_combination = crossprod(diag_pre_multiply(sigma_index, B)) + diag_matrix(noise_std);
            cov_chol_combination = cholesky_decompose(cov_combination);
            
            for (n in 1:N) {
                if (heldout[n]) {
                    lps[test_index, index] = log_weight_index + multi_normal_cholesky_lpdf(to_vector(features[n]) | mu_combination, cov_chol_combination);
                    test_index += 1;
                }
            }
        }
        //add all mixture likelihoods together
        for (ntest in 1:holdoutsize) {
            testllk += log_sum_exp(lps[ntest]);
        }      
    }
      
}


model { 
    if (include_prior) {
        noise_std ~ normal(0,1);
        to_vector(B) ~ normal(0,1);
        to_vector(mu) ~ normal(0,1);
        to_vector(sigma) ~ normal(0,1);
        for (k in 1:K) {
            weight[k] ~ dirichlet(rep_vector(1,C));
        }
    }
    //generate each column of W[,k] from univariate mixture with means mu[k,] and std-devs sigma[k,]  
    for (n in 1:N) {
        for (k in 1:K) {
            vector[C] lps;
            for (c in 1:C) {
                lps[c] = log(weight[k,c]) + normal_lpdf(W[n,k] | mu[k,c], sigma[k,c]);
            }
            target += log_sum_exp(lps);
        }
    }
    if (include_likelihood) {
        for (n in 1:N) {
            if (!heldout[n]) { 
                target += normal_lpdf(features[n] | W[n] * B, to_row_vector(noise_std));
            }
        }
    } 
}

edit: to be slightly more specific, the model is X_{features}={W}{B}+E where each element of W is drawn from a mixture specific to each column, and B is either without prior or with a simple Gaussian prior. E is Gaussian noise.

edit2: okay, one obvious thing to do is of course to not do testing in the transformed parameters block… Is there any way to run something only after optimizing? Or do I have to use different code for testing in the sampling and the MAP regime?

Hi -

First you should check out the loo package for doing cross-validation. There’s a whole method to it with Bayesian analysis in which you only need to compute the log-likelihood in generated quantities (which is probably where you should move all the transformed parameters stuff).

To test this model, you should really be using simulated data to then feed into it. Hard to know otherwise why it is or isn’t working (and also a way to keep the dataset small). The general advice is to start with a basic model based on simulated data (such as a model without the weights), and then build up from there. Maybe also start with known weights and then make the weights a parameter. I have no doubt that the dirichlet prior could be causing trouble given how its influence on the model will be subtle.

I’m a bit intrigued by this model with the use of weights over the factors & what is the motivation for that. It would be interesting to see a comparison of this to standard factor analysis, such as in a Stan case study! :D

You are quite right that I should move the transformed parameters stuff to generated quantities or, as I have now done, into its own separate file so that I can run it once at the end of the optimization procedure. I still wonder whether there is anything that can be done to speed up the mixture likelihoods though.

Is IS-LOO applicable in this setting seeing as I want the marginal likelihood with W_{test} integrated out? I could of course just use the conditional likelihood, but seems less than ideal.

Why do you think the Dirichlet will be a problem? I think it is even uniform at the moment.

The IFA is a variation of ICA and other factor models with non-Gaussian densities on the loadings - by modeling each weight with a mixture you can emulate different non-Gaussian densities such as heavy-tailed or multimodal ones.

If the parameter for a Dirichlet is constant, the gamma functions don’t get computed. But it’s still more efficident to just drop it.

But all of that nested Cholesky factorization is expensive!

Tried a non-marginalized version. MAP is a lot faster, but now it can barely sample. I get BFMI errors, very high Rhat, very low n_eff, max_treedepth, the whole shebang.

WARNING:pystan:Rhat above 1.1 or below 0.9 indicates that the chains very likely have not mixed
WARNING:pystan:200 of 200 iterations saturated the maximum tree depth of 10 (100.0%)
WARNING:pystan:Run again with max_treedepth larger than 10 to avoid saturation
WARNING:pystan:Chain 1: E-BFMI = 0.012576852673384556
WARNING:pystan:Chain 2: E-BFMI = 0.019297702165597518
WARNING:pystan:E-BFMI below 0.2 indicates you may need to reparameterize your model

Now, there are some non-identifiable bits which will obviously have terrible Rhat (permutation of factors), but I also get bad Rhat for values that should be identifiable, like the reconstruction X.

The model is as follows, changed to be a bit more universal.

data {
    int include_prior;
    int N; //observed rows
    int K; // number of factors
    int C; // number of clusters
    int P; // number of features

    matrix[N,P] features;

    real<lower=0> alpha;
    real<lower=0> sigma_sigma_noise;
}

parameters {
    positive_ordered[C] sigma[K];
    simplex[C] weight[K];

    matrix[N,K] z;
    
    matrix[P,K] B;
    vector<lower=0>[P] sigma_noise;
}

transformed parameters {
    matrix[P,K] B_unit = diag_post_multiply(B, inv_sqrt(columns_dot_self(B))');
    matrix[P,N] X = B_unit * z';
}   

model {
    if (include_prior) {
        sigma_noise ~ normal(0,sigma_sigma_noise);
        to_vector(B) ~ normal(0,1);

        for (k in 1:K) {
            sigma[k] ~ normal(0,1);
            weight[k] ~ dirichlet(rep_vector(alpha,C));
        }
    }

    for (n in 1:N) {
        for (k in 1:K) {
            vector[C] lps;
            for (c in 1:C) {
                lps[c] = normal_lpdf(z[n,k] | 0, sigma[k,c]);
            }
            target += log_sum_exp(log(weight[k]) + lps);
        }
        features[n] ~ normal(B_unit * z[n]', sigma_noise);
    } 
}