I like using cmdstanr
(as opposed to rstan
) to get access to the latest Stan features in R. However, for models with many parameters and large datasets I’ve been running into memory issues. I can’t save the object or call the draws into memory (from the csv files that cmdstanr creates) to summarize the model output, for example using fit$draws(variables = 'alpha')
. Many generated quantities or ‘raw’ parameters (e.g., from non-centering) only contributes to the memory issue since all outputs are tracked. I know monitoring only a subset of parameters has been discussed elsewhere here, here, here and is something that is being actively addressed in cmdstan
development here. In the meantime, I thought the below solution might be of interest to some folks experiencing similar issues (either blowing up their memory or just wanting to save out a subset of parameters).
This uses R to call awk
and cut
(bash programs) to select only columns (parameters) of interest in the temporary cmdstan .csv
files. It tacks on the header to the filtered data and then combines the filtered .csv
files to make a new cmdstanr
object. awk
and cut
(bash programs) must be installed for this to work.
I’m sure this could be improved (improvements welcome!) and it takes some time to run for very large objects, but importantly it does the job. I haven’t tested this very extensively (mostly macOS Sonoma, R 4.4.0) so your mileage may vary. I know the new model object will not read into shinystan
. But neither does an object created directly from the original temp files.
Function to subset cmdstanr object
# model_fit is the fit object from `cmdstanr`
# target_params argument ignores square brackets (i.e,. []).
# target_params = 'alpha' will return 'alpha[1]', 'alpha[2]', etc.
# thin argument saves every X iteration in csv
subset_cmdstanr <- function(model_fit, target_params, thin = 1)
{
#get list of temp cmdstan files
cmdstan_files <- model_fit$output_files()
#get param names
all_params <- model_fit$metadata()$model_params
#remove []
all_params_ISB <- vapply(strsplit(all_params,
split = "[", fixed = TRUE),
`[`, 1, FUN.VALUE=character(1))
#get idx for target params
f_ind <- which(all_params_ISB %in% target_params)
#add 6 to all indices except 1, if it exists to account for:
#accept_stat__,stepsize__,treedepth__,n_leapfrog__,divergent__,energy__
#which are in csv but not all_params
#add 1 to beginning of vec if not there
f_ind2 <- f_ind
if (sum(f_ind2 == 1))
{
f_ind2[-1] <- f_ind2[-1] + 6
} else {
f_ind2 <- c(1, f_ind2 + 6)
}
#add 2-7 (metadata)
f_ind2 <- append(f_ind2, 2:7, after = 1)
#when feeding >0.5 million args (columns) to cut, there are issues when passing
#too many indices. Get areas where range ('-') can be used
#get diff between indices
del <- diff(f_ind2)
#add max indx to vec
del_idx <- c(which(del > 1), length(f_ind2))
#start with first var
fv <- paste0('-f', f_ind2[1])
#if indices are a series, inset - max idx
if (sum(del > 1) == 0)
{
fv <- paste0(fv, '-', tail(f_ind2, 1))
} else {
while (length(del) > 0)
{
#find first diff > 1
if (sum(del > 1) != 0)
{
md <- min(which(del > 1))
# add to call
if (md > 1)
{
fv <- paste0(fv, '-', f_ind2[md], ',', f_ind2[md+1])
} else {
fv <- paste0(fv, ',', f_ind2[md+1])
}
# reorg del and f_ind2
del <- del[-c(1:md)]
f_ind2 <- f_ind2[-c(1:md)]
} else {
#if series to end
fv <- paste0(fv, '-', tail(f_ind2, 1))
del <- del[-c(1:length(del))]
}
}
}
#function to subset draws in csv to df
awk_fun <- function(file, IDX, thin)
{
print(paste0('subsetting ', file))
#remove header
p1_awk_call <- paste0("awk -F: '/^[^#]/ {print}' ", file,
# select params of interest
" | cut -d \",\" ", IDX)
if (thin == 1)
{
awk_call <- paste0(p1_awk_call,
" > ",
strsplit(file, '.csv')[[1]], '-draws-subset.csv')
} else {
#thin after keeping first line
#https://unix.stackexchange.com/questions/648113/how-to-skip-every-three-lines-using-awk
awk_call <- paste0(p1_awk_call,
" | awk 'NR%", thin, "==1' > ",
strsplit(file, '.csv')[[1]], '-draws-subset.csv')
}
system(awk_call)
#get header only and write to file
call2 <- paste0("head -n 47 ", file, " > ",
strsplit(file, '.csv')[[1]], '-header.csv')
system(call2)
#get elapsed time
call3 <- paste0("tail -n 5 ", file, " > ",
strsplit(file, '.csv')[[1]], '-time.csv')
system(call3)
print(paste0('combining and writing to file'))
#combine and write to file
awk_call4 <- paste0('cat ',
strsplit(file, '.csv')[[1]], '-header.csv ',
strsplit(file, '.csv')[[1]], '-draws-subset.csv ',
strsplit(file, '.csv')[[1]], '-time.csv > ',
strsplit(file, '.csv')[[1]], '-subset-comb.csv')
system(awk_call4)
}
#run awk fun and put all chains into list
invisible(lapply(cmdstan_files, FUN = function(x) awk_fun(file = x,
IDX = fv,
thin = thin)))
#list files in dir
lf <- list.files(dirname(cmdstan_files)[1], full.names = TRUE)
#get names of comb files
new_files <- grep('-subset-comb.csv', lf, value = TRUE)
#make sure only files pertaining to this run
tt <- strsplit(basename(cmdstan_files), '.csv')[[1]]
new_files2 <- grep(substring(tt, nchar(tt)-5, nchar(tt)), new_files, value = TRUE)
print(paste0('creating new cmdstanr object'))
model_fit2 <- cmdstanr::as_cmdstan_fit(new_files2)
return(model_fit2)
}
Example model
library(cmdstanr)
fit <- cmdstanr::cmdstanr_example("logistic", chains = 4)
fit$summary()
# A tibble: 105 × 10
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 lp__ -66.0 -65.6 1.45 1.20 -68.9 -64.3 1.00 2013. 2533.
2 alpha 0.381 0.385 0.218 0.216 0.0231 0.741 1.00 3703. 2476.
3 beta[1] -0.667 -0.663 0.252 0.248 -1.09 -0.253 1.00 4782. 2822.
4 beta[2] -0.281 -0.277 0.225 0.225 -0.661 0.0811 1.00 4100. 2867.
5 beta[3] 0.675 0.671 0.266 0.266 0.250 1.13 1.00 3986. 3010.
6 log_lik[1] -0.515 -0.507 0.0984 0.0964 -0.689 -0.368 1.00 3884. 2746.
7 log_lik[2] -0.406 -0.387 0.149 0.139 -0.679 -0.200 1.00 4255. 3059.
8 log_lik[3] -0.503 -0.472 0.219 0.212 -0.903 -0.208 1.00 4278. 2783.
9 log_lik[4] -0.449 -0.431 0.152 0.146 -0.726 -0.234 1.00 3585. 2480.
10 log_lik[5] -1.18 -1.16 0.279 0.275 -1.67 -0.755 1.00 4254. 3139.
# ℹ 95 more rows
Run subset function (retain only ‘lp__’, ‘alpha’, and ‘beta’)
fit2 <- subset_cmdstanr(model_fit = fit,
target_params = c('lp__', 'alpha', 'beta'))
fit2$summary()
# A tibble: 5 × 10
variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 lp__ -66.0 -65.6 1.45 1.20 -68.9 -64.3 1.00 2013. 2533.
2 alpha 0.381 0.385 0.218 0.216 0.0231 0.741 1.00 3703. 2476.
3 beta[1] -0.667 -0.663 0.252 0.248 -1.09 -0.253 1.00 4782. 2822.
4 beta[2] -0.281 -0.277 0.225 0.225 -0.661 0.0811 1.00 4100. 2867.
5 beta[3] 0.675 0.671 0.266 0.266 0.250 1.13 1.00 3986. 3010.