Hi, I’m having issues fitting a simple multivariate normal model. The model works well for many datasets, but not when input data is highly correlated, e.g., cor(x, y) = 0.9. Highly correlated data happens to be the scenario I’m developing a model for.
Below a simplified example:
data{
int<lower = 0> n;
int<lower = 0> p;
matrix[n, p] x;
// hyperparameters
real<lower = 0> priorEta;
real<lower = 0> priorSdx;
}
transformed data{
vector[p] zeros = rep_vector(0, p);
}
parameters{
cholesky_factor_corr[p] Lcorr;
vector<lower=0>[p] sdx;
}
model{
sdx ~ cauchy(0, priorSdx);
Lcorr ~ lkj_corr_cholesky(priorEta);
for (i in 1:n) {
x[i, ] ~ multi_normal_cholesky(zeros, diag_pre_multiply(sdx, Lcorr));
}
}
And the R-code I use to run the model.
rm(list = ls())
library(rstan)
n <- 100
p <- 3
r <- matrix(c(1.0, 0.5, 0.9,
0.5, 1.0, 0.0,
0.9, 0.0, 1.0), 3, 3)
r <- as.matrix(Matrix::nearPD(r, corr = TRUE)$mat)
sdx <- rep(1, p)
s <- diag(sdx) %*% r %*% diag(sdx)
x <- mvtnorm::rmvnorm(n, sigma = s)
dataList <- list(
x = x,
n = n,
p = p,
priorEta = 1,
priorSdx = 1
)
res <- stan(file = "stan/question/model.stan", data = dataList, chains = 2)
ss <- summary(res)
# Rhat is meh
ss$summary[!is.na(ss$summary[ ,"Rhat"]), "Rhat"]
# output
Lcorr[2,1] Lcorr[2,2] Lcorr[3,1] Lcorr[3,2] Lcorr[3,3] sdx[1]
1.121353 1.123117 1.794683 1.794350 1.578169 1.703874
sdx[2] sdx[3] lp__
1.239409 1.881612 1.085317
Warning messages:
1: There were 1861 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 10. See
http://mc-stan.org/misc/warnings.html#maximum-treedepth-exceeded
2: Examine the pairs() plot to diagnose sampling problems
Some observations:
If the correlations among predictors are approximately 0, a chain of 2000 samples takes about 10 seconds to run. If the correlations among predictors are high, this time can increase to about 500 seconds per chain! I’m aware I can rewrite the normal model in terms of the sufficient statistics, which leads to a good boost in performance, but it doesn’t improve the quality of the sampling (i.e., the R-hat is high).
Is there any reparametrization trick, or another method to improve the efficiency of the sampler, that I’m missing?
All comments, suggestions, or advice is welcome!