Rvars in data.frame - super efficient way of handling your posteriors

Hi!

This should probably be a blog post (which I don’t have), but I wanted to share something which I discovered recently as it looks as if it will change quite fundamentally my workflow with posteriors.

The trick I realised is that we can make rvars part of data frames. This mentioned in one of the vignettes almost as a side note - but to me it makes things far more convenient to handle posteriors in the context of our data structures. Below is a reprex demonstrating this. I do think that this approach should be known better. I know that there is tidybayes, but I never got into it as I needed to learn yet another tool thing. By making rvarpart of data.frames I can continue using my usual tidyr/dplyr moves and get for free the handling of the posterior. Cool!

There is one caveat though: Not everything works with these special columns. For example, an if_else statement won’t work - one needs to instead index into these things.

Tagging @paul.buerkner and Matthew Kay to thank them and to encourage them for more prominent documentation of this way of handling things.

Sebastian

Here is the reprex:

library(brms)
#> Loading required package: Rcpp
#> Loading 'brms' package (version 2.19.0). Useful instructions
#> can be found by typing help('brms'). A more detailed introduction
#> to the package is available through vignette('brms_overview').
#> 
#> Attaching package: 'brms'
#> The following object is masked from 'package:stats':
#> 
#>     ar
library(dplyr)
#> 
#> Attaching package: 'dplyr'
#> The following objects are masked from 'package:stats':
#> 
#>     filter, lag
#> The following objects are masked from 'package:base':
#> 
#>     intersect, setdiff, setequal, union
library(tidyr)
library(here)
library(posterior)
#> This is posterior version 1.4.1
#> 
#> Attaching package: 'posterior'
#> The following objects are masked from 'package:stats':
#> 
#>     mad, sd, var
#> The following objects are masked from 'package:base':
#> 
#>     %in%, match
library(ggplot2)
theme_set(theme_bw())
library(ggdist)
#> 
#> Attaching package: 'ggdist'
#> The following objects are masked from 'package:brms':
#> 
#>     dstudent_t, pstudent_t, qstudent_t, rstudent_t

# instruct brms to use cmdstanr as backend and cache all Stan binaries
options(brms.backend="cmdstanr", cmdstanr_write_stan_file_dir=here("brms-cache"))
# create cache directory if not yet available
dir.create(here("brms-cache"), FALSE)


model <- bf(count ~ Trt + Base + visit,
            autocor = ~unstr(time=visit, gr=patient))

model_prior <- prior(normal(0, 5), class=Intercept) +
    prior(normal(0, 1), class=b)

fit <- brm(model, prior=model_prior, data = epilepsy, seed=346465, refresh=0, cores=4)
#> Start sampling
#> Running MCMC with 4 parallel chains...
#> 
#> Chain 1 finished in 1.0 seconds.
#> Chain 2 finished in 1.1 seconds.
#> Chain 3 finished in 1.0 seconds.
#> Chain 4 finished in 1.1 seconds.
#> 
#> All 4 chains finished successfully.
#> Mean chain execution time: 1.1 seconds.
#> Total execution time: 1.2 seconds.

## some nifty utility functions:

## converts a matrix of draws as returned by brms into a vector of an rvar
to_rvar_vector <- function(draw_matrix) {
    lab <- "rv"
    mdraws <- as_draws_matrix(draw_matrix)
    variables(mdraws) <- paste0(lab, "[", 1:nvariables(mdraws), "]")
    as_draws_rvars(mdraws)[["rv"]]
}

## simplifying functions to work with models in the context of data.frames
rv_posterior_epred <- function(newdata, model, ...) {
    to_rvar_vector(posterior_epred(model, newdata=newdata, ...))
}

## standard way of using posteriors goes like
pp_count <- posterior_predict(fit)

## now I get a matrix with as many rows as draws and columns as we
## have rows in the original data.frame
dim(pp_count)
#> [1] 4000  236

## getting by visit summaries now is cumbersome, since I need to
## wrangle indices...even forget about contrasts between treated and
## not treated.

## so rvar to the rescue, since rvar can be made part of the original
## data.frame just like:

post_epilepsy <- epilepsy %>% mutate(pm_count=rv_posterior_epred(., fit))

## now the full posterior is part of my original data.frame and I can
## use the tidyverse utilites to post-process my data in a very
## natural way

