Some progress has been made!
I will show an example using the bearingcage
dataset from the fwb
package, a bootstrap package for fractional weight bootstrapping.
First we load some packages:
library(tidyverse)
library(fwb)
library(cmdstanr)
library(brms)
library(posterior)
library(ggdist)
We start by with estimating the ordinary model, with default priors from ´brms`:
> brm(formula = hours|cens(!failure)~1,
+ family = weibull,
+ data = bearingcage,
+ backend = "cmdstanr",
+ cores = 4,
+ iter = 3000,
+ warmup = 1000,
+ control= list(adapt_delta=.95)) ->
+ fit_brm
[sampling output]
> summary(fit_brm)
Family: weibull
Links: mu = log; shape = identity
Formula: hours | cens(!failure) ~ 1
Data: bearingcage (Number of observations: 1703)
Draws: 4 chains, each with iter = 3000; warmup = 1000; thin = 1;
total post-warmup draws = 8000
Regression Coefficients:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept 9.47 0.88 8.25 11.69 1.00 1253 1294
Further Distributional Parameters:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
shape 2.08 0.61 1.07 3.47 1.00 1194 1189
Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
Then we start tinkering!
First we need a new C++ header. Some code in it comes from @ajnafa.
r"(#ifndef DIRICHLET_RNG_WRAPPER_HPP
#define DIRICHLET_RNG_WRAPPER_HPP
#include <stan/math.hpp>
#include <boost/random/mersenne_twister.hpp>
#include <chrono>
#include <Eigen/Dense>
#include <iostream>
// Declare an integer to keep track of the iteration count
static int itct = 0;
// Increment the counter
inline void add_iter(std::ostream* pstream__) {
itct += 1;
}
// Retrieve the current count
inline int get_iter(std::ostream* pstream__) {
return itct;
}
// Generate Dirichlet draws, with iteration checking
Eigen::VectorXd dirichlet_rng_wrapper(const Eigen::VectorXd& alpha, std::ostream* pstream__) {
static Eigen::VectorXd last_draw = Eigen::VectorXd::Zero(alpha.size()); // Initialize with zeros
static int last_itct = -1; // Start with -1 to ensure it differs from itct initially
if (itct != last_itct) {
// It's a new iteration, generate new Dirichlet draws
unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
boost::random::mt19937 rng(seed);
last_draw = stan::math::dirichlet_rng(alpha, rng);
// Update the iteration counter
last_itct = itct;
}
// Increment the iteration count is handled outside this function
return last_draw;
}
#endif // DIRICHLET_RNG_WRAPPER_HPP)" ->
iterfuns
writeLines(iterfuns, "iterfuns.hpp")
We also need some code to modify the Stan code generated by brms
.
transform_stan_code <- function(stan_code) {
# Add new functions to the 'functions' block, checking if the block exists
if (grepl("functions \\{", stan_code)) {
stan_code <- sub("functions \\{", "functions {\n void add_iter(); \n int get_iter();\n vector dirichlet_rng_wrapper(vector alpha);", stan_code)
} else {
stan_code <- paste("functions {\n void add_iter(); \n int get_iter();\n vector dirichlet_rng_wrapper(vector alpha);\n}\n", stan_code, sep = "\n")
}
# Comment out the weights vector in the 'data' block
stan_code <- gsub("vector<lower=0>\\[N\\] weights;", "//vector[N] weights; // This line has been commented out", stan_code)
# Add Dirichlet parameters in 'transformed data' block, ensuring the block exists
if (grepl("transformed data \\{", stan_code)) {
stan_code <- sub("transformed data \\{", "transformed data {\n vector[N] alpha = rep_vector(1.0, N); // Dirichlet parameters, all ones for uniform distribution ", stan_code)
} else {
stan_code <- paste(stan_code, "transformed data {\n vector[N] alpha = rep_vector(1.0, N); // Dirichlet parameters, all ones for uniform distribution \n}", sep = "\n")
}
# Add the modified weights calculation in the 'model' block
stan_code <- gsub("(if \\(!prior_only\\) \\{)", "\\1\n vector[N] weights = dirichlet_rng_wrapper(alpha)*N;", stan_code)
# Insert the 'add_iter()' call in the 'generated quantities' block
stan_code <- gsub("generated quantities \\{", "generated quantities {\n add_iter(); // update the counter each iteration ", stan_code)
return(stan_code)
}
We continue with the mischief:
make_stancode(formula = hours|cens(!failure)+weights(1)~1,
family = weibull,
data = bearingcage) %>%
transform_stan_code() ->
modified_stan_code
modified_stan_code %>%
write_stan_file() %>%
cmdstan_model(user_header = here('iterfuns.hpp')) ->
modified_model # this will compile the model
make_standata(formula = hours|cens(!failure)~1, # note the absence of weights here
family = weibull,
data=bearingcage) ->
standata_bootstrap
Then we sample!
modified_model$sample(data = standata_bootstrap,
chains = 4,
iter_warmup = 1000,
iter_sampling = 2000,
adapt_delta = 0.99,
refresh = 50L,
parallel_chains = 3) ->
outcome_samples_cmdstan
> outcome_samples_cmdstan$diagnostic_summary()
$num_divergent
[1] 0 0 0 0
$num_max_treedepth
[1] 0 0 0 0
$ebfmi
[1] 2.003341 1.914465 2.018019 1.963792
No divergences, unbelievable!
Then we put it back together into a brms
output object:
outcome_samples_cmdstan$output_files() %>%
rstan::read_stan_csv() ->
rstan_fit
outcome_samples_cmdstan ->
attributes(rstan_fit)$CmdStanModel
modified_model_brms <- brm(formula = hours|cens(!failure)~1,
family = weibull,
data=bearingcage,
empty = TRUE)
rstan_fit -> modified_model_brms$fit
rename_pars(modified_model_brms) ->
modified_model_brms
> modified_model_brms
Family: weibull
Links: mu = log; shape = identity
Formula: hours | cens(!failure) ~ 1
Data: bearingcage (Number of observations: 1703)
Draws: 4 chains, each with iter = 3000; warmup = 1000; thin = 1;
total post-warmup draws = 8000
Regression Coefficients:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept 9.73 1.46 8.00 13.66 1.01 532 712
Further Distributional Parameters:
Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
shape 2.16 0.94 0.82 4.46 1.01 610 756
Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
Looking at the standard deviations of the posteriors, it seems that we got slightly fuzzier posteriors, which is easier to see when plotted:
We see heavier tails, which probably makes sense as we are likely to have bootstrap weight draws where the few uncensored observations have low weight, so the estimates wander off for a bit.
It is unclear whether this is useful, but it is possible to do.
What is happening is essentially that we sample a set of Dirichlet(1)*n-distributed weights within each iteration. We also keep track of whether we still are in the same iteration or not, since the model block is evaluated more than once for each iteration.