Hi!
This should probably be a blog post (which I don’t have), but I wanted to share something which I discovered recently as it looks as if it will change quite fundamentally my workflow with posteriors.
The trick I realised is that we can make rvar
s part of data frames. This mentioned in one of the vignettes almost as a side note - but to me it makes things far more convenient to handle posteriors in the context of our data structures. Below is a reprex demonstrating this. I do think that this approach should be known better. I know that there is tidybayes
, but I never got into it as I needed to learn yet another tool thing. By making rvar
part of data.frames I can continue using my usual tidyr/dplyr moves and get for free the handling of the posterior. Cool!
There is one caveat though: Not everything works with these special columns. For example, an if_else
statement won’t work - one needs to instead index into these things.
Tagging @paul.buerkner and Matthew Kay to thank them and to encourage them for more prominent documentation of this way of handling things.
Sebastian
Here is the reprex:
library(brms)
#> Loading required package: Rcpp
#> Loading 'brms' package (version 2.19.0). Useful instructions
#> can be found by typing help('brms'). A more detailed introduction
#> to the package is available through vignette('brms_overview').
#>
#> Attaching package: 'brms'
#> The following object is masked from 'package:stats':
#>
#> ar
library(dplyr)
#>
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#>
#> filter, lag
#> The following objects are masked from 'package:base':
#>
#> intersect, setdiff, setequal, union
library(tidyr)
library(here)
library(posterior)
#> This is posterior version 1.4.1
#>
#> Attaching package: 'posterior'
#> The following objects are masked from 'package:stats':
#>
#> mad, sd, var
#> The following objects are masked from 'package:base':
#>
#> %in%, match
library(ggplot2)
theme_set(theme_bw())
library(ggdist)
#>
#> Attaching package: 'ggdist'
#> The following objects are masked from 'package:brms':
#>
#> dstudent_t, pstudent_t, qstudent_t, rstudent_t
# instruct brms to use cmdstanr as backend and cache all Stan binaries
options(brms.backend="cmdstanr", cmdstanr_write_stan_file_dir=here("brms-cache"))
# create cache directory if not yet available
dir.create(here("brms-cache"), FALSE)
model <- bf(count ~ Trt + Base + visit,
autocor = ~unstr(time=visit, gr=patient))
model_prior <- prior(normal(0, 5), class=Intercept) +
prior(normal(0, 1), class=b)
fit <- brm(model, prior=model_prior, data = epilepsy, seed=346465, refresh=0, cores=4)
#> Start sampling
#> Running MCMC with 4 parallel chains...
#>
#> Chain 1 finished in 1.0 seconds.
#> Chain 2 finished in 1.1 seconds.
#> Chain 3 finished in 1.0 seconds.
#> Chain 4 finished in 1.1 seconds.
#>
#> All 4 chains finished successfully.
#> Mean chain execution time: 1.1 seconds.
#> Total execution time: 1.2 seconds.
## some nifty utility functions:
## converts a matrix of draws as returned by brms into a vector of an rvar
to_rvar_vector <- function(draw_matrix) {
lab <- "rv"
mdraws <- as_draws_matrix(draw_matrix)
variables(mdraws) <- paste0(lab, "[", 1:nvariables(mdraws), "]")
as_draws_rvars(mdraws)[["rv"]]
}
## simplifying functions to work with models in the context of data.frames
rv_posterior_epred <- function(newdata, model, ...) {
to_rvar_vector(posterior_epred(model, newdata=newdata, ...))
}
## standard way of using posteriors goes like
pp_count <- posterior_predict(fit)
## now I get a matrix with as many rows as draws and columns as we
## have rows in the original data.frame
dim(pp_count)
#> [1] 4000 236
## getting by visit summaries now is cumbersome, since I need to
## wrangle indices...even forget about contrasts between treated and
## not treated.
## so rvar to the rescue, since rvar can be made part of the original
## data.frame just like:
post_epilepsy <- epilepsy %>% mutate(pm_count=rv_posterior_epred(., fit))
## now the full posterior is part of my original data.frame and I can
## use the tidyverse utilites to post-process my data in a very
## natural way
head(post_epilepsy)
#> Age Base Trt patient visit count obs zAge zBase pm_count
#> 1 31 11 0 1 1 5 1 0.4249950 -0.7571728 0.94 ± 1.3
#> 2 30 11 0 2 1 3 2 0.2652835 -0.7571728 0.94 ± 1.3
#> 3 25 6 0 3 1 2 3 -0.5332740 -0.9444033 -0.90 ± 1.4
#> 4 36 8 0 4 1 4 4 1.2235525 -0.8695111 -0.16 ± 1.3
#> 5 22 66 0 5 1 7 5 -1.0124085 1.3023626 21.20 ± 1.5
#> 6 29 27 0 6 1 5 6 0.1055720 -0.1580352 6.83 ± 1.1
## let's summarize accross patients by visit and treatment group.. easy:
post_by_visit <- post_epilepsy %>% group_by(visit, Trt) %>%
summarise(mean_count=rvar_mean(pm_count))
#> `summarise()` has grouped output by 'visit'. You can override using the
#> `.groups` argument.
## note that this is done with the correct by-draw handling!
post_by_visit
#> # A tibble: 8 × 3
#> # Groups: visit [4]
#> visit Trt mean_count
#> <fct> <fct> <rvar[1d]>
#> 1 1 0 8.2 ± 1.1
#> 2 1 1 8.3 ± 1.0
#> 3 2 0 8.3 ± 1.1
#> 4 2 1 8.4 ± 1.1
#> 5 3 0 8.2 ± 1.1
#> 6 3 1 8.3 ± 1.0
#> 7 4 0 7.4 ± 1.1
#> 8 4 1 7.5 ± 1.1
## need summaries of the posterior as we are used to them... easy:
post_by_visit %>% mutate(summarise_draws(mean_count)) %>% knitr::kable(digits=2)
visit | Trt | mean_count | variable | mean | median | sd | mad | q5 | q95 | rhat | ess_bulk | ess_tail |
---|---|---|---|---|---|---|---|---|---|---|---|---|
1 | 0 | 8.2 ± 1.1 | mean_count[1] | 8.23 | 8.23 | 1.06 | 1.05 | 6.45 | 9.98 | 1 | 4736.19 | 2915.18 |
1 | 1 | 8.3 ± 1.0 | mean_count[2] | 8.31 | 8.32 | 1.01 | 1.03 | 6.64 | 9.98 | 1 | 5022.53 | 2726.15 |
2 | 0 | 8.3 ± 1.1 | mean_count[1] | 8.29 | 8.31 | 1.11 | 1.11 | 6.41 | 10.11 | 1 | 4498.82 | 2923.49 |
2 | 1 | 8.4 ± 1.1 | mean_count[2] | 8.38 | 8.39 | 1.07 | 1.04 | 6.57 | 10.10 | 1 | 4935.25 | 2914.84 |
3 | 0 | 8.2 ± 1.1 | mean_count[1] | 8.23 | 8.22 | 1.08 | 1.10 | 6.46 | 9.98 | 1 | 4317.58 | 2602.35 |
3 | 1 | 8.3 ± 1.0 | mean_count[2] | 8.31 | 8.33 | 1.04 | 1.02 | 6.56 | 10.00 | 1 | 4681.08 | 2831.58 |
4 | 0 | 7.4 ± 1.1 | mean_count[1] | 7.43 | 7.44 | 1.13 | 1.16 | 5.54 | 9.30 | 1 | 4664.16 | 2767.32 |
4 | 1 | 7.5 ± 1.1 | mean_count[2] | 7.51 | 7.51 | 1.07 | 1.09 | 5.73 | 9.27 | 1 | 5019.83 | 2734.47 |
## now that we have a result...we can even plot these rvars in this context with ggdist
post_by_visit %>% ggplot(aes(x=factor(visit), colour=factor(Trt))) +
stat_pointinterval(aes(ydist=mean_count), position=position_dodge(width=0.2))
## getting a treatment contrast...ok let's go wide
post_by_visit %>% pivot_wider(id_cols="visit", names_from="Trt", values_from="mean_count") %>%
mutate(effect=`1` - `0`)
#> # A tibble: 4 × 4
#> # Groups: visit [4]
#> visit `0` `1` effect
#> <fct> <rvar[1d]> <rvar[1d]> <rvar[1d]>
#> 1 1 8.2 ± 1.1 8.3 ± 1.0 0.084 ± 0.87
#> 2 2 8.3 ± 1.1 8.4 ± 1.1 0.084 ± 0.87
#> 3 3 8.2 ± 1.1 8.3 ± 1.0 0.084 ± 0.87
#> 4 4 7.4 ± 1.1 7.5 ± 1.1 0.084 ± 0.87
sessionInfo()
#> R version 4.1.0 (2021-05-18)
#> Platform: aarch64-apple-darwin20 (64-bit)
#> Running under: macOS 13.4
#>
#> Matrix products: default
#> BLAS: /Library/Frameworks/R.framework/Versions/4.1-arm64/Resources/lib/libRblas.dylib
#> LAPACK: /Library/Frameworks/R.framework/Versions/4.1-arm64/Resources/lib/libRlapack.dylib
#>
#> locale:
#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#>
#> attached base packages:
#> [1] stats graphics grDevices utils datasets methods base
#>
#> other attached packages:
#> [1] ggdist_3.3.0 ggplot2_3.4.2 posterior_1.4.1 here_1.0.1
#> [5] tidyr_1.2.1 dplyr_1.0.10 brms_2.19.0 Rcpp_1.0.9
#>
#> loaded via a namespace (and not attached):
#> [1] TH.data_1.1-1 minqa_1.2.4 colorspace_2.0-3
#> [4] ellipsis_0.3.2 ggridges_0.5.4 rprojroot_2.0.3
#> [7] estimability_1.3 markdown_1.2 base64enc_0.1-3
#> [10] fs_1.5.2 farver_2.1.1 rstan_2.21.7
#> [13] DT_0.26 fansi_1.0.3 mvtnorm_1.1-3
#> [16] bridgesampling_1.1-2 codetools_0.2-18 splines_4.1.0
#> [19] knitr_1.40 shinythemes_1.2.0 bayesplot_1.9.0
#> [22] projpred_2.0.2 jsonlite_1.8.3 nloptr_2.0.3
#> [25] shiny_1.7.2 compiler_4.1.0 emmeans_1.6.3
#> [28] backports_1.4.1 assertthat_0.2.1 Matrix_1.3-3
#> [31] fastmap_1.1.0 cli_3.4.1 later_1.3.0
#> [34] htmltools_0.5.3 prettyunits_1.1.1 tools_4.1.0
#> [37] igraph_1.3.5 coda_0.19-4 gtable_0.3.1
#> [40] glue_1.6.2 reshape2_1.4.4 styler_1.5.1
#> [43] vctrs_0.5.0 nlme_3.1-152 crosstalk_1.2.0
#> [46] tensorA_0.36.2 xfun_0.37 stringr_1.4.1
#> [49] ps_1.7.1 lme4_1.1-30 mime_0.12
#> [52] miniUI_0.1.1.1 lifecycle_1.0.3 gtools_3.9.3
#> [55] MASS_7.3-58.1 zoo_1.8-11 scales_1.2.1
#> [58] colourpicker_1.1.1 promises_1.2.0.1 Brobdingnag_1.2-9
#> [61] parallel_4.1.0 sandwich_3.0-2 inline_0.3.19
#> [64] shinystan_2.6.0 gamm4_0.2-6 yaml_2.3.6
#> [67] gridExtra_2.3 loo_2.5.1 StanHeaders_2.21.0-7
#> [70] stringi_1.7.8 highr_0.9 dygraphs_1.1.1.6
#> [73] checkmate_2.1.0 boot_1.3-28 pkgbuild_1.3.1
#> [76] cmdstanr_0.5.2 rlang_1.1.1 pkgconfig_2.0.3
#> [79] matrixStats_0.62.0 distributional_0.3.2 evaluate_0.17
#> [82] lattice_0.20-45 purrr_0.3.5 labeling_0.4.2
#> [85] rstantools_2.2.0 htmlwidgets_1.5.4 processx_3.7.0
#> [88] tidyselect_1.2.0 plyr_1.8.7 magrittr_2.0.3
#> [91] R6_2.5.1 generics_0.1.3 multcomp_1.4-20
#> [94] DBI_1.1.3 pillar_1.8.1 withr_2.5.0
#> [97] mgcv_1.8-35 xts_0.12.2 survival_3.3-1
#> [100] abind_1.4-5 tibble_3.1.8 crayon_1.5.2
#> [103] utf8_1.2.2 rmarkdown_2.20 grid_4.1.0
#> [106] data.table_1.14.2 callr_3.7.3 threejs_0.3.3
#> [109] reprex_2.0.2 digest_0.6.30 xtable_1.8-4
#> [112] httpuv_1.6.6 RcppParallel_5.1.5 stats4_4.1.0
#> [115] munsell_0.5.0 shinyjs_2.1.0
Created on 2023-05-24 with reprex v2.0.2