head(post_epilepsy)
#>   Age Base Trt patient visit count obs       zAge      zBase    pm_count
#> 1  31   11   0       1     1     5   1  0.4249950 -0.7571728  0.94 ± 1.3
#> 2  30   11   0       2     1     3   2  0.2652835 -0.7571728  0.94 ± 1.3
#> 3  25    6   0       3     1     2   3 -0.5332740 -0.9444033 -0.90 ± 1.4
#> 4  36    8   0       4     1     4   4  1.2235525 -0.8695111 -0.16 ± 1.3
#> 5  22   66   0       5     1     7   5 -1.0124085  1.3023626 21.20 ± 1.5
#> 6  29   27   0       6     1     5   6  0.1055720 -0.1580352  6.83 ± 1.1

## let's summarize accross patients by visit and treatment group.. easy:

post_by_visit <- post_epilepsy %>% group_by(visit, Trt) %>%
    summarise(mean_count=rvar_mean(pm_count))
#> `summarise()` has grouped output by 'visit'. You can override using the
#> `.groups` argument.

## note that this is done with the correct by-draw handling! 

post_by_visit
#> # A tibble: 8 × 3
#> # Groups:   visit [4]
#>   visit Trt   mean_count
#>   <fct> <fct> <rvar[1d]>
#> 1 1     0      8.2 ± 1.1
#> 2 1     1      8.3 ± 1.0
#> 3 2     0      8.3 ± 1.1
#> 4 2     1      8.4 ± 1.1
#> 5 3     0      8.2 ± 1.1
#> 6 3     1      8.3 ± 1.0
#> 7 4     0      7.4 ± 1.1
#> 8 4     1      7.5 ± 1.1

## need summaries of the posterior as we are used to them... easy:
post_by_visit %>% mutate(summarise_draws(mean_count)) %>% knitr::kable(digits=2)
visit Trt mean_count variable mean median sd mad q5 q95 rhat ess_bulk ess_tail
1 0 8.2 ± 1.1 mean_count[1] 8.23 8.23 1.06 1.05 6.45 9.98 1 4736.19 2915.18
1 1 8.3 ± 1.0 mean_count[2] 8.31 8.32 1.01 1.03 6.64 9.98 1 5022.53 2726.15
2 0 8.3 ± 1.1 mean_count[1] 8.29 8.31 1.11 1.11 6.41 10.11 1 4498.82 2923.49
2 1 8.4 ± 1.1 mean_count[2] 8.38 8.39 1.07 1.04 6.57 10.10 1 4935.25 2914.84
3 0 8.2 ± 1.1 mean_count[1] 8.23 8.22 1.08 1.10 6.46 9.98 1 4317.58 2602.35
3 1 8.3 ± 1.0 mean_count[2] 8.31 8.33 1.04 1.02 6.56 10.00 1 4681.08 2831.58
4 0 7.4 ± 1.1 mean_count[1] 7.43 7.44 1.13 1.16 5.54 9.30 1 4664.16 2767.32
4 1 7.5 ± 1.1 mean_count[2] 7.51 7.51 1.07 1.09 5.73 9.27 1 5019.83 2734.47

## now that we have a result...we can even plot these rvars in this context with ggdist
post_by_visit %>% ggplot(aes(x=factor(visit), colour=factor(Trt))) +
    stat_pointinterval(aes(ydist=mean_count), position=position_dodge(width=0.2))


## getting a treatment contrast...ok let's go wide
post_by_visit %>% pivot_wider(id_cols="visit", names_from="Trt", values_from="mean_count") %>%
    mutate(effect=`1` - `0`)
#> # A tibble: 4 × 4
#> # Groups:   visit [4]
#>   visit        `0`        `1`        effect
#>   <fct> <rvar[1d]> <rvar[1d]>    <rvar[1d]>
#> 1 1      8.2 ± 1.1  8.3 ± 1.0  0.084 ± 0.87
#> 2 2      8.3 ± 1.1  8.4 ± 1.1  0.084 ± 0.87
#> 3 3      8.2 ± 1.1  8.3 ± 1.0  0.084 ± 0.87
#> 4 4      7.4 ± 1.1  7.5 ± 1.1  0.084 ± 0.87

