Mixture model running slow

Hi,

I am looking at time-dependent centroids of clusters in my data.

The data is many 2-D points measured over time (dataset is quite large).

In addition, I’m generating the probability of membership to its cluster for each point over time.

However, running this model is enormously slow. Is there any way I can make the model more efficient? Running the model from Python:

sm.vb(data=data, iter=5000, algorithm=“meanfield”, grad_samples=10, seed=42, verbose=True)

The STAN model:

data {
int<lower=0> D;  // number of dimensions
int<lower=0> N;  // number of samples per point point in time
int<lower=0> T;  // number of points in time
matrix[T, N] X[D]; // observed data
int<lower=0> K;  //number of clusters
int<lower=0> miss_ixs[T,N]; // missing indexes
}
parameters {
simplex[K] theta[T];            // mixing proportions
vector[D] mu[K,T];              // mixture component means
corr_matrix[D] Omega[K,T];      // covariance matrix
vector<lower=0>[D] tau[K,T];

vector<lower=0>[K] alpha;
vector<lower=0>[K] beta;
}
model {         

mu[:,1] ~ normal(0, 10);
tau[:,1] ~ cauchy(0, 10);
Omega[:,1] ~ lkj_corr(5.0); 
    
alpha[:] ~ normal(0.02, 1);
beta[:] ~ normal(0.05, 1);  

for (t in 2:T) {
    for (k in 1:K){
        mu[k,t] ~ normal(mu[k,t-1], alpha[k]);
        //tau[k,t] ~ cauchy(0, 10);
        tau[k,t] ~ normal(tau[k,t-1], beta[k]);
        Omega[k,t] ~ lkj_corr(5.0);   
    }
}

for (t in 1:T) {
    vector[K] log_theta = log(theta[t]);  // cache log calculation
    
    for (n in 1:N) {
        if (miss_ixs[t,n] == 0) {
            vector[K] lps = log_theta;
            for (k in 1:K)
                lps[k] += multi_normal_lpdf(to_vector(X[:,t,n]) | mu[k,t], 
quad_form_diag(Omega[k,t], tau[k,t]));
            target += log_sum_exp(lps);
        }
    }
}
}
generated quantities {
matrix[T, N] log_p_X; // Highest individual log-probability of membership to a cluster

for (t in 1:T) {
    vector[K] prob; // Individual log-probability of membership to each cluster
    vector[K] log_theta = log(theta[t]);
    for (n in 1:N) {
        if (miss_ixs[t,n] == 0) {
            vector[K] lps = log_theta;
            for (k in 1:K)
                prob[k] = lps[k] + multi_normal_lpdf(to_vector(X[:,t,n]) | mu[k,t], 
quad_form_diag(Omega[k,t], tau[k,t]));
        log_p_X[t,n] = log_sum_exp(prob);
        }
    }
}
}

Thanks in advance!

If you rewrite the bit

for (k in 1:K){
    mu[k,1] ~ normal(0, 10);
    tau[k,1] ~ cauchy(0, 10);
    Omega[k,1] ~ lkj_corr(5.0);   
}

as

mu[:,1] ~ normal(0, 10);
tau[:,1] ~ cauchy(0, 10);
Omega[:,1] ~ lkj_corr(5.0);   

you’l get a little bit of speed up from vectorization (see chapter 28 of the Stan manual). You can also do the same thing for the line Omega[k,t] ~ lkj_corr(5.0); in the other for loop.

Can you explain your data and problem a bit more? So at every point in time, you N 2D observations. Those observations come from a mixture of 2D Gaussians, and where those modes in the mixture are located are dependent on the where the mode in the previous time was. Is that all correct? How many modes are you using and what are T and N?

Mixture models are usually pretty hard because of identifiability issues. The good thing about Stan is that it’ll let you know if you’re running into these identifiability issues. For 1D mixtures you can order the locations using an ordered vector in Stan, but I’m not what you can do in 2D. @betanalpha has a nice notebook here you can take a look at that talks more about these identifiability issues in mixture models.

@arya Thanks for the feedback! I have modified the model in the post and in addition added your suggestions.

The dataset is observations of temperature and salinity at different depths and at different locations over time. I have removed the location “dimension” in the data by concatenating such that the datamatrix consists of 2D-points (temperature, salinity) in matrix[time-dates, depths * locations]. I hope you get what I mean else I can try a different explanation.

Of this data I wish to do a Gaussian mixture-model that is time dependent as well. By which I mean the clusters will move over time as at each date in time there are an entire set of new data points. Getting the right amount of clusters is a bit difficult.

If you have input to a different approach I would be happy to hear it.