I’ve recently been using data.table’s fread
function to read in the Stan CSV files for my model because it’s much faster than the built-in reading methods. However, all of the draws are flattened into two dimensions whereas I’d prefer the format of the extract
function from the rstan
package that preserves the dimensions of each parameter. It seems like others have also had this problem like @wds15.
So I wrote a function to read in Stan CSVs with fread
and then reshape the parameters into their original dimensions. I’ve only tested it on a few examples so it’s possible there are few kinks to be worked out, but hopefully it helps some other users.
I’m using the below “model” which samples parameters in up to 4 dimensions, where you can pass the length of each dimension as your data:
data {
int l;
}
parameters {
real eta0;
real eta1 [l];
real eta2 [l, l];
real eta3 [l, l, l];
real eta4 [l, l, l, l];
}
model {
eta0 ~ normal(0, 1);
eta1 ~ normal(0, 1);
to_array_1d(eta2) ~ normal(0, 1);
to_array_1d(eta3) ~ normal(0, 1);
to_array_1d(eta4) ~ normal(0, 1);
}
For l=2
, there are no efficiency gains, but for l=6
and above, where the number of parameters pushes into the tens of thousands, this should be much faster:
library(tidyverse)
library(cmdstanr)
make_posterior <- function(files, num_warmup = -1){
# read first line to get parameter (column) names
params <- names(data.table::fread(files[1], skip = 38, nrow = 0))
# skip warmup draws if they were saved
skip <- ifelse(num_warmup > 0, num_warmup + 40, 38)
# bind together all the chains, skipping warm-up if necessary
chains <- lapply(files, function(file){
chain <- suppressWarnings(data.table::fread(file, skip = skip))
names(chain) <- params
chain
}) %>%
bind_rows()
iter <- nrow(chains)
# figure out highest dimension in parameters
max_dim <- max(str_count(params, "\\."))
# extract parameter names and indices
param_dims <- data.frame(param = params) %>%
separate(param, into = c('param', paste0('dim', 1:max_dim)),
sep = "\\.", convert = TRUE, fill = "right") %>%
group_by(param) %>%
summarize_all(max)
posterior <- list()
# loop over parameters
for(i in 1:nrow(param_dims)){
param <- param_dims %>%
slice(i) %>%
pull(param)
dims <- param_dims %>%
slice(i) %>%
select(-param) %>%
as.numeric()
# reshape draws into array defined by dims
draws <- chains %>%
select(which(str_detect(names(.), param))) %>%
as.matrix() %>%
array(c(iter, dims[!is.na(dims)]))
posterior[[param]] <- draws
}
posterior
}
mod <- cmdstan_model("...arrays.stan")
fit <- mod$sample(
data = list(l = 6),
chains = 4,
parallel_chains = 4
)
On my machine, the rstan
function take about 10 seconds while make_posterior
takes 2. The former permutes the draws, so we can’t check for element-level equality, but all the marginal means match up:
start1 <- Sys.time()
stanfit <- rstan::read_stan_csv(fit$output_files())
posterior1 <- rstan::extract(stanfit)
print(Sys.time() - start1)
start2 <- Sys.time()
posterior2 <- make_posterior(fit$output_files())
print(Sys.time() - start2)
all.equal(
as.numeric(apply(posterior1$eta4, 2:5, mean)),
as.numeric(apply(posterior2$eta4, 2:5, mean))
)
If anyone needs something like this, feel free to let me know if it works. It also could probably still be made much faster with parallelization or replacing the tidyverse
functions, which aren’t known for their speed.