Conditional_effects adding points in hierarchy?

I have a hierarchical model of the form

y ~ x + (1|ID)

Where x is categorical (3 levels). And every ID has ~ 50 measurements.

I make a conditional plot using

conditional_effects(my.model, effects='x')

And I get three dots with error bars, which is what I want.

Now I’d like to add jittered points for individual IDs. Is that possible? I’ve so far only managed to set “points=TRUE”. But that has then added each single measurement to the plot, rather than a summary one per ID.

Many thanks!

You might need to work directly with ggplot() for this.

The code below lays out a few options. Let me know if this is what you had in mind.

library(tidyverse)
theme_set(theme_bw())
library(tidybayes)
library(brms)

# Fit a model to work with
m = brm(count ~ Trt + (1|patient),
        data=epilepsy,
        family=poisson(),
        backend="cmdstanr",
        cores=4,
        file="model")

# conditional_effects returns a list of data frames, one for each effect
ce = conditional_effects(m, effects="Trt") # Default: doesn't include random effects
ce.re = conditional_effects(m, effects="Trt", re_formula=NULL) # Includes random effects

# Add points to conditional_effects plot. But note we can only
#  add them on top of existing error bars, so we'll shift them over.
plot(ce)[[1]] +
  geom_point(data=epilepsy, aes(as.numeric(Trt) + 0.25, count, colour=patient),
             inherit.aes=FALSE, alpha=0.5,
             position=position_jitter(h=0, w=0.07)) +
  guides(colour="none")


# Look at Trt data frame returned by conditional_effects()
ce[["Trt"]]
#>   Trt    count patient cond__ effect1__ estimate__      se__  lower__  upper__
#> 1   0 8.254237      NA      1         0   5.832204 1.0064498 4.013846 8.414122
#> 2   1 8.254237      NA      1         1   4.409573 0.8038234 3.092909 6.170921

ce.re[["Trt"]]
#>   Trt    count patient cond__ effect1__ estimate__     se__   lower__  upper__
#> 1   0 8.254237      NA      1         0   5.198796 4.113585 1.0277338 36.36602
#> 2   1 8.254237      NA      1         1   4.016772 3.354791 0.7854074 29.40605
# Create a plot from scratch using Trt conditional_effects data frame, so we 
#  can control the order in which each geom is plotted
ggplot() +
  geom_point(data=epilepsy, aes(Trt, count, colour=patient),
             size=1, alpha=0.5,
             position=position_jitter(h=0, w=0.05)) +
  geom_errorbar(data=ce[["Trt"]], aes(x=Trt, ymin=lower__, ymax=upper__),
                width=0.15) +
  geom_point(data=ce[["Trt"]], aes(x=Trt, y=estimate__), size=3) +
  scale_y_continuous(expand=expansion(c(0.01,0.05))) +
  guides(colour="none") +
  # Clip y-scale just for illustration
  coord_cartesian(ylim=c(0,20))

# Create plot by summarizing posterior draws directly, without 
#  using conditional_effects

# Posterior draws
draws.no.re = epred_draws(m, newdata=tibble(Trt=factor(c(0,1))), re_formula=NA)
draws.re = epred_draws(m, newdata=crossing(Trt=factor(c(0,1)),
                                           patient=unique(epilepsy$patient)),
                       re_formula=NULL)

# Summarize to get median and credible intervals
draws.no.re.est = draws.no.re %>%
  group_by(Trt) %>%
  summarise(enframe(quantile(.epred, c(0.025,0.5,0.975)) %>%
                      set_names(c("lwr","est","upr")))) %>%
  spread(name, value) %>% 
  ungroup()

draws.re.est = draws.re %>%
  group_by(Trt) %>%
  summarise(enframe(quantile(.epred, c(0.025,0.5,0.975)) %>%
                      set_names(c("lwr","est","upr")))) %>%
  spread(name, value) %>% 
  ungroup()
# Create plot using summarized draws
ggplot() +
  geom_point(data=epilepsy, aes(Trt, count, colour=patient),
             size=1, alpha=0.5,
             position=position_jitter(h=0, w=0.05)) +
  geom_errorbar(data=draws.no.re.est, aes(x=Trt, y=est, ymin=lwr, ymax=upr),
                width=0.1) +
  geom_point(data=draws.no.re.est, aes(Trt, y=est), size=2) +
  # Add conditional_effects data just to show our calculation matches
  geom_pointrange(data=ce[["Trt"]], 
                  aes(x=as.numeric(Trt) + 0.2, y=estimate__, 
                      ymin=lower__, ymax=upper__), 
                  colour="red") +
  scale_y_continuous(expand=expansion(c(0.01,0.05))) +
  guides(colour="none") +
  # Clip y-scale just for illustration
  coord_cartesian(ylim=c(0,20))

Created on 2022-11-28 with reprex v2.0.2