Some Bayesplot/ggplot2 questions relating to plotting priors and posteriors in same panel

Below is an example of extracting draws of the prior and posterior for each parameter and plotting with ggplot2. I’ve fit the model with brms, as I don’t know Stan well enough to program it directly in Stan.

# Set up and fit a sample model
library(tidyverse)
library(patchwork)
library(brms)
theme_set(theme_bw())

# Model formula
form = Petal.Width ~ Sepal.Width + Sepal.Length + Species

m1 = brm(formula=form,
         data=iris, 
         family=gaussian(),
         prior = c(prior(normal(0,0.2), class="b", coef="Speciesversicolor"),
                   prior(normal(0,0.2), class="b", coef="Speciesvirginica"),
                   prior(normal(0,0.5), class="b", coef="Sepal.Width"),
                   prior(normal(0,0.3), class="b", coef="Sepal.Length")),
         backend="cmdstanr",
         sample_prior=TRUE)
# Get posterior draws 
m1.draws = as_draws_df(m1)

# Keep just the draws for the parameters and corresponding priors
m1.draws = m1.draws %>% 
  select(matches("^(prior|b_|sigma)")) %>% 
  rename(prior_b_Intercept=prior_Intercept)

# Stack the prior and posterior draws for each parameter
#  and reshape to long format
m1.draws = bind_rows(
    m1.draws %>% 
      select(matches("prior")) %>% 
      mutate(type="prior") %>% 
      rename_all(~gsub("prior_", "", .)),
    m1.draws %>% 
      select(-matches("prior")) %>% 
      mutate(type="posterior")
  ) %>% 
  pivot_longer(-type)

# Show that we now have 4000 draws of prior and posterior for each parameter
m1.draws %>% 
  count(type, name)

#>    type      name                    n
#>    <chr>     <chr>               <int>
#>  1 posterior b_Intercept          4000
#>  2 posterior b_Sepal.Length       4000
#>  3 posterior b_Sepal.Width        4000
#>  4 posterior b_Speciesversicolor  4000
#>  5 posterior b_Speciesvirginica   4000
#>  6 posterior sigma                4000
#>  7 prior     b_Intercept          4000
#>  8 prior     b_Sepal.Length       4000
#>  9 prior     b_Sepal.Width        4000
#> 10 prior     b_Speciesversicolor  4000
#> 11 prior     b_Speciesvirginica   4000
#> 12 prior     sigma                4000

# Show one draw for each parameter and prior
m1.draws %>% 
  group_by(type, name) %>% 
  slice(1) %>% print(n=Inf)

#>    type      name                   value
#>    <chr>     <chr>                  <dbl>
#>  1 posterior b_Intercept         -1.04   
#>  2 posterior b_Sepal.Length       0.258  
#>  3 posterior b_Sepal.Width        0.00865
#>  4 posterior b_Speciesversicolor  0.833  
#>  5 posterior b_Speciesvirginica   1.33   
#>  6 posterior sigma                0.195  
#>  7 prior     b_Intercept          4.45   
#>  8 prior     b_Sepal.Length       0.301  
#>  9 prior     b_Sepal.Width       -0.178  
#> 10 prior     b_Speciesversicolor  0.0645 
#> 11 prior     b_Speciesvirginica  -0.207  
#> 12 prior     sigma                1.61
# Data frame of true parameter values (assuming we just have this)
true.params = tibble(
  name = sort(unique(m1.draws$name)),
  value = c(-0.6, 0.1, 0.02, 0.7, 1.3, 0.3),
  type="true value"
)
# Plot using ggdist
p1 = m1.draws %>% 
  bind_rows(true.params) %>% 
  ggplot(aes(value, name, colour=type)) +
    stat_slab(data = . %>% filter(type != "true value"),
                 aes(slab_colour=type), fill=NA, normalize="groups", 
                 scale=0.7, slab_size=0.5, fatten_point=1.5) +
    geom_segment(data=. %>% filter(type=="true value"), 
                 aes(x=value, xend=value, y=name, yend=..y.. + 0.7),
                 colour="grey10", size=0.3) +
    coord_cartesian(xlim=c(-3,3)) +
    labs(x=NULL, y=NULL, colour=NULL, fill=NULL, slab_colour=NULL) +
    theme(panel.grid.major.y=element_line(colour="grey10", size=0.2))
  
# Plot using plain ggplot
p2 = m1.draws %>% 
  ggplot(aes(value, colour=type)) +
    geom_vline(data=true.params, aes(xintercept=value), colour="grey10", size=0.3) +
    geom_density(aes(y=after_stat(scaled))) + 
    facet_wrap(vars(name), scales="free") +
    labs(x=NULL, y=NULL, colour=NULL, fill=NULL, slab_colour=NULL) +
    theme(axis.text.y=element_blank(),
          axis.ticks.y=element_blank())
  
p1 / p2

Created on 2022-02-01 by the reprex package (v2.0.1)

3 Likes