Some Bayesplot/ggplot2 questions relating to plotting priors and posteriors in same panel

Hello Stanizens,

I’m looking for some help this time with regard to some plot engineering.

I’m trying to make a plot consisting of panels of prior and posterior density pairs, with vertical lines of the true \theta values (the data generating process is known in the case of this inference).

Following various introductory guides, I’ve been able to make a plot of just the posterior densities extracted from a CmdStanR fit object. I run code along the lines of below,

library(cmdstanr)
library(bayesplot)

fit <- readRDS(...)
posterior <- fit$draws()
post_plots <- mcmc_dens(fit$draws(c("u_M", "a_SD", "a_DS", "a_M", "a_MSC", "k_S_ref", "k_D_ref", "k_M_ref", "Ea_S", "Ea_D", "Ea_M")))

for the following quick and unpolished plot:
image

After this step, I’m a little at a loss for how to

  1. plot the priors on top of the posteriors and
  2. plot vertical lines representing the true theta value
    across all of the parameter panels in straightforward fashion.

I do have the prior predictive samples in the generated quantities section of the code:

generated quantities {
  ...
  // Obtain prior predictive samples. 
  real u_M_prior_pred = normal_lb_ub_rng(u_M_prior_dist_params[1], u_M_prior_dist_params[1] * prior_scale_factor, u_M_prior_dist_params[2], u_M_prior_dist_params[3]);
  real a_SD_prior_pred = normal_lb_ub_rng(a_SD_prior_dist_params[1], a_SD_prior_dist_params[1] * prior_scale_factor, a_SD_prior_dist_params[2], a_SD_prior_dist_params[3]);
  real a_DS_prior_pred = normal_lb_ub_rng(a_DS_prior_dist_params[1], a_DS_prior_dist_params[1] * prior_scale_factor, a_DS_prior_dist_params[2], a_DS_prior_dist_params[3]);
  real a_M_prior_pred = normal_lb_ub_rng(a_M_prior_dist_params[1], a_M_prior_dist_params[1] * prior_scale_factor, a_M_prior_dist_params[2], a_M_prior_dist_params[3]);
  real a_MSC_prior_pred = normal_lb_ub_rng(a_MSC_prior_dist_params[1], a_MSC_prior_dist_params[1] * prior_scale_factor, a_MSC_prior_dist_params[2], a_MSC_prior_dist_params[3]);
  real k_S_ref_prior_pred = normal_lb_ub_rng(k_S_ref_prior_dist_params[1], k_S_ref_prior_dist_params[1] * prior_scale_factor, k_S_ref_prior_dist_params[2], k_S_ref_prior_dist_params[3]);
  real k_D_ref_prior_pred = normal_lb_ub_rng(k_D_ref_prior_dist_params[1], k_D_ref_prior_dist_params[1] * prior_scale_factor, k_D_ref_prior_dist_params[2], k_D_ref_prior_dist_params[3]);
  real k_M_ref_prior_pred = normal_lb_ub_rng(k_M_ref_prior_dist_params[1], k_M_ref_prior_dist_params[1] * prior_scale_factor, k_M_ref_prior_dist_params[2], k_M_ref_prior_dist_params[3]);
  real Ea_S_prior_pred = normal_lb_ub_rng(Ea_S_prior_dist_params[1], Ea_S_prior_dist_params[1] * prior_scale_factor, Ea_S_prior_dist_params[2], Ea_S_prior_dist_params[3]);
  real Ea_D_prior_pred = normal_lb_ub_rng(Ea_D_prior_dist_params[1], Ea_D_prior_dist_params[1] * prior_scale_factor, Ea_D_prior_dist_params[2], Ea_D_prior_dist_params[3]);
  real Ea_M_prior_pred = normal_lb_ub_rng(Ea_M_prior_dist_params[1], Ea_M_prior_dist_params[1] * prior_scale_factor, Ea_M_prior_dist_params[2], Ea_M_prior_dist_params[3]);
}

