Fully Bayesian Bootstrap

I came across an excellent blog post by @andrewheiss, which is based on the work by @ajnafa. You can find it here: https://www.andrewheiss.com/blog/2021/12/20/fully-bayesian-ate-iptw/.

Their approach to fully Bayesian inverse probability weighting (IPW) consists of two main steps. Firstly, they separately fit a weight model to generate posterior weights. Then, during the sampling process of the outcome model, they incorporate these posterior weights by assigning each draw in the outcome model to a corresponding row of weights.

Upon reading this, it reminded me of the Bayesian bootstrap method, also known as Fractional-Random-Weight bootstrap. In contrast to the “ordinary” bootstrap, Bayesian bootstrap employs weights that sum up to N. Typically, these weights are sampled from a Dirichlet distribution using a constant vector of ones for the parameter alpha. To verify this idea, I conducted a test using the Bearing Cage Field Failure Data, which was also used in Xu et al.'s (2020) paper. Interestingly, I obtained nearly identical results in terms of precision as reported in the paper.

To be honest, I find this quite fascinating, but my main question is whether this result is considered novel or trivial?

There is a slightly different approach to bootstrapping documented in the repository for Paananen et al.'s (2021) work in a case study titled “Importance Weighted Moment Matching for Fast Bootstrapping.” It should be noted that Paananen only considers the “ordinary” bootstrap and does not explore the case with fractional weights. When I experimented with this method in the past, it also worked well with fractional weights. You can find the relevant information here: https://htmlpreview.github.io/?https://github.com/topipa/iter-mm-paper/blob/master/case_studies/IWMM_BS.html.

References:

2 Likes

The approach illustrated in that blog post was my first attempt at tackling Bayesian estimation of IPTW. Unfortunately, it doesn’t scale well and runs into sampling issues pretty easily. I developed a more scalable solution that samples the weights from the design stage and only requires passing the location and scale as input data. You can find the working paper, code, and simulations here: https://github.com/ajnafa/Latent-Bayesian-MSM.

1 Like

Thank you for sharing your insights and providing the link to your additional work!

I completely understand how the scaling issue can arise, especially considering the rapid growth of the weight matrix. In the case of bootstrapping, it seems reasonable to sample the weights within each draw, as they are independent of other variables. However, using the built-in Dirichlet sampler in the model block of Stan would lead to loud complaints. In that case, relying on the external C++ trick would be necessary. It would be interesting to conduct further experimentation regarding sampling issues and explore different bootstrapping schemes, such as clustered bootstrap.

I intend to further investigate this topic when I have the opportunity.

1 Like

The bootstrapping angle here is quite interesting and I’m not aware of any existing approaches that attempt to handle the weights via a C++ routine similar to how we originally did for handling the posterior IPTW weights in @andrewheiss’s blog post.

I use the fractional weights approach to the Bayesian bootstrap fairly often when working with the g-formula for estimating PATEs as in this example but I haven’t bothered experimenting with it further and I’d be interested to see how one could generalize it to time series or hierarchical contexts as the last time I tried doing this in Stan I found it to be less than straightforward.

Some progress has been made!

I will show an example using the bearingcage dataset from the fwb package, a bootstrap package for fractional weight bootstrapping.

First we load some packages:

library(tidyverse)
library(fwb)
library(cmdstanr)
library(brms)
library(posterior)
library(ggdist)

