Bayesian Bootstrap implementation for brms

Hi there,

I’m dropping this here just in case it’s useful for anyone. A function to apply Bayesian Bootstrap on top of MAP estimations on a brms model:

wlb_brms <- function(
        fit, data = fit$data, B = 1000,
        mc.cores = max(1L, parallel::detectCores() - 1L),
        seed = 123, algorithm = "lbfgs", jacobian = TRUE,
        refresh = 0, init = 2,
        control = list()
) {
    stopifnot(inherits(fit, "brmsfit"))
    if (.Platform$OS.type == "windows" && mc.cores > 1L) {
        warning("mclapply uses fork; on Windows it will fall back to 1 core.")
        mc.cores <- 1L
    }
    
    N <- nrow(data)
    dummy_w <- rep(1, N)
    
    # Compile *once* a weights-aware Stan program matching `fit`
    scode <- brms::make_stancode(
        formula = formula(fit),
        data    = data,
        family  = fit$family,
        prior   = fit$prior,
        weights = dummy_w
    )
    sdata_base <- brms::make_standata(
        formula = formula(fit),
        data    = data,
        family  = fit$family,
        prior   = fit$prior,
        weights = dummy_w
    )
    mod <- cmdstanr::cmdstan_model(cmdstanr::write_stan_file(scode))
    
    # To avoid recompilation in workers, capture the path to the executable
    exe_path <- mod$exe_file()
    
    one_wlb <- function(b) {
        # Recreate CmdStanModel from exe to be fork-safe and fast
        mod_b <- cmdstanr::cmdstan_model(exe_file = exe_path)
        
        # Dirichlet(1,...,1) via normalized Gamma(1)
        w <- rgamma(N, 1); w <- N * w / sum(w)
        sdata <- sdata_base; sdata$weights <- as.numeric(w)
        
        # Run optimizer
        mle <- mod_b$optimize(
            data = sdata,
            seed = seed + b,
            jacobian = jacobian,
            algorithm = algorithm,
            init = init,
            refresh = refresh,
            # tolerances optional
            tol_grad = control$tol_grad %||% NULL,
            tol_obj  = control$tol_obj  %||% NULL
        )
        
        # Return both params and lp__
        list(theta = mle$mle(), lp = as.numeric(mle$lp()))
    }
    
    set.seed(seed)
    # Parallel map with error handling; drop failed replicates
    out_list <- pbmcapply::pbmclapply(seq_len(B), function(b) {
        tryCatch(one_wlb(b), error = function(e) NULL)
    }, mc.cores = mc.cores)
    
    ok <- !vapply(out_list, is.null, logical(1))
    if (!any(ok)) stop("All WLB replicates failed.")
    out_list <- out_list[ok]
    
    # Harmonize parameter names across replicates
    pname <- names(out_list[[1]]$theta)
    par_mat <- do.call(rbind, lapply(out_list, function(x) {
        v <- x$theta[pname]; as.numeric(v)
    }))
    colnames(par_mat) <- pname
    lp_vec <- vapply(out_list, `[[`, numeric(1), "lp")
    
    # ---- Fabricate a stanfit with draws=MAPs + valid diagnostics shell ----
    stanfit_from_draws <- function(draws_matrix, lp, model_name = "wlb_optimize") {
        stopifnot(nrow(draws_matrix) == length(lp))
        sam <- as.data.frame(draws_matrix)
        sam$lp__ <- as.numeric(lp)
        rownames(sam) <- seq_len(nrow(sam))
        
        # Create a diagnostics frame with expected NUTS columns (all NA)
        rstan_diagn_order <- c("accept_stat__", "treedepth__", "stepsize__",
                               "divergent__", "n_leapfrog__", "energy__")
        diagn <- as.data.frame(matrix(NA_real_, nrow = nrow(sam), ncol = length(rstan_diagn_order)))
        names(diagn) <- rstan_diagn_order
        rownames(diagn) <- rownames(sam)
        
        attr(sam, "sampler_params")  <- diagn
        attr(sam, "adaptation_info") <- character(0)
        attr(sam, "args")            <- list(sampler_t = "Optimize-WLB", chain_id = 1L)
        attr(sam, "elapsed_time")    <- c(warmup = 0, sample = 0)
        
        par_dims <- stats::setNames(replicate(ncol(sam), integer(0), simplify = FALSE),
                                    colnames(sam))
        sim <- list(samples   = list(sam),
                    iter      = nrow(sam), thin = 1L, warmup = 0L, chains = 1L,
                    n_save    = nrow(sam), warmup2 = 0L, permutation = list(seq_len(nrow(sam))),
                    pars_oi   = colnames(sam),
                    dims_oi   = par_dims,
                    fnames_oi = colnames(sam),
                    n_flatnames = ncol(sam))
        
        cxxdso_class <- "cxxdso"; attr(cxxdso_class, "package") <- "rstan"
        null_dso <- methods::new(cxxdso_class, sig = list(character(0)), dso_saved = FALSE,
                                 dso_filename = character(0), modulename = character(0),
                                 system = R.version$system, cxxflags = character(0),
                                 .CXXDSOMISC = new.env(parent = emptyenv()))
        null_sm <- methods::new("stanmodel", model_name = model_name,
                                model_code = character(0), model_cpp = list(), dso = null_dso)
        
        sargs <- list(method = "optimize", algorithm = "lbfgs", engine = "cmdstanr",
                      iter = nrow(sam), warmup = 0L, thin = 1L, chain_id = 1L)
        sargs_rep <- list(sargs)
        
        out <- methods::new("stanfit",
                            model_name = model_name,
                            model_pars = colnames(sam),
                            par_dims   = par_dims,
                            mode       = 0L,
                            sim        = sim,
                            inits      = list(),
                            stan_args  = sargs_rep,
                            stanmodel  = null_sm,
                            date       = format(Sys.time(), "%a %b %d %X %Y"),
                            .MISC      = new.env(parent = emptyenv())
        )
        attr(out, "metadata") <- list(method = "optimize", algorithm = "lbfgs")
        out
    }
    
    sf <- stanfit_from_draws(par_mat, lp_vec, model_name = "wlb_optimize")
    
    # Build a brmsfit container and align names
    fit_empty <- brms::brm(formula(fit), data = data, family = fit$family,
                           empty = TRUE, backend = "cmdstanr")
    fit_empty$fit <- sf
    fit_wlb <- brms::rename_pars(fit_empty)
    
    list(
        draws = posterior::as_draws_matrix(par_mat),
        lp__  = lp_vec,
        brmsfit = fit_wlb,
        exe_file = exe_path
    )
}