sessionInfo()
#> R version 4.1.0 (2021-05-18)
#> Platform: aarch64-apple-darwin20 (64-bit)
#> Running under: macOS 13.4
#> 
#> Matrix products: default
#> BLAS:   /Library/Frameworks/R.framework/Versions/4.1-arm64/Resources/lib/libRblas.dylib
#> LAPACK: /Library/Frameworks/R.framework/Versions/4.1-arm64/Resources/lib/libRlapack.dylib
#> 
#> locale:
#> [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
#> 
#> attached base packages:
#> [1] stats     graphics  grDevices utils     datasets  methods   base     
#> 
#> other attached packages:
#> [1] ggdist_3.3.0    ggplot2_3.4.2   posterior_1.4.1 here_1.0.1     
#> [5] tidyr_1.2.1     dplyr_1.0.10    brms_2.19.0     Rcpp_1.0.9     
#> 
#> loaded via a namespace (and not attached):
#>   [1] TH.data_1.1-1        minqa_1.2.4          colorspace_2.0-3    
#>   [4] ellipsis_0.3.2       ggridges_0.5.4       rprojroot_2.0.3     
#>   [7] estimability_1.3     markdown_1.2         base64enc_0.1-3     
#>  [10] fs_1.5.2             farver_2.1.1         rstan_2.21.7        
#>  [13] DT_0.26              fansi_1.0.3          mvtnorm_1.1-3       
#>  [16] bridgesampling_1.1-2 codetools_0.2-18     splines_4.1.0       
#>  [19] knitr_1.40           shinythemes_1.2.0    bayesplot_1.9.0     
#>  [22] projpred_2.0.2       jsonlite_1.8.3       nloptr_2.0.3        
#>  [25] shiny_1.7.2          compiler_4.1.0       emmeans_1.6.3       
#>  [28] backports_1.4.1      assertthat_0.2.1     Matrix_1.3-3        
#>  [31] fastmap_1.1.0        cli_3.4.1            later_1.3.0         
#>  [34] htmltools_0.5.3      prettyunits_1.1.1    tools_4.1.0         
#>  [37] igraph_1.3.5         coda_0.19-4          gtable_0.3.1        
#>  [40] glue_1.6.2           reshape2_1.4.4       styler_1.5.1        
#>  [43] vctrs_0.5.0          nlme_3.1-152         crosstalk_1.2.0     
#>  [46] tensorA_0.36.2       xfun_0.37            stringr_1.4.1       
#>  [49] ps_1.7.1             lme4_1.1-30          mime_0.12           
#>  [52] miniUI_0.1.1.1       lifecycle_1.0.3      gtools_3.9.3        
#>  [55] MASS_7.3-58.1        zoo_1.8-11           scales_1.2.1        
#>  [58] colourpicker_1.1.1   promises_1.2.0.1     Brobdingnag_1.2-9   
#>  [61] parallel_4.1.0       sandwich_3.0-2       inline_0.3.19       
#>  [64] shinystan_2.6.0      gamm4_0.2-6          yaml_2.3.6          
#>  [67] gridExtra_2.3        loo_2.5.1            StanHeaders_2.21.0-7
#>  [70] stringi_1.7.8        highr_0.9            dygraphs_1.1.1.6    
#>  [73] checkmate_2.1.0      boot_1.3-28          pkgbuild_1.3.1      
#>  [76] cmdstanr_0.5.2       rlang_1.1.1          pkgconfig_2.0.3     
#>  [79] matrixStats_0.62.0   distributional_0.3.2 evaluate_0.17       
#>  [82] lattice_0.20-45      purrr_0.3.5          labeling_0.4.2      
#>  [85] rstantools_2.2.0     htmlwidgets_1.5.4    processx_3.7.0      
#>  [88] tidyselect_1.2.0     plyr_1.8.7           magrittr_2.0.3      
#>  [91] R6_2.5.1             generics_0.1.3       multcomp_1.4-20     
#>  [94] DBI_1.1.3            pillar_1.8.1         withr_2.5.0         
#>  [97] mgcv_1.8-35          xts_0.12.2           survival_3.3-1      
#> [100] abind_1.4-5          tibble_3.1.8         crayon_1.5.2        
#> [103] utf8_1.2.2           rmarkdown_2.20       grid_4.1.0          
#> [106] data.table_1.14.2    callr_3.7.3          threejs_0.3.3       
#> [109] reprex_2.0.2         digest_0.6.30        xtable_1.8-4        
#> [112] httpuv_1.6.6         RcppParallel_5.1.5   stats4_4.1.0        
#> [115] munsell_0.5.0        shinyjs_2.1.0

Created on 2023-05-24 with reprex v2.0.2

10 Likes

Thanks for this!! We probably should advertise that feature of rvars more :).

Also of note is that besides being easier to read and (usually) manipulate, data frames of rvars are generally more memory-efficient than long-format data frames of draws:

df_draws = posterior::example_draws() |> tidybayes::spread_draws(theta[i])
df_draws
#> # A tibble: 3,200 × 5
#> # Groups:   i [8]
#>        i  theta .chain .iteration .draw
#>    <int>  <dbl>  <int>      <int> <int>
#>  1     1  3.96       1          1     1
#>  2     1  0.124      1          2     2
#>  3     1 21.3        1          3     3
#>  4     1 14.7        1          4     4
#>  5     1  5.96       1          5     5
#>  6     1  5.76       1          6     6
#>  7     1  4.03       1          7     7
#>  8     1 -0.278      1          8     8
#>  9     1  1.81       1          9     9
#> 10     1  6.08       1         10    10
#> # ℹ 3,190 more rows
object.size(df_draws)
#> 93312 bytes

