I have to fit the same model many times to different outcomes. I want to specify the model in brms and want to avoid having to recompile the model every time I fit the model. I hacked together a solution and it seems to work fine sequentially (i.e. using map
) and also works when I fit models in parallel, but only with a maximum of two workers. When I go up to 8 workers for example, I get the error “Supplied CSV file is corrupt!” after some models are finished. I suspected that this had something to do with both workers trying to access the same file, but from my interpretation of the documentation each model and file gets its own unique name so I think that’s not the issue.
Anyone has any suggestions what goes wrong? See my minimal reproducible example below:
empty_brms_to_stan_file <- function(empty_brms_obj) {
formula <- empty_brms_obj$formula
data <- empty_brms_obj$data
family <- empty_brms_obj$family
stancode <- brms::make_stancode(formula = formula, data = data)
standata <- brms::make_standata(formula = formula, data = data, family = family)
f <- cmdstanr::write_stan_file(stancode)
compiled_mod <- cmdstanr::cmdstan_model(f)
return(list("compiled_mod" = compiled_mod, "standata" = standata))
}
compiled_cmdstan_to_fitted_brms <- function(compiled_cmdstan_obj, standata, empty_brms_obj, ...) {
fit <- compiled_cmdstan_obj$sample(standata,
...)
empty_brms_obj$fit <- rstan::read_stan_csv(fit$output_files())
empty_brms_obj <- brms::rename_pars(empty_brms_obj)
return(empty_brms_obj)
}
wrapper <- function(outcome, other_vars_df, compiled_mod, empty_brms_obj, ...) {
df <- dplyr::bind_cols(y = outcome, other_vars_df)
standata <- brms::make_standata(formula = empty_brms_obj$formula, data = df, family = empty_brms_obj$family)
m_fitted <- compiled_cmdstan_to_fitted_brms(compiled_mod, standata, empty_brms_obj, ...)
return(m_fitted)
}
seed <- 204
n <- 300
samples_p_subject <- 50
nr_subjects <- n/samples_p_subject
re_sd_subject <- 2
beta_x <- 0.3
re_subjects <- rnorm(nr_subjects, mean = 0, sd = re_sd_subject)
subjectid <- rep(1:nr_subjects, each = samples_p_subject)
re_subject <- rep(re_subjects, each = samples_p_subject)
x <- rnorm(n)
y <- rnorm(n, beta_x*x + re_subject)
df <- tibble::tibble("x" = x, "y" = y, "subjectid" = subjectid)
priors <- c(brms::prior_string("normal(0,3)", class = "b"),
brms::prior_string("student_t(3, 0, 3.5)", class = c("Intercept")),
brms::prior_string("student_t(3, 0, 2.5)", class = c("sd")),
brms::prior_string("student_t(3, 0, 2.5)", class = c("sigma")))
m <- brms::brm(formula = brms::brmsformula(y ~ 1 + x + (1|subjectid)),
data = df,
family = gaussian(),
backend = "cmdstanr",
prior = priors,
seed = seed,
cores = 4, empty = TRUE)
# proof of concept:
cmdstanr_list <- empty_brms_to_stan_file(m)
#> Warning in '[..]/AppData/Local/Temp/Rtmpc5QohC/model-563868cb501f.stan', line 12, column 2: Declaration
#> of arrays by placing brackets after a variable name is deprecated and
#> will be removed in Stan 2.33.0. Instead use the array keyword before the
#> type. This can be changed automatically using the auto-format flag to
#> stanc
#> Warning in '[..]/AppData/Local/Temp/Rtmpc5QohC/model-563868cb501f.stan', line 31, column 2: Declaration
#> of arrays by placing brackets after a variable name is deprecated and
#> will be removed in Stan 2.33.0. Instead use the array keyword before the
#> type. This can be changed automatically using the auto-format flag to
#> stanc
#> In file included from stan/src/stan/model/model_header.hpp:11:
#> stan/src/stan/model/model_base_crtp.hpp:198: warning: 'void stan::model::model_base_crtp<M>::write_array(boost::random::ecuyer1988&, std::vector<double, std::allocator<double> >&, std::vector<int>&, std::vector<double, std::allocator<double> >&, bool, bool, std::ostream*) const [with M = model_471582a900d43b8f5bc21554bbd7d83f_model_namespace::model_471582a900d43b8f5bc21554bbd7d83f_model; boost::random::ecuyer1988 = boost::random::additive_combine_engine<boost::random::linear_congruential_engine<unsigned int, 40014, 0, 2147483563>, boost::random::linear_congruential_engine<unsigned int, 40692, 0, 2147483399> >; std::ostream = std::basic_ostream<char>]' was hidden [-Woverloaded-virtual=]
#> 198 | void write_array(boost::ecuyer1988& rng, std::vector<double>& theta,
#> |
#> [..]/AppData/Local/Temp/Rtmpc5QohC/model-563868cb501f.hpp:784: note: by 'model_471582a900d43b8f5bc21554bbd7d83f_model_namespace::model_471582a900d43b8f5bc21554bbd7d83f_model::write_array'
#> 784 | write_array(RNG& base_rng, std::vector<double>& params_r, std::vector<int>&
#> |
#> stan/src/stan/model/model_base_crtp.hpp:136: warning: 'void stan::model::model_base_crtp<M>::write_array(boost::random::ecuyer1988&, Eigen::VectorXd&, Eigen::VectorXd&, bool, bool, std::ostream*) const [with M = model_471582a900d43b8f5bc21554bbd7d83f_model_namespace::model_471582a900d43b8f5bc21554bbd7d83f_model; boost::random::ecuyer1988 = boost::random::additive_combine_engine<boost::random::linear_congruential_engine<unsigned int, 40014, 0, 2147483563>, boost::random::linear_congruential_engine<unsigned int, 40692, 0, 2147483399> >; Eigen::VectorXd = Eigen::Matrix<double, -1, 1>; std::ostream = std::basic_ostream<char>]' was hidden [-Woverloaded-virtual=]
#> 136 | void write_array(boost::ecuyer1988& rng, Eigen::VectorXd& theta,
#> |
#> [..]/AppData/Local/Temp/Rtmpc5QohC/model-563868cb501f.hpp:784: note: by 'model_471582a900d43b8f5bc21554bbd7d83f_model_namespace::model_471582a900d43b8f5bc21554bbd7d83f_model::write_array'
#> 784 | write_array(RNG& base_rng, std::vector<double>& params_r, std::vector<int>&
#> |
brms_obj <- compiled_cmdstan_to_fitted_brms(cmdstanr_list$compiled_mod, cmdstanr_list$standata, m,
parallel_chains = 4,
seed = seed)
#> Running MCMC with 4 parallel chains...
#> All 4 chains finished successfully.
#> Mean chain execution time: 2.4 seconds.
#> Total execution time: 2.8 seconds.
# parallel fits using map
df_many <- tibble::tibble(df |> dplyr::select("Y1" = y),
df |> dplyr::select("Y2" = y),
df |> dplyr::select("Y3" = y),
df |> dplyr::select("Y4" = y),
df |> dplyr::select("Y5" = y),
df |> dplyr::select("Y6" = y),
df |> dplyr::select("Y7" = y),
df |> dplyr::select("Y8" = y),
df |> dplyr::select(subjectid, x))
# 2 * 4 cores: works fine
future::plan(future::multisession, workers = 2)
fitted_models <- df_many |>
dplyr::select(1:8) |>
furrr::future_map(\(y)
wrapper(outcome = y,
other_vars_df = df_many |> dplyr::select(subjectid, x),
compiled_mod = cmdstanr_list$compiled_mod,
empty_brms_obj = m,
parallel_chains = 4,
seed = seed), .progress = TRUE,
.options = furrr::furrr_options(seed = seed))
#> Running MCMC with 4 parallel chains...
#> All 4 chains finished successfully.
#> Mean chain execution time: 2.8 seconds.
#> Total execution time: 3.6 seconds.
#>
#> Running MCMC with 4 parallel chains...
#> All 4 chains finished successfully.
#> Mean chain execution time: 3.4 seconds.
#> Total execution time: 3.9 seconds.
#>
#> Running MCMC with 4 parallel chains...
#> All 4 chains finished successfully.
#> Mean chain execution time: 3.4 seconds.
#> Total execution time: 4.0 seconds.
#>
#> Running MCMC with 4 parallel chains...
#> All 4 chains finished successfully.
#> Mean chain execution time: 3.3 seconds.
#> Total execution time: 3.9 seconds.
#>
#> Running MCMC with 4 parallel chains...
#> All 4 chains finished successfully.
#> Mean chain execution time: 3.6 seconds.
#> Total execution time: 4.4 seconds.
#>
#> Running MCMC with 4 parallel chains...
#> All 4 chains finished successfully.
#> Mean chain execution time: 3.3 seconds.
#> Total execution time: 3.8 seconds.
#>
#> Running MCMC with 4 parallel chains...
#> All 4 chains finished successfully.
#> Mean chain execution time: 3.4 seconds.
#> Total execution time: 3.8 seconds.
#>
#> Running MCMC with 4 parallel chains...
#> All 4 chains finished successfully.
#> Mean chain execution time: 2.8 seconds.
#> Total execution time: 3.5 seconds.
# 8 * 1 core: doesn't work!
future::plan(future::multisession, workers = 8)
fitted_models <- df_many |>
dplyr::select(1:8) |>
furrr::future_map(\(y)
wrapper(outcome = y,
other_vars_df = df_many |> dplyr::select(subjectid, x),
compiled_mod = cmdstanr_list$compiled_mod,
empty_brms_obj = m,
parallel_chains = 1,
seed = seed), .progress = TRUE,
.options = furrr::furrr_options(seed = seed))
#> Running MCMC with 4 sequential chains...
#> All 4 chains finished successfully.
#> Mean chain execution time: 1.9 seconds.
#> Total execution time: 8.0 seconds.
#>
#> Running MCMC with 4 sequential chains...
#> All 4 chains finished successfully.
#> Mean chain execution time: 2.2 seconds.
#> Total execution time: 9.5 seconds.
#>
#> Running MCMC with 4 sequential chains...
#> All 4 chains finished successfully.
#> Mean chain execution time: 2.4 seconds.
#> Total execution time: 10.5 seconds.
#>
#> Running MCMC with 4 sequential chains...
#> All 4 chains finished successfully.
#> Mean chain execution time: 2.4 seconds.
#> Total execution time: 10.2 seconds.
#>
#> Running MCMC with 4 sequential chains...
#> All 4 chains finished successfully.
#> Mean chain execution time: 2.4 seconds.
#> Total execution time: 10.2 seconds.
#>
#> Running MCMC with 4 sequential chains...
#> All 4 chains finished successfully.
#> Mean chain execution time: 2.2 seconds.
#> Total execution time: 9.6 seconds.
#>
#> Error:
#> ℹ In index: 1.
#> ℹ With name: Y6.
#> Caused by error:
#> ! Supplied CSV file is corrupt!
Created on 2024-04-20 with reprex v2.1.0
sessionInfo()
#> R version 4.3.1 (2023-06-16 ucrt)
#> Platform: x86_64-w64-mingw32/x64 (64-bit)
#> Running under: Windows 10 x64 (build 19045)
#>
#> Matrix products: default
#>
#>
#> attached base packages:
#> [1] stats graphics grDevices utils datasets methods base
#>
#> other attached packages:
#> [1] brms_2.19.0 Rcpp_1.0.10 rstan_2.32.3
#> [4] StanHeaders_2.26.28 cmdstanr_0.5.3
#>
#> loaded via a namespace (and not attached):
#> [1] tidyselect_1.2.0 dplyr_1.1.2 farver_2.1.1
#> [4] loo_2.6.0 fastmap_1.1.1 tensorA_0.36.2
#> [7] shinystan_2.6.0 shinyjs_2.1.0 promises_1.2.0.1
#> [10] reprex_2.1.0 digest_0.6.31 mime_0.12
#> [13] lifecycle_1.0.3 ellipsis_0.3.2 processx_3.8.1
#> [16] magrittr_2.0.3 posterior_1.4.1 compiler_4.3.1
#> [19] rlang_1.1.1 tools_4.3.1 igraph_1.4.3
#> [22] utf8_1.2.3 yaml_2.3.7 knitr_1.43
#> [25] prettyunits_1.1.1 bridgesampling_1.1-2 htmlwidgets_1.6.2
#> [28] pkgbuild_1.4.0 curl_5.0.0 plyr_1.8.8
#> [31] dygraphs_1.1.1.6 abind_1.4-5 miniUI_0.1.1.1
#> [34] withr_2.5.0 grid_4.3.1 stats4_4.3.1
#> [37] fansi_1.0.4 xts_0.13.1 xtable_1.8-4
#> [40] colorspace_2.1-0 inline_0.3.19 ggplot2_3.4.4
#> [43] gtools_3.9.4 scales_1.2.1 cli_3.6.1
#> [46] mvtnorm_1.2-2 rmarkdown_2.22 crayon_1.5.2
#> [49] generics_0.1.3 RcppParallel_5.1.7 rstudioapi_0.14
#> [52] reshape2_1.4.4 stringr_1.5.0 shinythemes_1.2.0
#> [55] bayesplot_1.10.0 parallel_4.3.1 matrixStats_1.2.0
#> [58] base64enc_0.1-3 vctrs_0.6.2 V8_4.4.0
#> [61] Matrix_1.6-0 jsonlite_1.8.7 callr_3.7.3
#> [64] crosstalk_1.2.0 glue_1.6.2 codetools_0.2-19
#> [67] ps_1.7.5 DT_0.28 distributional_0.3.2
#> [70] stringi_1.7.12 gtable_0.3.3 later_1.3.1
#> [73] QuickJSR_1.0.7 munsell_0.5.0 tibble_3.2.1
#> [76] colourpicker_1.2.0 pillar_1.9.0 htmltools_0.5.5
#> [79] Brobdingnag_1.2-9 R6_2.5.1 evaluate_0.21
#> [82] shiny_1.7.4 lattice_0.21-8 markdown_1.7
#> [85] backports_1.4.1 threejs_0.3.3 httpuv_1.6.11
#> [88] rstantools_2.3.1 coda_0.19-4 gridExtra_2.3
#> [91] nlme_3.1-162 checkmate_2.2.0 xfun_0.42
#> [94] fs_1.6.2 zoo_1.8-12 pkgconfig_2.0.3