Predictions from a multilevel model, conditional on a single new observation in a subject not in the original data

This question is a more concrete follow-up on this question, were @martinmodrak was very helpful in recognizing the core of my question.

I have fitted a varying effects model to some data with multiple measurements for each individual. I expect that the outcome (y) is proportional to time. Here is similar, simulated data.

library(brms)
library(tidybayes)
library(ggplot2)
library(dplyr)

d <- expand.grid(time = 4:10, id = factor(1:10))


d$y <- lme4::simulate.formula(~0+time+(0+time|id), newdata = d,
                       newparams = list(beta = c(time = 1), 
                                        theta = c(id.time = 1), 
                                        sigma = 1),
                       family = gaussian,
                       seed = 1)[[1]]

head(d)
#>   time id         y
#> 1    4  1 3.0059659
#> 2    5  1 2.2575742
#> 3    6  1 1.6200366
#> 4    7  1 0.4001234
#> 5    8  1 4.1133004
#> 6    9  1 3.3169821
                       
m1 <- brm(y~0+time+(0+time|id), data = d)
add_epred_draws(d, m1) |>
  summarise(mean_e = mean(.epred)) |> 
  ggplot(aes(time, y, color = id)) + 
  geom_point() +
  geom_line(aes(y = mean_e))

I now want to make predictions for a large amount of hypothetical observations. E.g. If a new subject has y=10 at time=6, what can I expect from measurements at time = 10.
I want to make these predictions for a grid of relevant combinations of y and time (in my real data I have two fixed effects, so the grid quickly becomes very large).

Here is the solution I have for now. It seems to work, but requires refitting the entire model for each hypothetical observation:

# Generate a grid of hypothetical observations to make predictions for (here only 3)
fake_data_for_predictions <- 
  expand.grid(time = 6, y = c(5,10,15)) %>% 
  mutate(id = row_number() + 100,
         id = factor(id))

# make a list of new datasets, each with the original data and one fake observation
d_list_w_one_fake <- lapply(split(fake_data_for_predictions, 
                             fake_data_for_predictions$id),
                 rbind, d)

head(d_list_w_one_fake[[1]])
#>   time         y  id
#> 1    6 5.0000000 101
#> 2    4 3.0059659   1
#> 3    5 2.2575742   1
#> 4    6 1.6200366   1
#> 5    7 0.4001234   1
#> 6    8 4.1133004   1

# Fit a model to each nearly identical dataset
m_mult <- brm_multiple(y~0+time+(0+time|id), data = d_list_w_one_fake,
                       combine = FALSE)

# Generate predictions at time = 10 for the fake subject
gen_prediction_at_10 <- function(mod) {
  fake_obs <- mod$data[1,] # the first row is the fake
  
  fake_obs %>% 
    mutate(
      time_obs = time,
      y_obs = y,
      time = 10) %>% 
    add_predicted_draws(mod, value = "y_predicted")
}

predictions_at_10 <- purrr::map_df(m_mult, gen_prediction_at_10)

ggplot(predictions_at_10, aes(y_obs, y_predicted)) +
  stat_eye() +
  labs(title = "Predictions of y at time = 10, conditional on an observation of y at time = 6")

Created on 2022-03-23 by the reprex package (v2.0.0)

Is there a faster way to do this? This vignette (Approximate leave-future-out cross-validation for Bayesian time series models • loo) seems to describe an approach to validate this type of predictions, but does not seem to help with generating loads of predictions.

In the end, I want to make a visualization conveying the level of y at different “times” that corresponds to an e.g. 80% probability of y > 20 at time = 10.

1 Like

One alternative might be to just fit a model for only the single new individual, with informative priors fixed to the hierarchical prior from the original model. This fit would then need to be iterated across many iterations from the original posterior (propagating the full uncertainty in the hierarchical prior), but each iteration should be so cheap that this might be feasible. In general, repeated fits iteration-wise like this are a way to implement “cuts” in Stan, so that information downstream of the cut (in this case the new observed value) doesn’t propagate back to inform estimates upstream of the cut (in this case the fitted hierarchical parameters).