df_rvars = posterior::example_draws() |> tidybayes::spread_rvars(theta[i])
df_rvars
#> # A tibble: 8 × 2
#>       i      theta
#>   <int> <rvar[1d]>
#> 1     1  6.7 ± 6.3
#> 2     2  5.3 ± 4.6
#> 3     3  3.0 ± 6.8
#> 4     4  4.9 ± 4.9
#> 5     5  3.2 ± 5.1
#> 6     6  4.0 ± 5.2
#> 7     7  6.5 ± 5.3
#> 8     8  4.6 ± 5.3
object.size(df_rvars)
#> 53400 bytes

This is a small example, but you can imagine on a large model this can make a difference :).

The other feature we probably don’t advertise well is that once those rvars are in a data frame, you can visualize them with any of the stats/geoms in {ggdist} very easily. e.g.:

df_rvars |> 
  ggplot(aes(y = factor(i), xdist = theta)) +
  ggdist::stat_slabinterval() +
  labs(y = "i", x = "theta[i]")

Re: ifelse, that is an excellent point — I drafted an implementation of ifelse for rvars for an issue about subsetting rvars with other rvars; need to finish that off soon. See here.

2 Likes

oh hah your example already had a ggdist example in it too :)

It is really great work! The memory thing is a nice thing, which I take…but as Bayesian inference is applied quite often to small data sets around here, it’s not - to me - the big thing. Having the posterior as a nested clever thing part of my data.frames is huge in comparison. This will strongly encourage me to keep the full posterior in my workflow rather than crudely summarising things to quantiles usually - and that’s great as it does simplify to propagate uncertainty around in my models at no cost from a practical viewpoint.

Wrt. to if else: What I was doing is to impute values for the observed data which got missing. So I do cast the observed data to an rvar and then replace the na entries with model predictives. Right now I need to get the indices of the missing data and then do the replacement operation…with if_else it would be a single statement inside a mutate. I wasn’t thinking of an rvar_ifelse (by draw ifelse things)… but that’s cool too!

Ahh… and I am actually surprised that I needed to define the two utility functions in my demo code above. These feel so natural to have that they should be in posterior or/and brms…is what I would suggest.

Again many thanks for rvars!

2 Likes

Agree working in data.frames are fantastic, one day I’m going to try to fix this issue so rvar works with data.table

1 Like

Ah… you don’t! Which is clearly a failure of documentation on my part :)

The rvar() constructor takes arrays where the first dimension is the draw, so its output is directly compatible with posterior_epred(), posterior_predict(), etc (this design choice was partly made to ease use with arrays of draws that are generated by functions like that). So you can use rvar() in place of your to_rvar_vector() function.

rv_posterior_epred() doesn’t have a direct analog in {posterior}, though it is pretty close to just calling rvar(posterior_epred(...)). However, {tidybayes} does have several rvar + data frame functions, including add_epred_rvars(), which is closer to rv_posterior_epred() with some additional features (mostly for models with multivariate outcomes).

Using these functions, you can replace this call:

post_epilepsy <- epilepsy %>% mutate(pm_count = rv_posterior_epred(., fit))

With this (using just posterior):

post_epilepsy <- epilepsy %>% mutate(pm_count = rvar(posterior_epred(fit, newdata = .)))

Or this (using posterior + tidybayes):

post_epilepsy <- epilepsy %>% tidybayes::add_epred_rvars(fit, value = "pm_count")

The tidybayes + posterior vignette (which is mostly about rvars in data frames) has some more examples.

1 Like

That would be lovely, though I think last time I looked at similar issues raised on the data.table github for other data types my sense was rvar is fundamentally incompatible with data.table and this is unlikely to be solvable (at least, without completely changing rvar to be a list-based format). Sooo I’m not crazy optimistic about it. The best I could imagine is allowing rvar to be backed either by an array or a list of arrays (or maybe have something like an "rvar_list") and store that in a data.table. (or to use something like the proxy format returned by posterior:::vec_proxy.rvar(), which is basically vector of pointers into an rvar…).

Though at that point it would probably be easier to use distributional::dist_sample(), which essentially is a list-based format and therefore compatible with data table. posterior::rvar and {distributional} are compatible, and you can cast an rvar to a distributional object via something like:

x = rvar_rng(rnorm, 10, 1:10)
x_dist = vctrs::vec_cast(x, dist_missing())

x_dist will be a vector of dist_sample() objects. It won’t be able to do math operations as efficiently as the array-backed rvar, but you should be able to add it to a data table.

1 Like