Ideally, I’d like to do things in vectorised fashion in something along the lines of post_plot + priors_plot + vertical_lines_plot, but perhaps I need to do this iteratively along the lines of the following pseudocode,

total_param_count <- 11
plot_list <- vector(mode = "list", length = total_param_count)
param_names_list = c("u_M", ...)
names(plot_list) <- param_names_list

for (i in 1:total_param_count) {
    plot_list[i] <- custom_function_for_single_prior_posterior_plot(prior_draws[param_names_list[i]], posterior_draws[param_names_list[i]], ...)
}

?
Has anyone had experience doing a plot like this, and if so, would you happen to have example code you can share or point me in the right direction? Ultimately, I’m trying to get to a plot that resembles the following:

On a less-involved note, how does one tune the alpha of the fill in mcmc_dens? mcmc_dens by default does not come with an alpha argument.

Thank you all kindly for your patient help.

1 Like

I wanted to follow up on the part of the above question involving plotting posterior densities estimated from Stan posterior histograms on top of priors. After some weeks, I was able to figure out a (potentially inelegant and clunky) approach hacked together using Matplotlib (due to the “ragged” density concentrations of the priors and posteriors, I found that it was easier to force things together in Matplotlib with rather than in Seaborn or ggplot2, but I welcome ggplot2/Seaborn examples if someone has that). The approach uses Scipy’s Gaussian KDE method.

The code looks something like the following:

num_params = len(SCON_C_priors_details)
ncols = 4
nrows = int(num_params / ncols) + 1
fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 4, nrows * 4))
axes = np.atleast_2d(axes)
k = 0
for i, row in enumerate(axes):
    for j, ax in enumerate(row):
        if k < num_params:
            key = keys[k]
            post_kde_density = stats.gaussian_kde(df_post.loc[:, key])
            ax.plot(x[:, k], pdf_prior[:, k])
            ax.plot(x[:, k], post_kde_density(x[:, k]))
            ax.axvline(true_theta[key], color='gray')
            ax.set_xlabel(key)
            ax.set_ylabel('Density')
            ax.ticklabel_format(style='sci', scilimits=(-2,4), axis='both', useMathText='True')
        else:
            fig.delaxes(axes[i, j])
        k += 1  
plt.tight_layout()

to produce this figure

Anyhow, just following up for posterity in case this is helpful for someone in the future.

1 Like

do you know if this would work with plotnine and matplotlib?

I hadn’t heard of Plotnine before, so I just looked it up, thanks for the tip on that. It seems like this could work if I get the prior pdf and and Gaussian KDE density outputs into a data frame. I’ll describe some more of the code I used then to get to some pseudocode to expand on for a potential plotnine implementation.

So, in this project I’m working on, I’m conducting variational inference on a model to obtain some approximate posteriors. I wanted to generate some plots describing Stan NUTS results to serve as a baseline comparison for the VI implementation. The VI implementation is in PyTorch; my priors were coded as PyTorch distribution objects. In my particular case, one of the distributions I used was a custom TruncatedNormal implementation (since PyTorch does not have a native trunc norm). In my above example plot, the priors specifically follow a custom bounded normal distribution, so I’ll show some of the truncated normal code for simplicity. So I establish a mean-field prior distribution object with something like

import torch
import torch.distributions as D

from TruncatedNormal import * #Need to download module into appropriate directory first.

p_theta = TruncatedNormal(loc = prior_means_tensor, scale = prior_sds_tensor, a = lower_bounds_tensor, b = upper_bounds_tensor)

To create my posterior distribution object, here’s where a little bit of a hacky portion comes in. To plot where the largest density is going to be for a distribution, we can do a coarse approach to plot some standard deviations in both directions of the mean. When we do these posterior-prior pair plots, we want to get in both the largest masses of the prior and posterior densities. We can do this in a straightforward manner with PyTorch distribution objects that have class methods for returning the moments. The Gaussian KDE function in SciPy does not estimate any moments. Consequently, I rely on the summary statistics from Stan to establish an “approximate posterior” PyTorch object to cast a net for the meat of the posterior density.

