Updated cholesky corr parameterization testing

Hi everyone,

I’m asking if people could test out this parameterization of the Cholesky factor of correlation matrices? From my tests on a macpro m1, it looks to be much faster and I can easily sample matrices of dimension 100+. For example, a 10 x 10 matrix samples in 0.3s (1000 wu / 2000 sampling = 3k) vs the built-in of 1.5s plus I receive no warnings with this parameterization.

edit per @jonah’s comment to separate out the 2 programs

functions {
 matrix cholesky_corr_constrain_proposal_lp (vector y, int K) {
    matrix[K, K] L = identity_matrix(K);
    int counter = 1;
    
    for (i in 2 : K) {
        row_vector[i - 1] y_star = y[counter:counter + i - 2]';
        real dsy = dot_self(y_star);
        real alpha_r = 1 / (dsy  + 1);
        real gamma = sqrt(dsy + 2) * alpha_r;
        L[i, : i] = append_col(gamma * y_star, alpha_r);
        target += 0.5 * (i - 2) * log(dsy + 2) - i * log1p(dsy);
        counter += i - 1;
      }
    return L;
  }
}
data {
  int<lower=0> N;
  real<lower=0> eta;
}
parameters {
 vector[choose(N, 2)] y;
}
transformed parameters {
 matrix[N,  N] L = cholesky_corr_constrain_proposal_lp(y, N);
}
model {
    L ~ lkj_corr_cholesky(eta);
}

And compare to the built-in

data {
  int<lower=0> N;
  real<lower=0> eta;
}
parameters {
 cholesky_factor_corr[N] L_stan;
}
model {
    L_stan ~ lkj_corr_cholesky(eta);
}

Jacobian calc

Also, I’m asking if someone could write the function in Jax and numerically estimate the Jacobian to check my math. I’ve attached the write up of the math.
correlation_matrix_parameterization_jacobian.pdf (144.7 KB)

Tagging some folks who have been interested in the past @Bob_Carpenter @stevebronder @Stephen_Martin @andrjohns @aseyboldt @adamgorm @Adam_Haber @WardBrian @Seth_Axen

9 Likes

I’m getting a very strange exception when running this with bool_stan = 1.

library(cmdstanr)
stan_file <- write_stan_file(
"
// original code was missing the 'functions' block, I added it here
functions {
  matrix cholesky_corr_constrain_proposal_lp (vector y, int K) {
    matrix[K, K] L = identity_matrix(K);
    int counter = 1;
    
    for (i in 2 : K) {
        row_vector[i - 1] y_star = y[counter:counter + i - 2]';
        real dsy = dot_self(y_star);
        real alpha_r = 1 / (dsy  + 1);
        real gamma = sqrt(dsy + 2) * alpha_r;
        L[i, : i] = append_col(gamma * y_star, alpha_r);
        target += 0.5 * (i - 2) * log(dsy + 2) - i * log1p(dsy);
        counter += i - 1;
      }
    return L;
  }
}
data {
  int<lower=0> N;
  int bool_stan;
  real<lower=0> eta;
}
parameters {
 vector[choose(N, 2)] y;
 cholesky_factor_corr[bool_stan == 1 ? N : 0] L_stan;
}
transformed parameters {
 matrix[bool_stan != 1 ? N : 0,  bool_stan != 1 ? N : 0] L = cholesky_corr_constrain_proposal_lp(y, N);
}
model {
  if (bool_stan == 1) {
    L_stan ~ lkj_corr_cholesky(eta);
  } else {
    L ~ lkj_corr_cholesky(eta);
  }
}
"  
)
mod <- cmdstan_model(stan_file)
dat <- list(N = 10, bool_stan = 1, eta = 2)
fit <- mod$sample(data = dat)

This results in a bunch of exceptions like this

Chain 4 Exception: In serializer: Storage capacity [145] exceeded while writing value of size [100] from position [145]. This is an internal error, if you see it please report it as an issue on the Stan github repository. (in '/var/folders/s0/zfzm55px2nd2v__zlw5xfj2h0000gn/T/Rtmp6yrWQx/model-691a42747fa9.stan', line 29, column 1 to column 103)

I don’t get any exceptions when bool_stan = 0.

1 Like

That’s weird but it’s probably never come up because people never do a 0 dim cholesky_factor_corr!

If you run this with the N you want and with N = 0 does it happen when N = 0 only?

data {
  int<lower=0> N;
  real<lower=0> eta;
}
parameters {
 cholesky_factor_corr[N] L_stan;
}
model {
    L ~ lkj_corr_cholesky(eta);
}

Separately, do you see the faster sampling and ESS with the proposed parameterization?

I just edited the original post to so people can run it

It runs fine with N=10. With N=0 I don’t get the same exception I was getting before, I get:

Chain 1 Rejecting initial value:
Chain 1   Error evaluating the log probability at the initial value.
Chain 1 Exception: lkj_corr_cholesky_lpdf: columns of Cholesky factor is 0, but must be positive! (in '/var/folders/s0/zfzm55px2nd2v__zlw5xfj2h0000gn/T/Rtmp6yrWQx/model-691a372408a5.stan', line 10, column 4 to column 31)
1 Like

Yes! Very cool.

1 Like

Cool! When running with N = 10 and eta = 2 on my Thinkpad, I get similar times (around 0.3s total for 1000 warmup / 2000 sampling) and ESS using the two parametrizations, but the built-in gives some warnings and your new one does not. With N = 25, the new parametrization takes 1.0s and the built-in takes 1.6s. With N = 50, the new parametrization samples without any warnings, but the built-in rejects initial values until finishing unexpectedly.

Wonderful! 👍

@spinkney this looks super nice. Would it be okay with you if we tried to implement this in PyMC and see how it fares (with proper attribution ofc)?

1 Like

Yes! I’m curious how the performance compares in different implementations. This is similar to the onion method but I found it to be twice as fast for large (> 100 dims) correlation matrices. However, it’s possible that the log-det-jac I have is wrong or the performance difference is due to something particular to Stan, like the AD. Ideally we could have a comparison across stick breaking, the onion method, and this one where we compare:

  • with and without AD
  • over increasing matrix dimensions
  • different LKJ eta values
  • Prior only
  • In a model (where we can increase the dim)

And report ESS/grad, divergences, failures, etc.

1 Like

We can compare that easily in PyTensor by matching it with the automatic log/det jacobian from the transform, assuming any differences are not just precision issues.

Edit: I see you asked someone to try it against JAX already. Same idea :)

2 Likes

Promising results from the Pymc folks

Confirmation that the det Jacobian is correct and early tests show improved sampling!

8 Likes

To be fair, we didn’t really have a working transform before 🙈🙈

2 Likes