I need to fit a mixture model to a (quite large) data set. I know Stan can be slower at mixtures than Gibbs-type samplers. However, I also know that it is not currently possible to vectorize mixtures within Stan.
To that end, I decided to take matters into my own hands by creating my own (somewhat) vectorized code. However, I was surprised to see that my code is actually somewhat slower than doing the traditional “loop over n and K” code.
I’ll preface this by saying I’m a statistician, not a computer scientist. So I apologize if my logic is poor.
My code is below:
functions {
real mixture_normal_lpdf(
vector y, vector logp, vector mu, vector sigma
) {
real NEGHALFLOG2PI = -0.918938533204673;
int n = size(y);
int K = size(mu);
matrix[n,K] logComp;
real logProb = 0;
vector[n] M = rep_vector(negative_infinity(), n);
// Loop through num. components and compute likelihood of each component
for ( k in 1:K ) {
logComp[, k] = logp[k] + NEGHALFLOG2PI
- log(sigma[k])
- 0.5 * square((y - mu[k]) ./ sigma[k]);
}
// Loop through num. obs. and use log-sum-exp trick
for ( i in 1:n )
logProb += log_sum_exp(logComp[i, ]);
// Return log likelihood as a scalar
return logProb;
}
}
data {
int<lower=1> n; // Number of observations
int<lower=1> K; // number of mixture components
vector[n] y; // Outcome variable
int<lower=0,upper=1> vectorize; // Whether to use custom (vectorized) likelihood
}
parameters {
vector[K] mu;
vector<lower=0>[K] sigma;
simplex[K] p;
}
model {
vector[K] logf;
vector[K] logp = log(p);
mu ~ std_normal();
sigma ~ std_normal();
// Likelihood
if ( vectorize )
y ~ mixture_normal(logp, mu, sigma);
else {
for ( i in 1:n ) {
for ( k in 1:K )
logf[k] = logp[k] + normal_lpdf(y[i] | mu[k], sigma[k]);
target += log_sum_exp(logf);
}
}
}
I would have bet that my code would be at least as fast as the traditional version for the following reasons:
- My code has to loop through n + K iterations, whereas the traditional approach is n*K
- The traditional code results in n*K function calls to the
normal_lpdf
function - It should be more efficient to have a single sampling statement