I have a relatively simple model that is behaving in ways I can’t quite understand. I have a treatment with ~ 100 levels, each with around 35 observations. I also have a control category that has many more observations - about 1500. I’m interested in the variability of each of these levels of treatment. My initial idea was to fit a random intercept only model, but I’ve not been able to tame the shrinkage of the random intercepts - they’re all shrunk down to near the population mean. I have tried relaxing the prior, but it doesn’t have much of an effect.
If I respecify the model such that all these treatment types are fixed effects, then I get estimates that are aligned with the observed data. But I think it’s worthwhile to have some partial pooling here.
Is there something I can do to get more sensible random effect estimates from the multilevel model? Some shrinkage is good I think, but I’d also like more variability than I have been able to get so far. Thanks!
set.seed(42)
#Simulate data
generate_data <- function(n, prob) {
tibble(y = rbinom(n, 1, prob))
}
dat <- tibble(treat_id = 1:100,
n = floor(rnorm(100, 35, 5)),
prop = rnorm(100, .4, .05)) %>%
mutate(y = map2(n, prop, generate_data)) %>%
unnest(cols = c(y)) %>%
select(treat_id, y) %>%
bind_rows(data.frame(treat_id=0, y = rbinom(1500, 1, prob=.4))) #adding in the controls
#random intercepts
m1 <- brm(y ~ 1 + (1|treat_id),
family = bernoulli,
data=dat)
#with weak priors
m_weakprior <- brm(y ~ 1 + (1|treat_id),
family = bernoulli,
data=dat,
prior = set_prior('student_t(3, 0, 10)',
class='sd',
group='treat_id'))
#fixed effects
m2 <- brm(y ~ treat_id,
family=bernoulli,
data=dat)
#plots
treats <- sample(1:100, 12)
dat %>%
data_grid(treat_id) %>%
filter(treat_id %in% treats) %>%
add_epred_draws(m1) -> post
dat %>%
filter(treat_id %in% treats) %>%
group_by(treat_id) %>%
summarize(est = mean(y)) -> obs_dat
ggplot(post[post$treat_id %in% treats,], aes(x=.epred)) +
geom_density() +
geom_vline(data=obs_dat, aes(xintercept=est)) +
facet_wrap(~treat_id) -> fig1
dat %>%
data_grid(treat_id) %>%
filter(treat_id %in% treats) %>%
add_epred_draws(m_weakprior) -> post
ggplot(post[post$treat_id %in% treats,], aes(x=.epred)) +
geom_density() +
geom_vline(data=obs_dat, aes(xintercept=est)) +
facet_wrap(~treat_id) -> fig2
dat %>%
data_grid(treat_id) %>%
filter(treat_id %in% treats) %>%
add_epred_draws(m2) -> post
ggplot(post[post$treat_id %in% treats,], aes(x=.epred)) +
geom_density() +
geom_vline(data=obs_dat, aes(xintercept=est)) +
facet_wrap(~treat_id) -> fig3