Brms ploting the effects for a categorical variable

I would like to know if it is possible to plot for each of the two levels of a categorical variable something like a violin plot which is similar to the ‘area’ option in mcmc_plot. The conditional_effects function reveals only the CI but I have not found something like the area option. But maybe i have not fully understand what the function is doing in case of an categorical variable.
Thank you for any help.

1 Like

conditional_effects is generating prediction data frames for each predictor variable, getting posterior draws at those prediction values, and creating whisker plots based on summaries of those posterior draws. We can generate the posterior draws ourselves and then create violin plots from them. Here’s an example:

library(tidyverse)
library(brms)
library(tidybayes)
theme_set(theme_bw())

# Create a model to work with
m = brm(Petal.Width ~ Species + Petal.Length, data=iris,
        prior=prior(normal(0,1), class="b"),
        cores=4, backend="cmdstanr")

ce=conditional_effects(m)

Look at the structure of the object returned by conditional_effects. It’s a list with two data frames, one for each of the predictor variables in the regression.

str(ce)
#> List of 2
#>  $ Species     :'data.frame':    3 obs. of  9 variables:
#>   ..$ Species     : Factor w/ 3 levels "setosa","versicolor",..: 1 2 3
#>   ..$ Petal.Width : num [1:3] 1.2 1.2 1.2
#>   ..$ Petal.Length: num [1:3] 3.76 3.76 3.76
#>   ..$ cond__      : Factor w/ 1 level "1": 1 1 1
#>   ..$ effect1__   : Factor w/ 3 levels "setosa","versicolor",..: 1 2 3
#>   ..$ estimate__  : num [1:3] 0.786 1.207 1.604
#>   ..$ se__        : num [1:3] 0.0813 0.0305 0.0667
#>   ..$ lower__     : num [1:3] 0.63 1.15 1.47
#>   ..$ upper__     : num [1:3] 0.951 1.269 1.729
#>   ..- attr(*, "effects")= chr "Species"
#>   ..- attr(*, "response")= chr "Petal.Width"
#>   ..- attr(*, "surface")= logi FALSE
#>   ..- attr(*, "categorical")= logi FALSE
#>   ..- attr(*, "ordinal")= logi FALSE
#>   ..- attr(*, "points")='data.frame':    150 obs. of  4 variables:
#>   .. ..$ Species  : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...
#>   .. ..$ resp__   : num [1:150] 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
#>   .. ..$ cond__   : Factor w/ 1 level "1": 1 1 1 1 1 1 1 1 1 1 ...
#>   .. ..$ effect1__: Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...
#>  $ Petal.Length:'data.frame':    100 obs. of  9 variables:
#>   ..$ Petal.Length: num [1:100] 1 1.06 1.12 1.18 1.24 ...
#>   ..$ Petal.Width : num [1:100] 1.2 1.2 1.2 1.2 1.2 ...
#>   ..$ Species     : Factor w/ 3 levels "setosa","versicolor",..: 1 1 1 1 1 1 1 1 1 1 ...
#>   ..$ cond__      : Factor w/ 1 level "1": 1 1 1 1 1 1 1 1 1 1 ...
#>   ..$ effect1__   : num [1:100] 1 1.06 1.12 1.18 1.24 ...
#>   ..$ estimate__  : num [1:100] 0.138 0.152 0.166 0.18 0.194 ...
#>   ..$ se__        : num [1:100] 0.0291 0.0282 0.0274 0.0268 0.0261 ...
#>   ..$ lower__     : num [1:100] 0.0776 0.0942 0.1099 0.1251 0.1401 ...
#>   ..$ upper__     : num [1:100] 0.195 0.208 0.22 0.232 0.245 ...
#>   ..- attr(*, "effects")= chr "Petal.Length"
#>   ..- attr(*, "response")= chr "Petal.Width"
#>   ..- attr(*, "surface")= logi FALSE
#>   ..- attr(*, "categorical")= logi FALSE
#>   ..- attr(*, "ordinal")= logi FALSE
#>   ..- attr(*, "points")='data.frame':    50 obs. of  4 variables:
#>   .. ..$ Petal.Length: num [1:50] 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
#>   .. ..$ resp__      : num [1:50] 0.2 0.2 0.2 0.2 0.2 0.4 0.3 0.2 0.2 0.1 ...
#>   .. ..$ cond__      : Factor w/ 1 level "1": 1 1 1 1 1 1 1 1 1 1 ...
#>   .. ..$ effect1__   : num [1:50] 1.4 1.4 1.3 1.5 1.4 1.7 1.4 1.5 1.4 1.5 ...
#>  - attr(*, "class")= chr "brms_conditional_effects"

The data frame for Species sets Petal.Length to its mean value when making predictions for Species.

ce$Species
#>      Species Petal.Width Petal.Length cond__  effect1__ estimate__       se__
#> 1     setosa    1.199333        3.758      1     setosa  0.7857913 0.08132856
#> 2 versicolor    1.199333        3.758      1 versicolor  1.2072069 0.03054717
#> 3  virginica    1.199333        3.758      1  virginica  1.6043561 0.06666792
#>     lower__   upper__
#> 1 0.6296771 0.9507451
#> 2 1.1466082 1.2691234
#> 3 1.4705383 1.7289340
mean(iris$Petal.Length)  
#> [1] 3.758

So, let’s set up a prediction data frame using the same values as conditional_effects:

pred.dat = tibble(Species = unique(iris$Species), 
                  Petal.Length=mean(iris$Petal.Length))

# Get posterior draws at the values in pred.dat
post.epred = epred_draws(m, newdata=pred.dat)

# Generate conditional effects plots. This is a list of two plots, one for Species and one for Petal.Length
pce = plot(ce, ask=FALSE, plot=FALSE)

Let start with the conditional_effects plot for Species and overlay a violin plot using the post.epred data frame of posterior draws we just created. This shows we’re replicating what conditional_effects is doing:

pce[["Species"]] + 
  geom_violin(data=post.epred, 
              inherit.aes=FALSE, fill=NA, colour="red",
              aes(Species, .epred))

Now let’s make our own stand-alone violin plot for each level of Species:

post.epred %>% 
  ggplot(aes(Species, .epred)) +
  geom_violin()

2 Likes

Great! Thank you so much. Exactly what I wished for. You are my hero today :)

1 Like