The thought part was to use “optimise” as an algorithm, since it’s not supported by brms, which would be a very nice addition. Notice that pbmcapply for parallelisation works only on Unix, due to my being lazy.

It should support all simple models, and it also works with a relatively complex nonlinear model of mine.

It’s not super fast due to the overhead of multiple stan calls; the best approach would be if the feature were implemented in Stan directly.

6 Likes

@WardBrian looks like someone dropped a stanfit from draws function here if I see that right…what we needed to make brm objects to come to live from draws?!

2 Likes

Holy moly! I was looking for this. The output of optimize is so much different from other methods really causes me headache. I currently I use laplace approximation with minimal posterior draws as the nearest replacement to optimize. However, the overhead of calculating Hessian and the inverse cholesky really slows things down.

Please by any chance this could be in brms? @paul.buerkner

Please feel free to open an issue on github. And a PR is of course much appreciated. I won’t have time to work on it myself though at the moment.

1 Like

Are you sure? It’s really just a workaround. Something deeper would be needed, e.g. enable “optimise” as algorithm or even deeper at the stan level…

I am not sure tbh. But I know that rstanarm supports optimization and then computes samples from the optimal point via the hession. So it can be done. If it is worth implementing in brms, I don’t fully know.

I totally think it would be nice to have the “optimize” choice in brms. I’m just not sure if my implementation is the most correct approach. if you think it is by eyeballing the code I’ll do it

I was not thinking thoroughly when I wrote the previous reply (was in the middle of a workshop), but I think it would make sense for brm_multiple where estimates from different datasets are used for uncertainty instead of relying on just one dataset.

For sampling from Hessian, at the moment, I don’t find Stan’s current implementation reliable enough for anything more complicated than a simple glm. Even when the optimization manages to converge, the CrI can be too wide or too narrow compared to sampling.