q_theta = TruncatedNormal(loc = stan_summary_means, scale = stan_summary_sds, a = lower_bounds_tensor, b = upper_bounds_tensor)

With the prior and posterior distribution objects, we can guess at our range to cover on the x-axis.

x0 = torch.min(q_theta.mean - 4 * q_theta.stddev, p_theta.mean - 4 * p_theta.stddev)
x0 = torch.max(x0, lower).detach()

x1 = torch.max(q_theta.mean + 4 * q_theta.stddev, p_theta.mean + 4 * p_theta.stddev)
x1 = torch.min(x1, upper).detach()

num_pts = 10000
x = torch.from_numpy(np.linspace(x0, x1, num_pts)) #We now have a matrix of x spans here covering all the parameters and can get slice per figure panel.
pdf_prior = torch.exp(p_theta.log_prob(x)).detach() #The actual pdf density function is created with this line. 

After we have our x span, we drop the use of q_theta and then can move on to using a Gaussian KDE for the Stan histogram. We can read in a data frame of Stan samples in which all chains have been collapsed into one column

import pandas as pd

df = pd.read_csv("stan_samples.csv")

And then we can call SciPy’s Gaussian KDE function for each column of the df. In my experience, this has to be done as a loop. Otherwise, SciPy’s Gaussian KDE will try to do a multivariate KDE for the entire df. So, then you have something like

from scipy import stats

post_density_df = pd.DataFrame()
prior_density_df = pd.DataFrame()

for i in range(df.columns):
   key = keys[i] #Some list storing strings of parameter names.
   prior_density 
   post_density_func = stats.gaussian_kde(df.loc[:, key])
   post_density_df[key] = post_density_func(x[:, i])
   prior_density_df[key] = pdf_prior(x[:, i])

And then you could merge post_density_df and prior_density_df or put them into one data frame in the first place. You can also assemble the prior density data frame outside the loop since PyTorch should be vectorized enough to handle prior_density_df = pdf_prior(x). Having trouble visualizing this data frame right now in the exact format, but once you have the appropriate x values linked to a posterior-prior column pair, that should put one on the road to doing this in “grammar of graphics” style with Plotnine and Seaborn. Perhaps you can also create a data frame for each x, prior, and posterior KDE trio and then loop through each set of trios per figure panel.

So, I ultimately did the above in Matplotlib because I already had to do inelegant looping for stats.gaussian_kde to target the individual data frame columns, but thinking about things, you should be able to get things into a set of data frames with one corresponding for each parameter or perhaps one larger data frame (just having trouble visualizing that larger data frame right now) for use in Plotnine.

Hope the above helps and offers some ideas!

1 Like

Below is an example of extracting draws of the prior and posterior for each parameter and plotting with ggplot2. I’ve fit the model with brms, as I don’t know Stan well enough to program it directly in Stan.

# Set up and fit a sample model
library(tidyverse)
library(patchwork)
library(brms)
theme_set(theme_bw())

# Model formula
form = Petal.Width ~ Sepal.Width + Sepal.Length + Species

m1 = brm(formula=form,
         data=iris, 
         family=gaussian(),
         prior = c(prior(normal(0,0.2), class="b", coef="Speciesversicolor"),
                   prior(normal(0,0.2), class="b", coef="Speciesvirginica"),
                   prior(normal(0,0.5), class="b", coef="Sepal.Width"),
                   prior(normal(0,0.3), class="b", coef="Sepal.Length")),
         backend="cmdstanr",
         sample_prior=TRUE)
# Get posterior draws 
m1.draws = as_draws_df(m1)

# Keep just the draws for the parameters and corresponding priors
m1.draws = m1.draws %>% 
  select(matches("^(prior|b_|sigma)")) %>% 
  rename(prior_b_Intercept=prior_Intercept)

