Getting rhats for posterior predictions, using tidybayes

I am trying to calculate rhats for posterior predictions from a model fit using brms. Based on this old thread, I am trying to do this using tidybayes.

As a MWE, we can use the data/model in the tidy-brms vignette:

set.seed(5)
n = 10
n_condition = 5
ABC =
  tibble(
    condition = rep(c("A","B","C","D","E"), n),
    response = rnorm(n * 5, c(0,1,2,1,-1), 0.5)
  )

m = brm(
  response ~ (1|condition), 
  data = ABC, 
  prior = c(
    prior(normal(0, 1), class = Intercept),
    prior(student_t(3, 0, 1), class = sd),
    prior(student_t(3, 0, 1), class = sigma)
  ),
  control = list(adapt_delta = .99)
)

If I call spread_draws() on the resulting model, it returns a tibble that has the chain information stored as .chains:

> m %>%
+   spread_draws(r_condition[condition, term]) %>%
+   head()
# A tibble: 6 × 6
# Groups:   condition, term [1]
  condition term      r_condition .chain .iteration .draw
  <chr>     <chr>           <dbl>  <int>      <int> <int>
1 A         Intercept       0.683      1          1     1
2 A         Intercept      -0.779      1          2     2
3 A         Intercept      -0.580      1          3     3
4 A         Intercept      -0.717      1          4     4
5 A         Intercept      -0.786      1          5     5
6 A         Intercept      -0.818      1          6     6

However, if I call add_epred_draws(), the .chains information is all missing:

> ABC %>%
+   modelr::data_grid(condition) %>%
+   add_epred_draws(m) %>%
+   head()
# A tibble: 6 × 6
# Groups:   condition, .row [1]
  condition  .row .chain .iteration .draw .epred
  <chr>     <int>  <int>      <int> <int>  <dbl>
1 A             1     NA         NA     1  0.406
2 A             1     NA         NA     2  0.114
3 A             1     NA         NA     3  0.361
4 A             1     NA         NA     4  0.282
5 A             1     NA         NA     5  0.253
6 A             1     NA         NA     6  0.596

Why is it missing? I need the .chains info to get the rhats.

@Solomon

1 Like

I believe this is normal behavior for add_epred_draws(), but I don’t recall why. We might need to look to @mjskay for the answer.

There are open issues in tidybayes and brms repos for this

Discussions of those issues include workarounds

2 Likes

Thanks, the fix_draws() function described in the tidybayes issue thread worked perfectly.

1 Like

For those interested, here’s how I solved the issue to get rhats for posterior predictions (actually epreds). It uses fix_draws() from the thread linked above.

Note that I’m new to the posterior package, and I think there has to be a more straightforward way of creating the rvar object. If someone could weigh in on what that would be, I’d appreciate it.

preds <- fix_draws(m, modelr::data_grid(ABC, condition),
                   func = tidybayes::epred_draws) %>%
  group_by(condition, .row) %>%
  nest() %>%
  mutate(data = map(data, ~.x %>%
                      dplyr::select(.epred, .chain, .iteration) %>%
                      pivot_wider(names_from = .chain, values_from = .epred) %>%
                      dplyr::select(-.iteration) %>%
                      as.matrix() %>%
                      as.array())) %>%
  mutate(rv   = map(data, posterior::rvar, with_chains = TRUE)) %>%
  mutate(rhat = map_dbl(rv, posterior::rhat))
1 Like

I found nest_rhat() to simplify the code above.

preds <- fix_draws(m, modelr::data_grid(ABC, condition),
                   func = tidybayes::epred_draws) %>%
  group_by(condition, .row) %>%
  nest_rvars() %>%
  mutate(rhat = map_dbl(.epred, posterior::rhat))
3 Likes

Thanks for pointing to the relevant issues @avehtari. For the sake of maintainability in {tidybayes}, my preference is to wait to wait for the {brms} issue to be solved so that there is a canonical way to get that information. I believe this also relates to New draws_tensor format? · Issue #349 · stan-dev/posterior · GitHub

1 Like