Hi,
I am looking at time-dependent centroids of clusters in my data.
The data is many 2-D points measured over time (dataset is quite large).
In addition, I’m generating the probability of membership to its cluster for each point over time.
However, running this model is enormously slow. Is there any way I can make the model more efficient? Running the model from Python:
sm.vb(data=data, iter=5000, algorithm=“meanfield”, grad_samples=10, seed=42, verbose=True)
The STAN model:
data {
int<lower=0> D; // number of dimensions
int<lower=0> N; // number of samples per point point in time
int<lower=0> T; // number of points in time
matrix[T, N] X[D]; // observed data
int<lower=0> K; //number of clusters
int<lower=0> miss_ixs[T,N]; // missing indexes
}
parameters {
simplex[K] theta[T]; // mixing proportions
vector[D] mu[K,T]; // mixture component means
corr_matrix[D] Omega[K,T]; // covariance matrix
vector<lower=0>[D] tau[K,T];
vector<lower=0>[K] alpha;
vector<lower=0>[K] beta;
}
model {
mu[:,1] ~ normal(0, 10);
tau[:,1] ~ cauchy(0, 10);
Omega[:,1] ~ lkj_corr(5.0);
alpha[:] ~ normal(0.02, 1);
beta[:] ~ normal(0.05, 1);
for (t in 2:T) {
for (k in 1:K){
mu[k,t] ~ normal(mu[k,t-1], alpha[k]);
//tau[k,t] ~ cauchy(0, 10);
tau[k,t] ~ normal(tau[k,t-1], beta[k]);
Omega[k,t] ~ lkj_corr(5.0);
}
}
for (t in 1:T) {
vector[K] log_theta = log(theta[t]); // cache log calculation
for (n in 1:N) {
if (miss_ixs[t,n] == 0) {
vector[K] lps = log_theta;
for (k in 1:K)
lps[k] += multi_normal_lpdf(to_vector(X[:,t,n]) | mu[k,t],
quad_form_diag(Omega[k,t], tau[k,t]));
target += log_sum_exp(lps);
}
}
}
}
generated quantities {
matrix[T, N] log_p_X; // Highest individual log-probability of membership to a cluster
for (t in 1:T) {
vector[K] prob; // Individual log-probability of membership to each cluster
vector[K] log_theta = log(theta[t]);
for (n in 1:N) {
if (miss_ixs[t,n] == 0) {
vector[K] lps = log_theta;
for (k in 1:K)
prob[k] = lps[k] + multi_normal_lpdf(to_vector(X[:,t,n]) | mu[k,t],
quad_form_diag(Omega[k,t], tau[k,t]));
log_p_X[t,n] = log_sum_exp(prob);
}
}
}
}
Thanks in advance!