Multi_normal_cholesky is slower than hand coded lpdf

This is true for a smallish size of L and increasing size of N. I haven’t tested bigger L sizes but that’s next on the list. The below shows that the stan-math code is running in about 22-24s and the user-defined function is ~17s.

library(cmdstanr)
library(mvtnorm)

sigma <- matrix(c(1.0, 0.8, -0.1, 0.05, 
                      0.8, 1.0, 0.05, -0.01,
                      -0.1, 0.05, 1.0, -0.2,
                      0.05, -0.01, -0.2, 1.0), byrow = T, 4, 4)
t(chol(sigma))
mu <- c(0.2, 0.4, -0.2, 0.)
y <- rmvnorm(10000, mean = mu, sigma = sigma)

mvn_mod <- cmdstan_model("mvn_test.stan")

mine <- mvn_mod$sample(
  data = list(P = length(mu),
                   N = nrow(y),
                   y = y,
                   flag = 1),
  parallel_chains = 4
 # seed = 1231
)

stans <- mvn_mod$sample(
  data = list(P = length(mu),
              N = nrow(y),
              y = y,
              flag = 0),
  parallel_chains = 4
  # seed = 1231
)

mine
stans

mvn_test.stan (918 Bytes)

@syclik helped me plumb through and make sure it was the derivatives that were slowing us down. Basically just translating the UDF into stan-math c++.

 * @tparam T_covar Type of scale.
 */
template <bool propto, typename T_y, typename T_loc, typename T_covar,
          require_any_not_vector_vt<is_stan_scalar, T_y, T_loc>* = nullptr,
          require_all_not_nonscalar_prim_or_rev_kernel_expression_t<
              T_y, T_loc, T_covar>* = nullptr>
return_type_t<T_y, T_loc, T_covar> multi_normal_cholesky_lpdf(
    const T_y& y, const T_loc& mu, const T_covar& L) {
  static const char* function = "multi_normal_cholesky_lpdf";
  using T_covar_elem = typename scalar_type<T_covar>::type;
  using T_return = return_type_t<T_y, T_loc, T_covar>;
  using T_partials_return = partials_return_t<T_y, T_loc, T_covar>;
  using matrix_partials_t
      = Eigen::Matrix<T_partials_return, Eigen::Dynamic, Eigen::Dynamic>;
  using T_y_ref = ref_type_t<T_y>;
  using T_mu_ref = ref_type_t<T_loc>;
  using T_L_ref = ref_type_t<T_covar>;

  // check_consistent_sizes_mvt(function, "y", y, "mu", mu);
  size_t number_of_y = size_mvt(y);
  size_t number_of_mu = size_mvt(mu);
  if (number_of_y == 0 || number_of_mu == 0) {
    return 0;
  }

  T_y_ref y_ref = y;
  T_mu_ref mu_ref = mu;
  T_L_ref L_ref = L;
  vector_seq_view<T_y_ref> y_vec(y_ref);
  vector_seq_view<T_mu_ref> mu_vec(mu_ref);
  const size_t size_vec = max_size_mvt(y, mu);

  const int size_y = y_vec[0].size();
  const int size_mu = mu_vec[0].size();

  // check size consistency of all random variables y
  for (size_t i = 1, size_mvt_y = size_mvt(y); i < size_mvt_y; i++) {
    check_size_match(function,
                     "Size of one of the vectors of "
                     "the random variable",
                     y_vec[i].size(),
                     "Size of the first vector of the "
                     "random variable",
                     size_y);
  }
  // check size consistency of all means mu
  for (size_t i = 1, size_mvt_mu = size_mvt(mu); i < size_mvt_mu; i++) {
    check_size_match(function,
                     "Size of one of the vectors of "
                     "the location variable",
                     mu_vec[i].size(),
                     "Size of the first vector of the "
                     "location variable",
                     size_mu);
  }

  check_size_match(function, "Size of random variable", size_y,
                   "size of location parameter", size_mu);
  check_size_match(function, "Size of random variable", size_y,
                   "rows of covariance parameter", L.rows());
  check_size_match(function, "Size of random variable", size_y,
                   "columns of covariance parameter", L.cols());

  for (size_t i = 0; i < size_vec; i++) {
    check_finite(function, "Location parameter", mu_vec[i]);
    check_not_nan(function, "Random variable", y_vec[i]);
  }

  if (unlikely(size_y == 0)) {
    return T_return(0);
  }
  
  
  
  const int K = size_y;
  const int N = size_vec;
  auto sqrt_det = -N * sum(log(diagonal(L)));
  auto norm_const =  -K * N * 0.5 * log(2 * pi());

  Eigen::Matrix<T_return, -1, -1> y_minus_mu(N, K);
  for (size_t i = 0; i < N; ++i) {
    y_minus_mu.row(i) = y[i] - mu;
  }
  auto mahab = sum(columns_dot_self(mdivide_left_tri_low(L, y_minus_mu.transpose())));
  
  return norm_const + sqrt_det - 0.5 * mahab;
}

Shouldn’t it be the case that hand-coded derivatives are faster than AD? If so, then there’s a huge opportunity to speed up the code. If not, then we can just replace the current code with the no-derivative code. It does beg the question if other functions are slowed down by hand-coded derivatives.

3 Likes

Yes, hand-coded derivatives should never be slower than autodiff. After altering the stan-math C++ function, it was down to ~17 s, but with the caveat that the numerics were slightly different. The posterior distributions were being estimated similarly, but the draws were no longer identical.

Btw, we knocked out the input checking code as a test. It helped a little, but nothing significant.

Is it possible that the stan function you provided takes advantage of linear algebra shortcuts, which are not used in the C++ functions due to reasons as it needs to handle all sorts of data structures?

What we have not yet provided in Stan math C++ are varmat optimised versions of multi_normal_cholesky, I think… but I guess that providing varmat optimised versions for the function and the gradient should give nice speedups. Maybe @stevebronder knows more on this?

It’s basically the same functions (mdivide_left_tri_low and columns_dot_self) just without the derivatives

mvn_test.R (521 Bytes)
mvn_test_cfa.stan (1.2 KB)

Attaching my basic CFA example here. In this example, there’s a consistent benefit toward using the built-in multi_normal_cholesky (nearly 2x as fast).

1 Like

What @syclik is saying would be true if we were all perfect coders. But we’re not. There are two things going on. Memory and speed. Some of our functions are actually slower with hand-coded derivatives, but much tighter in memory. I think that may be true for some of the matrix functions. To write tight autodiff code that’s just as fast would require blocking algorithms. So while we could write faster derivative code, none of us are matrix operation specialists, so it’d be challenging.

Making the built-in derivatives faster everywhere would be a great improvement, especially in this critical function. But we don’t want to sacrifice memory for speed and memory is a huge bottleneck in big problems for autodiff.

2 Likes

+1. Thanks for the clarification, @Bob_Carpenter! I wasn’t thinking about the impact of the memory tradeoff when I wrote that, especially in the context of matrix functions.

1 Like