We start by with estimating the ordinary model, with default priors from ´brms`:

> brm(formula = hours|cens(!failure)~1, 
+     family = weibull, 
+     data = bearingcage, 
+     backend = "cmdstanr", 
+     cores = 4, 
+     iter = 3000, 
+     warmup = 1000, 
+     control= list(adapt_delta=.95)) -> 
+   fit_brm
[sampling output]
> summary(fit_brm)
 Family: weibull 
  Links: mu = log; shape = identity 
Formula: hours | cens(!failure) ~ 1 
   Data: bearingcage (Number of observations: 1703) 
  Draws: 4 chains, each with iter = 3000; warmup = 1000; thin = 1;
         total post-warmup draws = 8000

Regression Coefficients:
          Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept     9.47      0.88     8.25    11.69 1.00     1253     1294

Further Distributional Parameters:
      Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
shape     2.08      0.61     1.07     3.47 1.00     1194     1189

Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).

Then we start tinkering!

First we need a new C++ header. Some code in it comes from @ajnafa.

r"(#ifndef DIRICHLET_RNG_WRAPPER_HPP
#define DIRICHLET_RNG_WRAPPER_HPP

#include <stan/math.hpp>
#include <boost/random/mersenne_twister.hpp>
#include <chrono>
#include <Eigen/Dense>
#include <iostream>

// Declare an integer to keep track of the iteration count
static int itct = 0;

// Increment the counter
inline void add_iter(std::ostream* pstream__) {
  itct += 1;
}

// Retrieve the current count
inline int get_iter(std::ostream* pstream__) {
  return itct;
}

// Generate Dirichlet draws, with iteration checking
Eigen::VectorXd dirichlet_rng_wrapper(const Eigen::VectorXd& alpha, std::ostream* pstream__) {
  static Eigen::VectorXd last_draw = Eigen::VectorXd::Zero(alpha.size()); // Initialize with zeros
  static int last_itct = -1;  // Start with -1 to ensure it differs from itct initially
  
  if (itct != last_itct) {
    // It's a new iteration, generate new Dirichlet draws
    unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
    boost::random::mt19937 rng(seed);
    last_draw = stan::math::dirichlet_rng(alpha, rng);
    
    // Update the iteration counter
    last_itct = itct;
  }
  
  // Increment the iteration count is handled outside this function
  
  return last_draw;
}

#endif // DIRICHLET_RNG_WRAPPER_HPP)" ->
  iterfuns

writeLines(iterfuns, "iterfuns.hpp")

We also need some code to modify the Stan code generated by brms.

transform_stan_code <- function(stan_code) {
  # Add new functions to the 'functions' block, checking if the block exists
  if (grepl("functions \\{", stan_code)) {
    stan_code <- sub("functions \\{", "functions {\n  void add_iter(); \n  int get_iter();\n  vector dirichlet_rng_wrapper(vector alpha);", stan_code)
  } else {
    stan_code <- paste("functions {\n  void add_iter(); \n  int get_iter();\n  vector dirichlet_rng_wrapper(vector alpha);\n}\n", stan_code, sep = "\n")
  }
  
  # Comment out the weights vector in the 'data' block
  stan_code <- gsub("vector<lower=0>\\[N\\] weights;", "//vector[N] weights;  // This line has been commented out", stan_code)
  
  # Add Dirichlet parameters in 'transformed data' block, ensuring the block exists
  if (grepl("transformed data \\{", stan_code)) {
    stan_code <- sub("transformed data \\{", "transformed data {\n  vector[N] alpha = rep_vector(1.0, N);  // Dirichlet parameters, all ones for uniform distribution ", stan_code)
  } else {
    stan_code <- paste(stan_code, "transformed data {\n  vector[N] alpha = rep_vector(1.0, N);  // Dirichlet parameters, all ones for uniform distribution  \n}", sep = "\n")
  }
  
  # Add the modified weights calculation in the 'model' block
  stan_code <- gsub("(if \\(!prior_only\\) \\{)", "\\1\n    vector[N] weights = dirichlet_rng_wrapper(alpha)*N;", stan_code)
  
  # Insert the 'add_iter()' call in the 'generated quantities' block
  stan_code <- gsub("generated quantities \\{", "generated quantities {\n  add_iter();  // update the counter each iteration ", stan_code)
  
  return(stan_code)
}

We continue with the mischief:

make_stancode(formula = hours|cens(!failure)+weights(1)~1, 
              family = weibull, 
              data = bearingcage) %>% 
  transform_stan_code() ->
  modified_stan_code

modified_stan_code %>% 
  write_stan_file() %>% 
  cmdstan_model(user_header = here('iterfuns.hpp')) -> 
  modified_model # this will compile the model

make_standata(formula = hours|cens(!failure)~1, # note the absence of weights here
              family = weibull, 
              data=bearingcage) -> 
  standata_bootstrap

Then we sample!

modified_model$sample(data = standata_bootstrap, 
                      chains = 4, 
                      iter_warmup = 1000, 
                      iter_sampling = 2000, 
                      adapt_delta = 0.99, 
                      refresh = 50L, 
                      parallel_chains = 3)  ->
  outcome_samples_cmdstan
> outcome_samples_cmdstan$diagnostic_summary()
$num_divergent
[1] 0 0 0 0

$num_max_treedepth
[1] 0 0 0 0

$ebfmi
[1] 2.003341 1.914465 2.018019 1.963792

No divergences, unbelievable!

Then we put it back together into a brms output object:

outcome_samples_cmdstan$output_files() %>% 
  rstan::read_stan_csv() -> 
  rstan_fit

outcome_samples_cmdstan -> 
  attributes(rstan_fit)$CmdStanModel

modified_model_brms <- brm(formula = hours|cens(!failure)~1, 
                           family = weibull, 
                           data=bearingcage,
                           empty = TRUE)

rstan_fit -> modified_model_brms$fit

rename_pars(modified_model_brms) -> 
  modified_model_brms
> modified_model_brms
 Family: weibull 
  Links: mu = log; shape = identity 
Formula: hours | cens(!failure) ~ 1 
   Data: bearingcage (Number of observations: 1703) 
  Draws: 4 chains, each with iter = 3000; warmup = 1000; thin = 1;
         total post-warmup draws = 8000

Regression Coefficients:
          Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept     9.73      1.46     8.00    13.66 1.01      532      712

Further Distributional Parameters:
      Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
shape     2.16      0.94     0.82     4.46 1.01      610      756

Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).

Looking at the standard deviations of the posteriors, it seems that we got slightly fuzzier posteriors, which is easier to see when plotted:

We see heavier tails, which probably makes sense as we are likely to have bootstrap weight draws where the few uncensored observations have low weight, so the estimates wander off for a bit.

It is unclear whether this is useful, but it is possible to do.

What is happening is essentially that we sample a set of Dirichlet(1)*n-distributed weights within each iteration. We also keep track of whether we still are in the same iteration or not, since the model block is evaluated more than once for each iteration.

5 Likes

The proposed generated data block (design-docs/designs/0035-pre-model-gqs.md at 0035-pre-model-gqs · stan-dev/design-docs · GitHub) could also be used both for this, and for the IPW approach. Let’s hope it will be implemented soon!

2 Likes

@Bob_Carpenter if we took this proposal one step further we could have Stan programs be used as priors for other Stan programs. Suzerain Stan takes the fit Stan fit object and generates a sampling draw which is used as data in the new Stan program.

Perhaps a more computationally effective way to achieve a similar outcome would be to fit a generative normalizing flow over the draws of a Stan fit object and then have that as the generating object in the generated data block.

3 Likes

“Suzerain Stan”. Was that a joke—the link directs to a definition of the “word”.

I guess “computationally effective” is going to depend on your hardware, but at least as of now, normalizing flows are a bit challenging to control. Having said that, lots of people are going this way, including in preconditioning in NumPy!

Yep, and I’d just make it Susan Stan.

On you proposal, I love it and hope it gets built!

1 Like