Code to read + reshape draws to match rstan::extract

I’ve recently been using data.table’s fread function to read in the Stan CSV files for my model because it’s much faster than the built-in reading methods. However, all of the draws are flattened into two dimensions whereas I’d prefer the format of the extract function from the rstan package that preserves the dimensions of each parameter. It seems like others have also had this problem like @wds15.

So I wrote a function to read in Stan CSVs with fread and then reshape the parameters into their original dimensions. I’ve only tested it on a few examples so it’s possible there are few kinks to be worked out, but hopefully it helps some other users.

I’m using the below “model” which samples parameters in up to 4 dimensions, where you can pass the length of each dimension as your data:

data {
  int l;
}
parameters {
  real eta0;
  real eta1 [l];
  real eta2 [l, l];
  real eta3 [l, l, l];
  real eta4 [l, l, l, l];
}
model {
  eta0 ~ normal(0, 1);
  eta1 ~ normal(0, 1);
  to_array_1d(eta2) ~ normal(0, 1);
  to_array_1d(eta3) ~ normal(0, 1);
  to_array_1d(eta4) ~ normal(0, 1);
}

For l=2, there are no efficiency gains, but for l=6 and above, where the number of parameters pushes into the tens of thousands, this should be much faster:

library(tidyverse)
library(cmdstanr)

make_posterior <- function(files, num_warmup = -1){
  
  # read first line to get parameter (column) names
  params <- names(data.table::fread(files[1], skip = 38, nrow = 0))
  
  # skip warmup draws if they were saved
  skip <- ifelse(num_warmup > 0, num_warmup + 40, 38)
  
  # bind together all the chains, skipping warm-up if necessary
  chains <- lapply(files, function(file){
    chain <- suppressWarnings(data.table::fread(file, skip = skip))
    names(chain) <- params
    chain
  }) %>%
    bind_rows()
  
  iter <- nrow(chains)
  
  # figure out highest dimension in parameters
  max_dim <- max(str_count(params, "\\."))
  
  # extract parameter names and indices
  param_dims <- data.frame(param = params) %>%
    separate(param, into = c('param', paste0('dim', 1:max_dim)), 
             sep = "\\.", convert = TRUE, fill = "right") %>%
    group_by(param) %>%
    summarize_all(max)
  
  posterior <- list()
  
  # loop over parameters
  for(i in 1:nrow(param_dims)){
    
    param <- param_dims %>%
      slice(i) %>%
      pull(param)
    
    dims <- param_dims %>%
      slice(i) %>%
      select(-param) %>%
      as.numeric()
    
    # reshape draws into array defined by dims
    draws <- chains %>%
      select(which(str_detect(names(.), param))) %>%
      as.matrix() %>%
      array(c(iter, dims[!is.na(dims)]))
    
    posterior[[param]] <- draws
  }
  
  posterior
}

mod <- cmdstan_model("...arrays.stan")

fit <- mod$sample(
  data = list(l = 6),
  chains = 4,
  parallel_chains = 4
)

On my machine, the rstan function take about 10 seconds while make_posterior takes 2. The former permutes the draws, so we can’t check for element-level equality, but all the marginal means match up:


start1 <- Sys.time()
stanfit <- rstan::read_stan_csv(fit$output_files())
posterior1 <- rstan::extract(stanfit)
print(Sys.time() - start1)

start2 <- Sys.time()
posterior2 <- make_posterior(fit$output_files())
print(Sys.time() - start2)


all.equal(
  as.numeric(apply(posterior1$eta4, 2:5, mean)),
  as.numeric(apply(posterior2$eta4, 2:5, mean))
)

If anyone needs something like this, feel free to let me know if it works. It also could probably still be made much faster with parallelization or replacing the tidyverse functions, which aren’t known for their speed.

4 Likes

Nice!

You can set permuted = FALSE as an argument to rstan::extract. Does that make stuff comparable?

real eta4 [l, l, l, l];

Does this code work with non-matching dimensions?

1 Like

you mean something other than a high-dimensional cube?
I tried real eta4 [l, l+1, l+2, l+3] and it appears to still work.

permuted = FALSE flattens the samples into chain x parameter x iteration I believe, which I never understood – seems like that should be a different option

Groovy!

1 Like