Is it possible to regroup conditional_effects outcome variable for clearer plot

I have an ordinal model and wish to plot conditional effects for an important variable. The problem is the outcome variable has 10 levels and the resulting plot is too difficult to interpret.

To pick a particularly severe example:

plot(conditional_effects(brmsmodel, "outvar", categorical = TRUE))

My question is - is it possible to regroup the outcome variable for a clearer plot (post-model fit) - i.e. to only have one curve for a combined outcome range of 0 to 3, 4 to 7, 8 to 10 for example? (I know I could refactor the levels before fitting the plot and get around the problem that way… but I’m trying not to do so)

I think the way to do this is to work with the posterior expectation values, collapse those into fewer categories, and then summarize the results into medians and credible intervals in the same way that conditional_effects does this. Here’s an example using fake data:

library(brms)
library(tidyverse)
library(tidybayes)
library(patchwork)
theme_set(theme_classic())

Generate fake data

Assume grade is a student’s grade in a course. gpa is the student’s grade-point average at the start of the course. We’ll predict course grade from gpa.

grade = c(paste0(rep(LETTERS[c(1:4)], each=3), c("+","","-")), "F")
d = data.frame(grade, num=seq(3.5,1.5,length=length(grade)))
d$grade = ordered(d$grade, levels=rev(unique(d$grade)))

set.seed(3)
d = d %>%
  group_by(grade) %>%
  summarise(gpa = rnorm(20, num, 0.8))
d %>% ggplot(aes(gpa, grade)) + 
  geom_point(position=position_jitter(h=0.1, w=0), 
             alpha=0.6, colour="red", size=1)

Ordinal model

m = brm(grade ~ gpa,
        data=d,
        family=cumulative(),
        prior=prior(normal(0,1), class="b"),
        cores=4, chains=4,
        backend="cmdstanr",
        file="grade-ordinal")

Get conditional effect of gpa

ce = conditional_effects(m, "gpa", categorical=TRUE)
# Run str(ce) to see the data frame that conditional_effects creates
plot(ce)

Now let’s recreate the plot above, but with fewer categories. We’ll collapse the categories down to four: A, B, C, or D/F.

First let’s make sure we understand what conditional_effects is doing so that we can implement the same method with our collapsed categories.

Reproduce conditional_effects plot manually

Get the values of gpa that conditional_effects is using for the plot. (Note: If you have multiple independent variables in your model, you’ll need a data frame that sets the other variables to specific values. You can just use the data frame created by conditional_effects, rather than create your own.)

pred.dat = tibble(gpa=seq(min(d$gpa), max(d$gpa), length=100))

Get expectations for the posterior draws

post.epred = epred_draws(m, newdata=pred.dat) %>% 
  select(-.chain, -.iteration)

Summarize the posterior draws (median and 95% credible interval) and plot them

ce2 = post.epred %>%
  group_by(gpa, .category) %>%
  summarise(enframe(quantile(.epred, probs=c(0.025,0.5,0.975)))) %>%
  pivot_wider(names_from=name, values_from=value) %>%
  ggplot(aes(gpa, colour=.category, fill=.category)) +
  geom_ribbon(aes(ymin=`2.5%`, ymax=`97.5%`), alpha=0.3, size=0) +
  geom_line(aes(y=`50%`)) +
  labs(title="Reproduce conditional_effects",
       fill="grade", colour="grade", y=NULL) +
  guides(colour="none", fill="none") +
  scale_y_continuous(limits=c(0, 0.85))

Show that we’ve reproduced the conditional_effects plot

plot(ce, plot=FALSE)[[1]] + 
  labs(title="conditional_effects") + scale_y_continuous(limits=c(0,0.85)) +
  ce2 +
  plot_layout(guides="collect")

Collapse outcome categories and summarize them

Here’s the post.epred data frame we generated above using epred_draws:

post.epred
# A tibble: 5,200,000 × 5
# Groups:   gpa, .row, .category [1,300]
      gpa  .row .draw .category .epred
    <dbl> <int> <int> <fct>      <dbl>
 1 0.0210     1     1 F          0.486
 2 0.0210     1     2 F          0.506
 3 0.0210     1     3 F          0.507
 4 0.0210     1     4 F          0.496
 5 0.0210     1     5 F          0.451
 6 0.0210     1     6 F          0.474
 7 0.0210     1     7 F          0.473
 8 0.0210     1     8 F          0.426
 9 0.0210     1     9 F          0.521
10 0.0210     1    10 F          0.504
# … with 5,199,990 more rows

