I’m asking if people could test out this parameterization of the Cholesky factor of correlation matrices? From my tests on a macpro m1, it looks to be much faster and I can easily sample matrices of dimension 100+. For example, a 10 x 10 matrix samples in 0.3s (1000 wu / 2000 sampling = 3k) vs the built-in of 1.5s plus I receive no warnings with this parameterization.
edit per @jonah’s comment to separate out the 2 programs
functions {
matrix cholesky_corr_constrain_proposal_lp (vector y, int K) {
matrix[K, K] L = identity_matrix(K);
int counter = 1;
for (i in 2 : K) {
row_vector[i - 1] y_star = y[counter:counter + i - 2]';
real dsy = dot_self(y_star);
real alpha_r = 1 / (dsy + 1);
real gamma = sqrt(dsy + 2) * alpha_r;
L[i, : i] = append_col(gamma * y_star, alpha_r);
target += 0.5 * (i - 2) * log(dsy + 2) - i * log1p(dsy);
counter += i - 1;
}
return L;
}
}
data {
int<lower=0> N;
real<lower=0> eta;
}
parameters {
vector[choose(N, 2)] y;
}
transformed parameters {
matrix[N, N] L = cholesky_corr_constrain_proposal_lp(y, N);
}
model {
L ~ lkj_corr_cholesky(eta);
}
And compare to the built-in
data {
int<lower=0> N;
real<lower=0> eta;
}
parameters {
cholesky_factor_corr[N] L_stan;
}
model {
L_stan ~ lkj_corr_cholesky(eta);
}
Jacobian calc
Also, I’m asking if someone could write the function in Jax and numerically estimate the Jacobian to check my math. I’ve attached the write up of the math. correlation_matrix_parameterization_jacobian.pdf (144.7 KB)
I’m getting a very strange exception when running this with bool_stan = 1.
library(cmdstanr)
stan_file <- write_stan_file(
"
// original code was missing the 'functions' block, I added it here
functions {
matrix cholesky_corr_constrain_proposal_lp (vector y, int K) {
matrix[K, K] L = identity_matrix(K);
int counter = 1;
for (i in 2 : K) {
row_vector[i - 1] y_star = y[counter:counter + i - 2]';
real dsy = dot_self(y_star);
real alpha_r = 1 / (dsy + 1);
real gamma = sqrt(dsy + 2) * alpha_r;
L[i, : i] = append_col(gamma * y_star, alpha_r);
target += 0.5 * (i - 2) * log(dsy + 2) - i * log1p(dsy);
counter += i - 1;
}
return L;
}
}
data {
int<lower=0> N;
int bool_stan;
real<lower=0> eta;
}
parameters {
vector[choose(N, 2)] y;
cholesky_factor_corr[bool_stan == 1 ? N : 0] L_stan;
}
transformed parameters {
matrix[bool_stan != 1 ? N : 0, bool_stan != 1 ? N : 0] L = cholesky_corr_constrain_proposal_lp(y, N);
}
model {
if (bool_stan == 1) {
L_stan ~ lkj_corr_cholesky(eta);
} else {
L ~ lkj_corr_cholesky(eta);
}
}
"
)
mod <- cmdstan_model(stan_file)
dat <- list(N = 10, bool_stan = 1, eta = 2)
fit <- mod$sample(data = dat)
This results in a bunch of exceptions like this
Chain 4 Exception: In serializer: Storage capacity [145] exceeded while writing value of size [100] from position [145]. This is an internal error, if you see it please report it as an issue on the Stan github repository. (in '/var/folders/s0/zfzm55px2nd2v__zlw5xfj2h0000gn/T/Rtmp6yrWQx/model-691a42747fa9.stan', line 29, column 1 to column 103)
It runs fine with N=10. With N=0 I don’t get the same exception I was getting before, I get:
Chain 1 Rejecting initial value:
Chain 1 Error evaluating the log probability at the initial value.
Chain 1 Exception: lkj_corr_cholesky_lpdf: columns of Cholesky factor is 0, but must be positive! (in '/var/folders/s0/zfzm55px2nd2v__zlw5xfj2h0000gn/T/Rtmp6yrWQx/model-691a372408a5.stan', line 10, column 4 to column 31)
Cool! When running with N = 10 and eta = 2 on my Thinkpad, I get similar times (around 0.3s total for 1000 warmup / 2000 sampling) and ESS using the two parametrizations, but the built-in gives some warnings and your new one does not. With N = 25, the new parametrization takes 1.0s and the built-in takes 1.6s. With N = 50, the new parametrization samples without any warnings, but the built-in rejects initial values until finishing unexpectedly.
Yes! I’m curious how the performance compares in different implementations. This is similar to the onion method but I found it to be twice as fast for large (> 100 dims) correlation matrices. However, it’s possible that the log-det-jac I have is wrong or the performance difference is due to something particular to Stan, like the AD. Ideally we could have a comparison across stick breaking, the onion method, and this one where we compare:
We can compare that easily in PyTensor by matching it with the automatic log/det jacobian from the transform, assuming any differences are not just precision issues.
Edit: I see you asked someone to try it against JAX already. Same idea :)