Priors and the reference level in multinomial models

Thanks @paul.buerkner . I think I figured out most of the reasons I was confused.

My understanding is that if I have a multinomial model with X discrete outcomes and a fixed number of trials (N) the fact that the prior predictive draws have to add up to the number of trials means that a larger variance/spread leads to a lower mean. In other words, if my prior is very tight around zero then there’s a fairly even spread between the outcomes and the mean for each outcome would be more or less N/X. Whereas, if the prior has a larger variance with some very high values being possible, the predictive intervals will have a lower mean with a long tail towards the right/higher values.

I have a simpler follow up:
If I fit a multinomial model with refcat=NA, what is the appropriate transformation to get from the intercept values in the posterior draws to the expected probability or expected counts?

Here’s an example based on your prior example:

N <- 15
dat <- data.frame(
  y1 = rbinom(N, 50, 0.9), y2 = rbinom(N, 150, 0.5), 
  y3 = rbinom(N, 150, 0.2), x = rnorm(N)
)
dat$size <- with(dat, y1 + y2 + y3)
dat$y <- with(dat, cbind(y1, y2, y3))
prior <-  prior(normal(0, 1), "Intercept") 
fit <- brm(bf(y | trials(size)  ~ 1), data = dat, 
           family = multinomial(refcat = NA), prior = prior)

If we call tidydraws:

tidy_draws(fit) %>% select(-contains("__"))
# A tibble: 4,000 x 6
   .chain .iteration .draw b_muy1_Intercept b_muy2_Intercept b_muy3_Intercept
    <int>      <int> <int>            <dbl>            <dbl>            <dbl>
 1      1          1     1          0.567              1.08            0.0644
 2      1          2     2          0.538              1.10            0.0724
 3      1          3     3          0.189              0.697          -0.284 
 4      1          4     4          0.231              0.800          -0.251 
 5      1          5     5         -0.240              0.301          -0.618 
 6      1          6     6         -0.206              0.292          -0.637 
 7      1          7     7         -0.0975             0.383          -0.614 
 8      1          8     8         -0.0840             0.471          -0.590 
 9      1          9     9         -0.00945            0.526          -0.465 
10      1         10    10          0.176              0.695          -0.266 
# … with 3,990 more rows

I want to double check I understand how to get from those Intercept values to the ones provided by fitted

fitted(fit)
, , P(Y = y1)

      Estimate Est.Error     Q2.5    Q97.5
 [1,] 42.07248  1.337674 39.50588 44.75072
 [2,] 48.68386  1.547880 45.71394 51.78297
 [3,] 43.27455  1.375893 40.63462 46.02931
 [4,] 45.37817  1.442777 42.60991 48.26684
 [5,] 39.36782  1.251680 36.96621 41.87388
 [6,] 43.27455  1.375893 40.63462 46.02931
 [7,] 42.67351  1.356783 40.07025 45.39001
 [8,] 48.68386  1.547880 45.71394 51.78297
 [9,] 47.18128  1.500106 44.30302 50.18473
[10,] 48.98438  1.557434 45.99613 52.10262
[11,] 45.07765  1.433222 42.32772 47.94720
[12,] 44.77713  1.423667 42.04554 47.62755
[13,] 44.77713  1.423667 42.04554 47.62755
[14,] 47.78231  1.519215 44.86739 50.82403
[15,] 49.58542  1.576544 46.56050 52.74192

Based on this

I think the answer is something like:

draws_temp<-tidy_draws(fit) %>% select(-contains("__")) %>% group_by(.chain,.iteration,.draw) %>%
  mutate(ExpIntercepts=exp(b_muy1_Intercept)+exp(b_muy2_Intercept)+exp(b_muy3_Intercept))%>%
  mutate(b_muy1_Intercept=exp(b_muy1_Intercept)/ExpIntercepts,
         b_muy2_Intercept=exp(b_muy2_Intercept)/ExpIntercepts,
         b_muy3_Intercept=exp(b_muy3_Intercept)/ExpIntercepts)

Which would give you P(Y=y1), P(Y=y2) etc. And this can then be used to get the fitted values, which can then be summarised:

dat %>% select(size) %>%
  rowwise() %>%
  mutate("P(Y=y1)"=mean(size*draws_temp$b_muy1_Intercept),
         "P(Y=y2)"=mean(size*draws_temp$b_muy2_Intercept),
         "P(Y=y3)"=mean(size*draws_temp$b_muy3_Intercept))

# A tibble: 15 x 4
    size `P(Y=y1)` `P(Y=y2)` `P(Y=y3)`
   <int>     <dbl>     <dbl>     <dbl>
 1   140      42.1      71.3      26.6
 2   162      48.7      82.5      30.8
 3   144      43.3      73.3      27.4
 4   151      45.4      76.9      28.7
 5   131      39.4      66.7      24.9
 6   144      43.3      73.3      27.4
 7   142      42.7      72.3      27.0
 8   162      48.7      82.5      30.8
 9   157      47.2      80.0      29.9
10   163      49.0      83.0      31.0
11   150      45.1      76.4      28.5
12   149      44.8      75.9      28.3
13   149      44.8      75.9      28.3
14   159      47.8      81.0      30.2
15   165      49.6      84.0      31.4

Thanks, and thanks for brms, it’s amazing.