There are 4,000 unique values of .draw (one for each of the 4,000 samples from the posterior that we generated when we fit the model), 100 unique values of .row (one for each unique value of gpa for which we extracted draws), and 13 unique values of .category (one for each possible grade) for a total of 4,000*100*13=5,200,000 rows.

Within a given .row and .draw there are 13 levels of .category. .epred is the predicted probability for each category. The total probability adds up to 1.0, because the probability of getting some grade is 1.0. You can see this below for .row 1 and .draw 1 and 2.

post.epred %>% 
  arrange(.draw, desc(.category)) %>% 
  filter(.row==1, .draw %in% 1:2) %>%
  group_by(.draw) %>% 
  mutate(.epred_cumulative = cumsum(.epred)) %>% 
  print(n=Inf)
# A tibble: 26 × 6
# Groups:   .draw [2]
      gpa  .row .draw .category  .epred .epred_cumulative
    <dbl> <int> <int> <fct>       <dbl>             <dbl>
 1 0.0210     1     1 A+        0.00113           0.00113
 2 0.0210     1     1 A         0.00227           0.00340
 3 0.0210     1     1 A-        0.00149           0.00489
 4 0.0210     1     1 B+        0.00304           0.00793
 5 0.0210     1     1 B         0.00489           0.0128 
 6 0.0210     1     1 B-        0.0103            0.0231 
 7 0.0210     1     1 C+        0.00801           0.0311 
 8 0.0210     1     1 C         0.0199            0.0510 
 9 0.0210     1     1 C-        0.0319            0.0830 
10 0.0210     1     1 D+        0.0787            0.162  
11 0.0210     1     1 D         0.125             0.287  
12 0.0210     1     1 D-        0.227             0.514  
13 0.0210     1     1 F         0.486             1      
14 0.0210     1     2 A+        0.00366           0.00366
15 0.0210     1     2 A         0.00472           0.00838
16 0.0210     1     2 A-        0.00644           0.0148 
17 0.0210     1     2 B+        0.00962           0.0244 
18 0.0210     1     2 B         0.00774           0.0322 
19 0.0210     1     2 B-        0.0101            0.0423 
20 0.0210     1     2 C+        0.0221            0.0644 
21 0.0210     1     2 C         0.0314            0.0958 
22 0.0210     1     2 C-        0.0472            0.143  
23 0.0210     1     2 D+        0.0540            0.197  
24 0.0210     1     2 D         0.105             0.302  
25 0.0210     1     2 D-        0.192             0.494  
26 0.0210     1     2 F         0.506             1

To get self-consistent model outputs for 4 categories instead of 13, we need to work at the level of individual .row and .draw combinations, and sum the probabilities for each group of categories. For example, let’s collapse the 13 grades down to 4 categories: A, B, C, or D/F.

post.epred.collapsed = post.epred %>% 
  mutate(category.collapsed = gsub("\\+|-", "", .category),
         category.collapsed = gsub("(D|F).*", "D/F", category.collapsed)) %>%  
  group_by(gpa, .row, .draw, category.collapsed) %>% 
  summarise(.epred = sum(.epred))
# A tibble: 1,600,000 × 5
# Groups:   gpa, .row, .draw [400,000]
      gpa  .row .draw category.collapsed  .epred
    <dbl> <int> <int> <chr>                <dbl>
 1 0.0210     1     1 A                  0.00489
 2 0.0210     1     1 B                  0.0182 
 3 0.0210     1     1 C                  0.0598 
 4 0.0210     1     1 D/F                0.917  
 5 0.0210     1     2 A                  0.0148 
 6 0.0210     1     2 B                  0.0275 
 7 0.0210     1     2 C                  0.101  
 8 0.0210     1     2 D/F                0.857  
 9 0.0210     1     3 A                  0.00877
10 0.0210     1     3 B                  0.0245 
# … with 1,599,990 more rows
# ℹ Use `print(n = ...)` to see more rows

Now we can do exactly what we did above to reproduce the conditional_effects plot, but here we calculate the median and 95% credible intervals by category.collapsed instead of .category. With four categories, the plot is much easier to read.

post.epred.collapsed %>%
  group_by(gpa, category.collapsed) %>%
  summarise(enframe(quantile(.epred, probs=c(0.025,0.5,0.975)))) %>%
  pivot_wider(names_from=name, values_from=value) %>%
  ggplot(aes(gpa, colour=category.collapsed, fill=category.collapsed)) +
  geom_ribbon(aes(ymin=`2.5%`, ymax=`97.5%`), alpha=0.3, size=0) +
  geom_line(aes(y=`50%`)) +
  labs(title="Conditional effects for collapsed grade categories",
       fill="grade", colour="grade", y=NULL) 

4 Likes

Wow @joels thats a really detailed answer. Thank you very much!