Save out subset of cmdstanr parameters (post-run solution)

I like using cmdstanr (as opposed to rstan) to get access to the latest Stan features in R. However, for models with many parameters and large datasets I’ve been running into memory issues. I can’t save the object or call the draws into memory (from the csv files that cmdstanr creates) to summarize the model output, for example using fit$draws(variables = 'alpha'). Many generated quantities or ‘raw’ parameters (e.g., from non-centering) only contributes to the memory issue since all outputs are tracked. I know monitoring only a subset of parameters has been discussed elsewhere here, here, here and is something that is being actively addressed in cmdstan development here. In the meantime, I thought the below solution might be of interest to some folks experiencing similar issues (either blowing up their memory or just wanting to save out a subset of parameters).

This uses R to call awk and cut (bash programs) to select only columns (parameters) of interest in the temporary cmdstan .csv files. It tacks on the header to the filtered data and then combines the filtered .csv files to make a new cmdstanr object. awk and cut (bash programs) must be installed for this to work.

I’m sure this could be improved (improvements welcome!) and it takes some time to run for very large objects, but importantly it does the job. I haven’t tested this very extensively (mostly macOS Sonoma, R 4.4.0) so your mileage may vary. I know the new model object will not read into shinystan. But neither does an object created directly from the original temp files.

Function to subset cmdstanr object

# model_fit is the fit object from `cmdstanr`
# target_params argument ignores square brackets (i.e,. []).
# target_params = 'alpha' will return 'alpha[1]', 'alpha[2]', etc.
# thin argument saves every X iteration in csv
subset_cmdstanr <- function(model_fit, target_params, thin = 1)
{
  #get list of temp cmdstan files
  cmdstan_files <- model_fit$output_files()
  
  #get param names
  all_params <- model_fit$metadata()$model_params
  
  #remove []
  all_params_ISB <- vapply(strsplit(all_params,
                                    split = "[", fixed = TRUE), 
                           `[`, 1, FUN.VALUE=character(1))
  
  #get idx for target params
  f_ind <- which(all_params_ISB %in% target_params)
  
  #add 6 to all indices except 1, if it exists to account for:
  #accept_stat__,stepsize__,treedepth__,n_leapfrog__,divergent__,energy__
  #which are in csv but not all_params
  #add 1 to beginning of vec if not there
  f_ind2 <- f_ind
  if (sum(f_ind2 == 1))
  {
    f_ind2[-1] <- f_ind2[-1] + 6
  } else {
    f_ind2 <- c(1, f_ind2 + 6)
  }
  
  #add 2-7 (metadata)
  f_ind2 <- append(f_ind2, 2:7, after = 1)
  
  #when feeding >0.5 million args (columns) to cut, there are issues when passing
  #too many indices. Get areas where range ('-') can be used
  #get diff between indices
  del <- diff(f_ind2)
  #add max indx to vec
  del_idx <- c(which(del > 1), length(f_ind2))
  
  #start with first var
  fv <- paste0('-f', f_ind2[1])
  #if indices are a series, inset - max idx
  if (sum(del > 1) == 0)
  {
    fv <- paste0(fv, '-', tail(f_ind2, 1))
  } else {
    while (length(del) > 0)
    {
      #find first diff > 1
      if (sum(del > 1) != 0)
      {
        md <- min(which(del > 1))
        
        # add to call
        if (md > 1)
        {
          fv <- paste0(fv, '-', f_ind2[md], ',', f_ind2[md+1])
        } else {
          fv <- paste0(fv, ',', f_ind2[md+1])
        }
        
        # reorg del and f_ind2
        del <- del[-c(1:md)]
        f_ind2 <- f_ind2[-c(1:md)]
      } else {
        #if series to end
        fv <- paste0(fv, '-', tail(f_ind2, 1))
        del <- del[-c(1:length(del))]
      }
    } 
  }
  
  #function to subset draws in csv to df
  awk_fun <- function(file, IDX, thin)
  {
    print(paste0('subsetting ', file))
    #remove header
    p1_awk_call <- paste0("awk -F: '/^[^#]/ {print}' ", file, 
                       # select params of interest
                       " | cut -d \",\" ", IDX) 
    if (thin == 1)
    {
      awk_call <- paste0(p1_awk_call, 
                         " > ",
                         strsplit(file, '.csv')[[1]], '-draws-subset.csv')
    } else {
      #thin after keeping first line
      #https://unix.stackexchange.com/questions/648113/how-to-skip-every-three-lines-using-awk
      awk_call <- paste0(p1_awk_call, 
                         " | awk 'NR%", thin, "==1' > ", 
                         strsplit(file, '.csv')[[1]], '-draws-subset.csv')
    }
    system(awk_call)
    
    #get header only and write to file
    call2 <- paste0("head -n 47 ", file, " > ", 
                    strsplit(file, '.csv')[[1]], '-header.csv')
    system(call2)
    
    #get elapsed time
    call3 <- paste0("tail -n 5 ", file, " > ", 
                    strsplit(file, '.csv')[[1]], '-time.csv')
    system(call3)
    
    print(paste0('combining and writing to file'))
    #combine and write to file
    awk_call4 <- paste0('cat ', 
                        strsplit(file, '.csv')[[1]], '-header.csv ', 
                        strsplit(file, '.csv')[[1]], '-draws-subset.csv ',
                        strsplit(file, '.csv')[[1]], '-time.csv > ',
                        strsplit(file, '.csv')[[1]], '-subset-comb.csv')
    system(awk_call4)
  }
  
  #run awk fun and put all chains into list
  invisible(lapply(cmdstan_files, FUN = function(x) awk_fun(file = x, 
                                                            IDX = fv,
                                                            thin = thin)))
  
  #list files in dir
  lf <- list.files(dirname(cmdstan_files)[1], full.names = TRUE)
  #get names of comb files
  new_files <- grep('-subset-comb.csv', lf, value = TRUE)
  #make sure only files pertaining to this run
  tt <- strsplit(basename(cmdstan_files), '.csv')[[1]]
  new_files2 <- grep(substring(tt, nchar(tt)-5, nchar(tt)), new_files, value = TRUE)
  
  print(paste0('creating new cmdstanr object'))
  model_fit2 <- cmdstanr::as_cmdstan_fit(new_files2)
  
  return(model_fit2)
}

