When is log_mix preferable to log_sum_exp for mixtures?

I’ve been working with some models that contain mixture components, and I’m trying to see if the new (ish) log_mix interfaces can squeeze out a bit more performance over the usual log_sum_exp approach. So far things are looking mixed, which is confusing me.

For example, a simple 2 mixture model as follows with log_sum_exp

data {
  int N;
  vector[N] y;
}

parameters {
  real<lower=0, upper=1> lambda;
}

model {
  profile("all") {
    for (i in 1:N) {
      target += log_sum_exp(
        log(lambda) + normal_lpdf(y[i] | -1, 2),
        log1m(lambda) + normal_lpdf(y[i] | 3, 1)
      );
    }
  }
}

is about 50% slower (per gradient) than a log_mix equivalent:

data {
  int N;
  vector[N] y;
}

parameters {
  real <lower=0, upper=1> lambda;
}

model {
  profile("all") {
    for (i in 1:N) {
      target += log_mix(lambda,
        normal_lpdf(y[i] | -1, 2),
        normal_lpdf(y[i] | 3, 1)
      );
    }
  }
}

This is great! However, when I try to extend these approaches to an arbitrary number of components, things change. For example, a similar model with log_sum_exp:

data {
  int N;
  vector[N] y;
  int k;
  vector[k] means;
  vector[k] sds;
}

parameters {
  simplex[k] lambda;
}

model {
  profile("all") {
    vector[k] lp;
    for (i in 1:N) {
      for (j in 1:k) {
        lp[j] = log(lambda[j]) + normal_lpdf(y[i] | means[j], sds[j]);
      }
      target += log_sum_exp(lp);
    }
  }
}

versus with log_mix:

data {
  int N;
  vector[N] y;
  int k;
  vector[k] means;
  vector[k] sds;
}

parameters {
  simplex[k] lambda;
}

model {
  profile("all") {
    vector[k] lp;
    for (i in 1:N) {
      for (j in 1:k) {
        lp[j] = normal_lpdf(y[i] | means[j], sds[j]);
      }
      target += log_mix(lambda, lp);
    }
  }
}

In this case the log_sum_exp version is about 30% faster per gradient.

Now I consider it extremely likely that I’m missing something obvious in how to program this model, and my use of the vectorised log_mix is just silly, but I can’t quite figure it out and I think examples haven’t made it to the docs yet, per this issue. Would appreciate any pointers, or possibly a recommendation in the docs on when/how to apply this kind of log_mix approach. Happy to contribute something to the docs if so.

The R code I’m using to simulate data & fit models is below (sorry for all the 2-3 letter variable names):

library("cmdstanr")

mm <- cmdstan_model("mix.stan")
lm <- cmdstan_model("logmix.stan")

lambda <- 0.7
means <- c(-1, 3)
sds <- c(2, 1)
N <- 1000
y <- c(
    rnorm(N * lambda, -1, 3),
    rnorm(N * (1 - lambda), 2, 1)
)
data <- list(
    N = N,
    y = y
)

fitm <- mm$sample(data = data)
fitl <- lm$sample(data = data)

pm <- do.call(rbind, fitm$profiles())
pm$per_gradient <- pm[["total_time"]] / pm[["autodiff_calls"]]
pl <- do.call(rbind, fitl$profiles())
pl$per_gradient <- pl[["total_time"]] / pl[["autodiff_calls"]]

pm
pl

pm$per_gradient / pl$per_gradient
# [1] 1.347329 1.458891 1.513658 1.365013


mmk <- cmdstan_model("mix_k.stan")
lmk <- cmdstan_model("logmix_k.stan")

k <- 3
lambdak <- c(0.2, 0.3, 0.5)
means <- c(-1, 3, 8)
sds <- c(2, 1, 3)
N <- 1000
y <- do.call(c, lapply(1:k, function(i) rnorm(N * lambdak[[i]], means[[i]], sds[[i]])))
datak <- list(
    N = N,
    y = y,
    means = means,
    sds = sds,
    k = k
)

fitmk <- mmk$sample(data = datak)
fitlk <- lmk$sample(data = datak)

pmk <- do.call(rbind, fitmk$profiles())
pmk$per_gradient <- pmk[["total_time"]] / pmk[["autodiff_calls"]]
plk <- do.call(rbind, fitlk$profiles())
plk$per_gradient <- plk[["total_time"]] / plk[["autodiff_calls"]]

pmk
plk

plk$per_gradient / pmk$per_gradient
# [1] 1.299544 1.320193 1.297496 1.290129
1 Like

Thanks for posting this! Can you try a more fully vectorised log_mix:

data {
  int N;
  vector[N] y;
  int k;
  vector[k] means;
  vector[k] sds;
}

parameters {
  simplex[k] lambda;
}

model {
  profile("all") {
    vector[k] lp[N];
    for (i in 1:N) {
      for (j in 1:k) {
        lp[i][j] = normal_lpdf(y[i] | means[j], sds[j]);
      }
    }
    target += log_mix(lambda, lp);
  }
}

And post the profile results from this and the log_sum_exp implementation?

I just tested the additionally vectorised version I posted above, and it does appear to marginally outperform on your time-per-gradient metric:

> plk2$per_gradient / pmk$per_gradient
[1] 0.9525492 0.9898638 0.9927668 0.9839802

But there isn’t a great deal of difference. Looking at the log_mix source, there are definitely some areas for optimisation, so I’ll do a bit of testing.

Thanks! You replied at night for me so just seen this now. Yes this additional vectorisation seems maybe 10% faster.

I’ve taken the per-gradient measure from the cmdstanr docs, makes sense to me as the number of autodiff calls seems to vary quite a lot between chains for some models