Below is an example of extracting draws of the prior and posterior for each parameter and plotting with ggplot2. I’ve fit the model with brms, as I don’t know Stan well enough to program it directly in Stan.
# Set up and fit a sample model
library(tidyverse)
library(patchwork)
library(brms)
theme_set(theme_bw())
# Model formula
form = Petal.Width ~ Sepal.Width + Sepal.Length + Species
m1 = brm(formula=form,
data=iris,
family=gaussian(),
prior = c(prior(normal(0,0.2), class="b", coef="Speciesversicolor"),
prior(normal(0,0.2), class="b", coef="Speciesvirginica"),
prior(normal(0,0.5), class="b", coef="Sepal.Width"),
prior(normal(0,0.3), class="b", coef="Sepal.Length")),
backend="cmdstanr",
sample_prior=TRUE)
# Get posterior draws
m1.draws = as_draws_df(m1)
# Keep just the draws for the parameters and corresponding priors
m1.draws = m1.draws %>%
select(matches("^(prior|b_|sigma)")) %>%
rename(prior_b_Intercept=prior_Intercept)
# Stack the prior and posterior draws for each parameter
# and reshape to long format
m1.draws = bind_rows(
m1.draws %>%
select(matches("prior")) %>%
mutate(type="prior") %>%
rename_all(~gsub("prior_", "", .)),
m1.draws %>%
select(-matches("prior")) %>%
mutate(type="posterior")
) %>%
pivot_longer(-type)
# Show that we now have 4000 draws of prior and posterior for each parameter
m1.draws %>%
count(type, name)
#> type name n
#> <chr> <chr> <int>
#> 1 posterior b_Intercept 4000
#> 2 posterior b_Sepal.Length 4000
#> 3 posterior b_Sepal.Width 4000
#> 4 posterior b_Speciesversicolor 4000
#> 5 posterior b_Speciesvirginica 4000
#> 6 posterior sigma 4000
#> 7 prior b_Intercept 4000
#> 8 prior b_Sepal.Length 4000
#> 9 prior b_Sepal.Width 4000
#> 10 prior b_Speciesversicolor 4000
#> 11 prior b_Speciesvirginica 4000
#> 12 prior sigma 4000
# Show one draw for each parameter and prior
m1.draws %>%
group_by(type, name) %>%
slice(1) %>% print(n=Inf)
#> type name value
#> <chr> <chr> <dbl>
#> 1 posterior b_Intercept -1.04
#> 2 posterior b_Sepal.Length 0.258
#> 3 posterior b_Sepal.Width 0.00865
#> 4 posterior b_Speciesversicolor 0.833
#> 5 posterior b_Speciesvirginica 1.33
#> 6 posterior sigma 0.195
#> 7 prior b_Intercept 4.45
#> 8 prior b_Sepal.Length 0.301
#> 9 prior b_Sepal.Width -0.178
#> 10 prior b_Speciesversicolor 0.0645
#> 11 prior b_Speciesvirginica -0.207
#> 12 prior sigma 1.61
# Data frame of true parameter values (assuming we just have this)
true.params = tibble(
name = sort(unique(m1.draws$name)),
value = c(-0.6, 0.1, 0.02, 0.7, 1.3, 0.3),
type="true value"
)
# Plot using ggdist
p1 = m1.draws %>%
bind_rows(true.params) %>%
ggplot(aes(value, name, colour=type)) +
stat_slab(data = . %>% filter(type != "true value"),
aes(slab_colour=type), fill=NA, normalize="groups",
scale=0.7, slab_size=0.5, fatten_point=1.5) +
geom_segment(data=. %>% filter(type=="true value"),
aes(x=value, xend=value, y=name, yend=..y.. + 0.7),
colour="grey10", size=0.3) +
coord_cartesian(xlim=c(-3,3)) +
labs(x=NULL, y=NULL, colour=NULL, fill=NULL, slab_colour=NULL) +
theme(panel.grid.major.y=element_line(colour="grey10", size=0.2))
# Plot using plain ggplot
p2 = m1.draws %>%
ggplot(aes(value, colour=type)) +
geom_vline(data=true.params, aes(xintercept=value), colour="grey10", size=0.3) +
geom_density(aes(y=after_stat(scaled))) +
facet_wrap(vars(name), scales="free") +
labs(x=NULL, y=NULL, colour=NULL, fill=NULL, slab_colour=NULL) +
theme(axis.text.y=element_blank(),
axis.ticks.y=element_blank())
p1 / p2
Created on 2022-02-01 by the reprex package (v2.0.1)