Example model

library(cmdstanr)
fit <- cmdstanr::cmdstanr_example("logistic", chains = 4)
fit$summary()
# A tibble: 105 × 10
   variable      mean  median     sd    mad       q5      q95  rhat ess_bulk ess_tail
   <chr>        <dbl>   <dbl>  <dbl>  <dbl>    <dbl>    <dbl> <dbl>    <dbl>    <dbl>
 1 lp__       -66.0   -65.6   1.45   1.20   -68.9    -64.3     1.00    2013.    2533.
 2 alpha        0.381   0.385 0.218  0.216    0.0231   0.741   1.00    3703.    2476.
 3 beta[1]     -0.667  -0.663 0.252  0.248   -1.09    -0.253   1.00    4782.    2822.
 4 beta[2]     -0.281  -0.277 0.225  0.225   -0.661    0.0811  1.00    4100.    2867.
 5 beta[3]      0.675   0.671 0.266  0.266    0.250    1.13    1.00    3986.    3010.
 6 log_lik[1]  -0.515  -0.507 0.0984 0.0964  -0.689   -0.368   1.00    3884.    2746.
 7 log_lik[2]  -0.406  -0.387 0.149  0.139   -0.679   -0.200   1.00    4255.    3059.
 8 log_lik[3]  -0.503  -0.472 0.219  0.212   -0.903   -0.208   1.00    4278.    2783.
 9 log_lik[4]  -0.449  -0.431 0.152  0.146   -0.726   -0.234   1.00    3585.    2480.
10 log_lik[5]  -1.18   -1.16  0.279  0.275   -1.67    -0.755   1.00    4254.    3139.
# ℹ 95 more rows

Run subset function (retain only ‘lp__’, ‘alpha’, and ‘beta’)

fit2 <- subset_cmdstanr(model_fit = fit, 
                        target_params = c('lp__', 'alpha', 'beta'))
fit2$summary()
# A tibble: 5 × 10
  variable    mean  median    sd   mad       q5      q95  rhat ess_bulk ess_tail
  <chr>      <dbl>   <dbl> <dbl> <dbl>    <dbl>    <dbl> <dbl>    <dbl>    <dbl>
1 lp__     -66.0   -65.6   1.45  1.20  -68.9    -64.3     1.00    2013.    2533.
2 alpha      0.381   0.385 0.218 0.216   0.0231   0.741   1.00    3703.    2476.
3 beta[1]   -0.667  -0.663 0.252 0.248  -1.09    -0.253   1.00    4782.    2822.
4 beta[2]   -0.281  -0.277 0.225 0.225  -0.661    0.0811  1.00    4100.    2867.
5 beta[3]    0.675   0.671 0.266 0.266   0.250    1.13    1.00    3986.    3010.
3 Likes

Thank you for sharing this! I will point people here when we get asked about this until we have a solution in CmdStan.

1 Like