We have a huge dataset with 1.6 million observations, and we fit three different models. We fit the models with four chains and 750 iterations each. So, we have a 3000 x 1.6M loglik_posterior matrix. We saved this from each model run. It takes around 50GB RAM to reload these matrices into R, but we have access to a computational cluster, so it is no big deal.
The problem is when we use this as input in the loo package to compute loo-ic and, more importantly, Pareto-k-diagnostic values. We keep getting OOM errors no matter how big the RAM we try. We can go up to 96GB, but it doesn’t work.
We calculated WAIC using our own code in R. But we’d like to check the k-values before we trust these numbers.
As I explore the loo package, I realize it doesn’t matter whether I provide the full matrix as input or a smaller number of columns as input. The k-value for each observation turn out to be the same.
See below using the example from the loo package.
> LLmat <- example_loglik_matrix()
> sub <- 10:15
>
> l1 <- loo(LLmat)
Warning message:
Relative effective sample sizes ('r_eff' argument) not specified.
For models fit with MCMC, the reported PSIS effective sample sizes and
MCSE estimates will be over-optimistic.
>
> l2 <- loo(LLmat[,sub])
Warning message:
Relative effective sample sizes ('r_eff' argument) not specified.
For models fit with MCMC, the reported PSIS effective sample sizes and
MCSE estimates will be over-optimistic.
>
> pareto_k_values(l1)[sub]
[1] -0.11439294 -0.04461812 -0.06883883 -0.08868121 0.07186615 0.34633282
>
> pareto_k_values(l2)
[1] -0.11439294 -0.04461812 -0.06883883 -0.08868121 0.07186615 0.34633282
To resolve my OOM problem, I plan to run a loop by feeding smaller chunks of columns instead of providing all 1.6M at once and store the k-values at each iteration.
# d, is the loglik_posterior matrix, 3000 x 1663076
k_values <- rep(0,ncol(d))
loc <- c(seq(1,ncol(d),1000),ncol(d)+1)
for(i in 1:(length(loc)-1)){
loo_ <- loo(d[,loc[i]:(loc[i+1]-1)])
k_values[loc[i]:(loc[i+1]-1)] <- pareto_k_values(loo_)
print(i)
}
Is there any problem with this approach?