I have defined a model to which I have added order restrictions in the transformed parameters
block.
When sampling priors (in brm(sample_priors = "only")
I get a divergence warning that I don’t get when sampling from the posterior:
Warning: 2919 of 4000 (73.0%) transitions ended with a divergence.
See https://mc-stan.org/misc/warnings for details.
prior predictive check, rhat, plots of chains all look good (ess is low…).
Additionally, when sampling from the posterior, when I do get a divergence warning its for ~2 transitions (<0.1%).
(Code and example below)
Reading here, a source of divergence can be:
- Check your priors. If the model is sampling heavily […] on the boundaries of parameter constraints, this is a bad sign.
However, given the priors and order restrictions, sampling from the prior near the boundary is unavoidable.
So… is it safe to ignore this warning? Is there something I should do?
Code
library(brms)
library(posterior)
library(ggplot2)
library(ggdist)
library(bayestestR)
# Prep data ---------------------------------------------------------------
data("mtcars")
mtcars$cyl <- factor(mtcars$cyl)
contrasts(mtcars$cyl) <- contr.orthonorm
# Set Priors --------------------------------------------------------------
get_prior(mpg ~ cyl, data = mtcars)
priors <-
set_prior("normal(20, 5)", class = "Intercept") +
set_prior("normal(0, 10)", class = "sigma", lb = 0) +
# Priors on the paired diffs
set_prior("normal(0, 5)", class = "b", coef = c("cyl1", "cyl2"))
## Set additional constrains ------------------------------------
# We also have a prior that mpg[cyl=6] < mpg[cyl=4]
# We need to edit the stan code!
# Get the weights for the contrasts
contr.cyl <- contrasts(mtcars$cyl)
wts <- c(1, -1, 0)
diff_cyl_4vs6 <- c(t(contr.cyl) %*% wts)
diff_cyl_4vs6
#> [1] 0.7071068 1.2247449
scode <-
# Pass the weights
stanvar(diff_cyl_4vs6, name = "diff_cyl_4vs6") +
# Add some code to transformed parameters
# 1. Define: cyl_4vs6 cannot be smaller than 0
stanvar(scode = "real<lower=0> cyl_4vs6;", block = "tparameters") +
# 2. Compute cyl_4vs6 with the weights
stanvar(scode = "cyl_4vs6 = b[1] * diff_cyl_4vs6[1] + b[2] * diff_cyl_4vs6[2];", block = "tparameters")
# Validate
make_stancode(mpg ~ cyl,
data = mtcars,
prior = priors,
stanvars = scode)
#> // generated with brms 2.17.0
#> functions {
#> }
#> data {
#> int<lower=1> N; // total number of observations
#> vector[N] Y; // response variable
#> int<lower=1> K; // number of population-level effects
#> matrix[N, K] X; // population-level design matrix
#> int prior_only; // should the likelihood be ignored?
#> vector[2] diff_cyl_4vs6;
#> }
#> transformed data {
#> int Kc = K - 1;
#> matrix[N, Kc] Xc; // centered version of X without an intercept
#> vector[Kc] means_X; // column means of X before centering
#> for (i in 2:K) {
#> means_X[i - 1] = mean(X[, i]);
#> Xc[, i - 1] = X[, i] - means_X[i - 1];
#> }
#> }
#> parameters {
#> vector[Kc] b; // population-level effects
#> real Intercept; // temporary intercept for centered predictors
#> real<lower=0> sigma; // dispersion parameter
#> }
#> transformed parameters {
#> real lprior = 0; // prior contributions to the log posterior
#> real<lower=0> cyl_4vs6;
#> cyl_4vs6 = b[1] * diff_cyl_4vs6[1] + b[2] * diff_cyl_4vs6[2];
#> lprior += normal_lpdf(b[1] | 0, 5);
#> lprior += normal_lpdf(b[2] | 0, 5);
#> lprior += normal_lpdf(Intercept | 20, 5);
#> lprior += normal_lpdf(sigma | 0, 10)
#> - 1 * normal_lccdf(0 | 0, 10);
#> }
#> model {
#> // likelihood including constants
#> if (!prior_only) {
#> target += normal_id_glm_lpdf(Y | Xc, Intercept, b, sigma);
#> }
#> // priors including constants
#> target += lprior;
#> }
#> generated quantities {
#> // actual population-level intercept
#> real b_Intercept = Intercept - dot_product(means_X, b);
#> }
# Looks good!
## Predictive Prior Check ----------------------------------------
mod_prior <- brm(mpg ~ cyl, data = mtcars,
stanvars = scode,
sample_prior = "only",
prior = priors,
backend = "cmdstanr")
#>
#> Warning: 2919 of 4000 (73.0%) transitions ended with a divergence.
#> See https://mc-stan.org/misc/warnings for details.
#>
diagnostic_posterior(mod_prior,
effects = "all", component = "all")
#> Parameter Rhat ESS MCSE
#> 1 b_cyl1 1.003784 727.6092 0.1739313
#> 2 b_cyl2 1.005846 471.5054 0.1699110
#> 3 b_Intercept 1.004845 761.9795 0.1831934
#> 4 sigma 1.001598 1113.7852 0.1822975
pp_check(mod_prior)
# Plotting the transformed parameter
rvar_ests_prior <- mod_prior |>
posterior_epred(newdata = data.frame(cyl = factor(c(4, 6, 8)))) |>
rvar() |>
setNames(nm = paste0("cyl", c(4, 6, 8)))
ggplot() +
stat_slab(aes(xdist = rvar_ests_prior["cyl4"] - rvar_ests_prior["cyl6"]))
# Fit Model ---------------------------------------------------------------
mod_post <- update(mod_prior, sample_prior = FALSE)
#>
#> Warning: 2 of 4000 (0.0%) transitions ended with a divergence.
#> See https://mc-stan.org/misc/warnings for details.
#>