# Stack the prior and posterior draws for each parameter
#  and reshape to long format
m1.draws = bind_rows(
    m1.draws %>% 
      select(matches("prior")) %>% 
      mutate(type="prior") %>% 
      rename_all(~gsub("prior_", "", .)),
    m1.draws %>% 
      select(-matches("prior")) %>% 
      mutate(type="posterior")
  ) %>% 
  pivot_longer(-type)

# Show that we now have 4000 draws of prior and posterior for each parameter
m1.draws %>% 
  count(type, name)

#>    type      name                    n
#>    <chr>     <chr>               <int>
#>  1 posterior b_Intercept          4000
#>  2 posterior b_Sepal.Length       4000
#>  3 posterior b_Sepal.Width        4000
#>  4 posterior b_Speciesversicolor  4000
#>  5 posterior b_Speciesvirginica   4000
#>  6 posterior sigma                4000
#>  7 prior     b_Intercept          4000
#>  8 prior     b_Sepal.Length       4000
#>  9 prior     b_Sepal.Width        4000
#> 10 prior     b_Speciesversicolor  4000
#> 11 prior     b_Speciesvirginica   4000
#> 12 prior     sigma                4000

# Show one draw for each parameter and prior
m1.draws %>% 
  group_by(type, name) %>% 
  slice(1) %>% print(n=Inf)

#>    type      name                   value
#>    <chr>     <chr>                  <dbl>
#>  1 posterior b_Intercept         -1.04   
#>  2 posterior b_Sepal.Length       0.258  
#>  3 posterior b_Sepal.Width        0.00865
#>  4 posterior b_Speciesversicolor  0.833  
#>  5 posterior b_Speciesvirginica   1.33   
#>  6 posterior sigma                0.195  
#>  7 prior     b_Intercept          4.45   
#>  8 prior     b_Sepal.Length       0.301  
#>  9 prior     b_Sepal.Width       -0.178  
#> 10 prior     b_Speciesversicolor  0.0645 
#> 11 prior     b_Speciesvirginica  -0.207  
#> 12 prior     sigma                1.61
# Data frame of true parameter values (assuming we just have this)
true.params = tibble(
  name = sort(unique(m1.draws$name)),
  value = c(-0.6, 0.1, 0.02, 0.7, 1.3, 0.3),
  type="true value"
)
# Plot using ggdist
p1 = m1.draws %>% 
  bind_rows(true.params) %>% 
  ggplot(aes(value, name, colour=type)) +
    stat_slab(data = . %>% filter(type != "true value"),
                 aes(slab_colour=type), fill=NA, normalize="groups", 
                 scale=0.7, slab_size=0.5, fatten_point=1.5) +
    geom_segment(data=. %>% filter(type=="true value"), 
                 aes(x=value, xend=value, y=name, yend=..y.. + 0.7),
                 colour="grey10", size=0.3) +
    coord_cartesian(xlim=c(-3,3)) +
    labs(x=NULL, y=NULL, colour=NULL, fill=NULL, slab_colour=NULL) +
    theme(panel.grid.major.y=element_line(colour="grey10", size=0.2))
  
# Plot using plain ggplot
p2 = m1.draws %>% 
  ggplot(aes(value, colour=type)) +
    geom_vline(data=true.params, aes(xintercept=value), colour="grey10", size=0.3) +
    geom_density(aes(y=after_stat(scaled))) + 
    facet_wrap(vars(name), scales="free") +
    labs(x=NULL, y=NULL, colour=NULL, fill=NULL, slab_colour=NULL) +
    theme(axis.text.y=element_blank(),
          axis.ticks.y=element_blank())
  
p1 / p2

Created on 2022-02-01 by the reprex package (v2.0.1)

2 Likes

Thanks for sharing this, @joels! The above will be pretty useful to have as a reference for someone whose inference workflow is entirely in Stan and is also drawing prior predictive check samples for a more “proper” NUTS prior/posterior comparison. In my case, my prior density was not derived from samples and parametric since the inference method this project is testing is a mean-field VI method.