I want to increase the speed of sampling using ”log_sum_exp” function

I am a Japanese postgraduate student studying cognitive psychology.
I am looking for someone who can show me how to increase the computational speed of Stan. Details are as follows.

I am using CmdStan to perform parameter estimation in cognitive diagnostic models.
As cognitive diagnostic models assume discrete variables in the latent variables, Stan needs to eliminate the discrete latent variables using the log_sum_exp function.

In my model, the log_sum_exp function needs to be applied for each value in a 1000 x 15 matrix, which is currently handled in a double for loop. This is taking a very long time to compute.

If anyone is familiar with how to speed up the processing of the log_sum_exp function, I would be very interested in your advice.

The actual code is as follows. The last part of the model block is a double for loop and I want to improve this.

I would be very grateful for your help!

data {
  int<lower=0> N;
  int<lower=0> J;
  int<lower=0> K;
  matrix[K,J] Q;
  matrix[N, 2] X;
  array[J, N] int Y;
}

transformed data {
  vector[2] Zeros = rep_vector(0, 2);
  matrix[2,2] X_inv = inverse(X'*X);
}

parameters {
  matrix[K, 2] beta;

  vector<lower=0, upper=1>[J] s;
  vector<lower=0, upper=1>[J] g_base;

  real<lower=0> sigma;
}

transformed parameters {
  vector<lower=0, upper=1>[J] g;
  for (j in 1:J) {
    if (g_base[j] > 1-s[j]) {
      g[j] = 1-s[j];
    } else {
      g[j] = g_base[j];
    }
  }

  matrix[N, K] delta;
  delta = inv_logit(X*beta');
}

model {
  target += -log(sigma);
  
  for (k in 1:K) {
    beta[k,] ~ multi_normal(Zeros, N*(sigma^2)*X_inv);
  }

  s ~ beta(2, 5);
  g_base ~ beta(2, 5);

  matrix[N, J] eta;
  eta = exp(log(delta + 1e-8)*Q);

  for (j in 1:J) {
    for (n in 1:N) {
      target += log_sum_exp(
        log(eta[n, j]) + bernoulli_lpmf(Y[j,n] | 1-s[j]),
        log(1-eta[n, j]) + bernoulli_lpmf(Y[j,n] | g[j])
      );
    }
  }
}

I’m not sure how large of a speed gain you’ll get, but one thing to try here is to compute eta on the log scale. That is instead of eta = exp(log(delta + 1e-8)*Q);, just do eta = log(delta + 1e-8)*Q; and then in the log_sum_exp use eta and log1m_exp(eta)

2 Likes

The usual advice is to use Cholesky factorized multinormal whenever possible:

transformed data {
  ...
  matrix[2,2] X_inv_chol = cholesky_decompose(N*inverse(X'*X));
}
model {
  ...
    beta[k,] ~ multi_normal_cholesky(Zeros, sigma*X_inv_chol);
  ...
}

Though in this case the matrix is only 2-by-2 so it won’t make much difference.

More efficiency might be gained by noticing that when 1 - s[j] > g_base[j] you have equal g[j] and 1 - s[j] and you don’t really need log_sum_exp().

  for (j in 1:J) {
    if (g_base[j] > 1 - s[j]) {
      target += bernoulli_lpmf(Y[j,:]| 1 - s[j]);
    } else {
      for (n in 1:N) {
        target += log_sum_exp(
          log(eta[n, j]) + bernoulli_lpmf(Y[j,n] | 1 - s[j]),
          log(1-eta[n, j]) + bernoulli_lpmf(Y[j,n] | g_base[j])
        );
      }
    }
  }

or combining with @jsocolar’s suggestion

  matrix[N, K] log_delta = log(delta + 1e-8);
  for (j in 1:J) {
    if (g_base[j] > 1 - s[j]) {
      target += bernoulli_lpmf(Y[j,:]| 1 - s[j]);
    } else {
      vector[N] log_eta = log_delta*Q[:,j];
      for (n in 1:N) {
        target += log_sum_exp(
          log_eta[n] + bernoulli_lpmf(Y[j,n] | 1 - s[j]),
          log1m_exp(log_eta[n]) + bernoulli_lpmf(Y[j,n] | g_base[j])
        );
      }
    }
  }

By the way, that scale-free prior on sigma formally assigns unbounded probability mass to sigma being almost zero and that’s a problem, at least in theory. Even if your data provides strong evidence against sigma=0, merely “strong” evidence is not unbounded evidence and the prior dominates. In practice, it’s hard to say if HMC is going to discover the sigma=0 mode but if it does sampling will either slow down immensely or fail completely.
Anyway, there’s little point in allowing extremely large or exremely small values of sigma. I’d recommend something like sigma ~ exponential(1.0) instead.

3 Likes

@jsocolar @nhuurre
Thank you very much!
I followed the code as you both suggested and it quadrupled the speed! The key is to process the calculation as it is on the log scale. It is very efficient.
It made it very easy to carry out the research for my master’s thesis.

2 Likes