I’ve been working with some models that contain mixture components, and I’m trying to see if the new (ish) log_mix
interfaces can squeeze out a bit more performance over the usual log_sum_exp
approach. So far things are looking mixed, which is confusing me.
For example, a simple 2 mixture model as follows with log_sum_exp
data {
int N;
vector[N] y;
}
parameters {
real<lower=0, upper=1> lambda;
}
model {
profile("all") {
for (i in 1:N) {
target += log_sum_exp(
log(lambda) + normal_lpdf(y[i] | -1, 2),
log1m(lambda) + normal_lpdf(y[i] | 3, 1)
);
}
}
}
is about 50% slower (per gradient) than a log_mix
equivalent:
data {
int N;
vector[N] y;
}
parameters {
real <lower=0, upper=1> lambda;
}
model {
profile("all") {
for (i in 1:N) {
target += log_mix(lambda,
normal_lpdf(y[i] | -1, 2),
normal_lpdf(y[i] | 3, 1)
);
}
}
}
This is great! However, when I try to extend these approaches to an arbitrary number of components, things change. For example, a similar model with log_sum_exp
:
data {
int N;
vector[N] y;
int k;
vector[k] means;
vector[k] sds;
}
parameters {
simplex[k] lambda;
}
model {
profile("all") {
vector[k] lp;
for (i in 1:N) {
for (j in 1:k) {
lp[j] = log(lambda[j]) + normal_lpdf(y[i] | means[j], sds[j]);
}
target += log_sum_exp(lp);
}
}
}
versus with log_mix
:
data {
int N;
vector[N] y;
int k;
vector[k] means;
vector[k] sds;
}
parameters {
simplex[k] lambda;
}
model {
profile("all") {
vector[k] lp;
for (i in 1:N) {
for (j in 1:k) {
lp[j] = normal_lpdf(y[i] | means[j], sds[j]);
}
target += log_mix(lambda, lp);
}
}
}
In this case the log_sum_exp
version is about 30% faster per gradient.
Now I consider it extremely likely that I’m missing something obvious in how to program this model, and my use of the vectorised log_mix
is just silly, but I can’t quite figure it out and I think examples haven’t made it to the docs yet, per this issue. Would appreciate any pointers, or possibly a recommendation in the docs on when/how to apply this kind of log_mix
approach. Happy to contribute something to the docs if so.
The R code I’m using to simulate data & fit models is below (sorry for all the 2-3 letter variable names):
library("cmdstanr")
mm <- cmdstan_model("mix.stan")
lm <- cmdstan_model("logmix.stan")
lambda <- 0.7
means <- c(-1, 3)
sds <- c(2, 1)
N <- 1000
y <- c(
rnorm(N * lambda, -1, 3),
rnorm(N * (1 - lambda), 2, 1)
)
data <- list(
N = N,
y = y
)
fitm <- mm$sample(data = data)
fitl <- lm$sample(data = data)
pm <- do.call(rbind, fitm$profiles())
pm$per_gradient <- pm[["total_time"]] / pm[["autodiff_calls"]]
pl <- do.call(rbind, fitl$profiles())
pl$per_gradient <- pl[["total_time"]] / pl[["autodiff_calls"]]
pm
pl
pm$per_gradient / pl$per_gradient
# [1] 1.347329 1.458891 1.513658 1.365013
mmk <- cmdstan_model("mix_k.stan")
lmk <- cmdstan_model("logmix_k.stan")
k <- 3
lambdak <- c(0.2, 0.3, 0.5)
means <- c(-1, 3, 8)
sds <- c(2, 1, 3)
N <- 1000
y <- do.call(c, lapply(1:k, function(i) rnorm(N * lambdak[[i]], means[[i]], sds[[i]])))
datak <- list(
N = N,
y = y,
means = means,
sds = sds,
k = k
)
fitmk <- mmk$sample(data = datak)
fitlk <- lmk$sample(data = datak)
pmk <- do.call(rbind, fitmk$profiles())
pmk$per_gradient <- pmk[["total_time"]] / pmk[["autodiff_calls"]]
plk <- do.call(rbind, fitlk$profiles())
plk$per_gradient <- plk[["total_time"]] / plk[["autodiff_calls"]]
pmk
plk
plk$per_gradient / pmk$per_gradient
# [1] 1.299544 1.320193 1.297496 1.290129