Shrinkage of random intercepts - how much is too much?

I have a relatively simple model that is behaving in ways I can’t quite understand. I have a treatment with ~ 100 levels, each with around 35 observations. I also have a control category that has many more observations - about 1500. I’m interested in the variability of each of these levels of treatment. My initial idea was to fit a random intercept only model, but I’ve not been able to tame the shrinkage of the random intercepts - they’re all shrunk down to near the population mean. I have tried relaxing the prior, but it doesn’t have much of an effect.

If I respecify the model such that all these treatment types are fixed effects, then I get estimates that are aligned with the observed data. But I think it’s worthwhile to have some partial pooling here.

Is there something I can do to get more sensible random effect estimates from the multilevel model? Some shrinkage is good I think, but I’d also like more variability than I have been able to get so far. Thanks!


set.seed(42)

#Simulate data
generate_data <- function(n, prob) {
  tibble(y = rbinom(n, 1, prob))
}

dat <- tibble(treat_id = 1:100,
              n = floor(rnorm(100, 35, 5)),
              prop = rnorm(100, .4, .05)) %>%
  mutate(y = map2(n, prop, generate_data)) %>%
  unnest(cols = c(y)) %>%
  select(treat_id, y) %>%
  bind_rows(data.frame(treat_id=0, y = rbinom(1500, 1, prob=.4))) #adding in the controls

#random intercepts
m1 <- brm(y ~ 1 + (1|treat_id), 
          family = bernoulli,
          data=dat)

#with weak priors
m_weakprior <- brm(y ~ 1 + (1|treat_id), 
          family = bernoulli,
          data=dat,
          prior = set_prior('student_t(3, 0, 10)',
                            class='sd',
                            group='treat_id'))
  
#fixed effects
m2 <- brm(y ~ treat_id,
          family=bernoulli,
          data=dat)

#plots
treats <- sample(1:100, 12)

dat %>%
  data_grid(treat_id) %>%
  filter(treat_id %in% treats) %>%
  add_epred_draws(m1) -> post

dat %>% 
  filter(treat_id %in% treats) %>%
  group_by(treat_id) %>%
  summarize(est = mean(y)) -> obs_dat

ggplot(post[post$treat_id %in% treats,], aes(x=.epred)) + 
  geom_density() + 
  geom_vline(data=obs_dat, aes(xintercept=est)) +
  facet_wrap(~treat_id) -> fig1

dat %>%
  data_grid(treat_id) %>%
  filter(treat_id %in% treats) %>%
  add_epred_draws(m_weakprior) -> post

ggplot(post[post$treat_id %in% treats,], aes(x=.epred)) + 
  geom_density() + 
  geom_vline(data=obs_dat, aes(xintercept=est)) +
  facet_wrap(~treat_id) -> fig2

dat %>%
  data_grid(treat_id) %>%
  filter(treat_id %in% treats) %>%
  add_epred_draws(m2) -> post

ggplot(post[post$treat_id %in% treats,], aes(x=.epred)) + 
  geom_density() + 
  geom_vline(data=obs_dat, aes(xintercept=est)) +
  facet_wrap(~treat_id) -> fig3



As far as I can tell, your model is behaving pretty much as one would expect. The provided values of the population mean and standard deviation of the random intercepts in your simulated data are:

> qlogis(0.4)
[1] -0.4054651
> sd(qlogis(rnorm(100, .4, .05)))
[1] 0.230597

Here are the model estimates from running your code with the seed you provided:

> m1
 Family: bernoulli 
  Links: mu = logit 
Formula: y ~ 1 + (1 | treat_id) 
   Data: dat (Number of observations: 4954) 
  Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup draws = 4000

Multilevel Hyperparameters:
~treat_id (Number of levels: 101) 
              Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(Intercept)     0.23      0.06     0.09     0.35 1.00     1297     1137

Regression Coefficients:
          Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept    -0.47      0.04    -0.56    -0.39 1.00     4192     3150

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).

Those estimates look just fine to me.

Thanks Dalton. You’re right that the model looks good. I guess I’m just surprised that the individual random effect estimates bear such a weak relationship to the observed data.

I dug a little more and I’m guessing that ~35 observations per treatment is just not a lot of observations from which to learn.

Here’s a plot of the random effect estimates and observed data from the code above:

And here’s the same plot, but with the model fit to data where there are ~100 observations per treatment.

So I guess that’s it? There’s just not a lot of information in ~35 observations and there’s no way this framework is going to allow those estimates to get too far away from the population mean.

code:

dat %>% 
  data_grid(treat_id) %>% 
  add_epred_draws(m1) %>% 
  median_qi() %>% 
  left_join(dat %>% 
              group_by(treat_id) %>% 
              summarize(prop = mean(y), n = n())) %>% 
  ggplot(aes(x=prop, y=.epred)) + 
  geom_point() + 
  geom_abline(intercept=0, slope=1) + 
  ylim(.2, .6) + 
  xlim(.2, .6) +
  ggtitle('~35 obs per treatment level') -> figa

dat2 <- tibble(treat_id = as.character(1:100),
              n = floor(rnorm(100, 100, 5)),
              prop = rnorm(100, .4, .05)) %>%
  mutate(y = map2(n, prop, generate_data)) %>%
  unnest(cols = c(y)) %>%
  select(treat_id, y) %>%
  bind_rows(data.frame(treat_id='0', y = rbinom(1500, 1, prob=.4)))

m1b <- brm(y ~ 1 + (1|treat_id), 
          family = bernoulli,
          data=dat2)

dat2 %>% 
  data_grid(treat_id) %>% 
  add_epred_draws(m1b) %>% 
  median_qi() %>% 
  left_join(dat2 %>% 
              group_by(treat_id) %>% 
              summarize(prop = mean(y), n = n())) %>% 
  ggplot(aes(x=prop, y=.epred)) + 
  geom_point() + 
  geom_abline(intercept=0, slope=1) + 
  ylim(.2, .6) + 
  xlim(.2, .6) +
  ggtitle('~100 obs per treatment level') -> figb