Hierarchical Dirichlet Toy Model

Dear Stan community,
I’d like to invite you to design a model for a toy problem. Imagine you have N hypothetical 3-sided dices, each with a face of 1, 2, 3 and has an unknown weight [w_1, w_2, w_3]. You want to know what is the population average of the weights for all these dices, let’s call them [w_{pop1}, w_{pop2}, w_{pop3}]. You start your experiment by rolling each dice certain times, but you are not that much of a rigorous experimenter so each dice ended up having different number of rolls. You observe that for each dice, the proportion of 1s, 2s and 3s are quite different and that should reflect the different weights. Here’s a simple drawing of the problem:

I understand that Dirichlet should be used in this problem, but not quite sure how to implement it to reflect the hierarchy and most importantly, what kind of prior should be set. Any insights much appreciated!

1 Like

Morning!

This is outside my area but will give this a bump. Here is a model Adding levels to a hierarchical Dirichlet-Multinomial model that’s likely to complex for what you want to do. But it might be possible to strip it down to it’s bare bones.

Hey there,

I guess there are multiple ways one could go about this.

So it seems the main “problem” is how to incorporate all the information about all those rolls without the the rolls of just on die overpowering all the other dice rolls (while still propagating uncertainty).

I think it’s easiest just to simulate the experiment. I started from the population parameters and drew each die’s “true” probabilities.

> POP_N <- 1000
> POP_AVGS <- c(0.2, 0.3, 0.5)
> POP_HET <- 5
> dice_pop <- gtools::rdirichlet(POP_N, POP_HET*POP_AVGS)
> round(dice_pop[1:5,], 2)
     [,1] [,2] [,3]
[1,] 0.00 0.39 0.60
[2,] 0.28 0.16 0.56
[3,] 0.03 0.50 0.47
[4,] 0.40 0.17 0.43
[5,] 0.03 0.26 0.71
> round(colMeans(dice_pop), 2)
[1] 0.2 0.3 0.5

Then pick a few of them randomly. Also randomly draw a number of rolls per each die (I took your number and added 1 and 50 to the list).

SAMPLE_N <- 20
dice <- dice_pop[sample(1:POP_N, SAMPLE_N), ]
rolls <- sample(c(1, 50, 500, 1000, 2500), SAMPLE_N, replace = TRUE)

Y <- array(dim = c(SAMPLE_N, 3))
for(i in 1:SAMPLE_N){
  Y[i,] <- rmultinom(1, rolls[i], dice[i,])
}

So here’s the data that we have to work with from which to infer the population averages.

> Y
      [,1] [,2] [,3]
 [1,]   85 1078 1337
 [2,]  325  115 2060
 [3,]   11    8   31
 [4,]    3   24   23
 [5,]    0    1    0
 [6,]    2    9   39
 [7,]  124  758  118
 [8,]   46   49  405
 [9,]  450  287  263
[10,]   42  619  339
[11,]  278   73  149
[12,]    0    0    1
[13,]    9  373  618
[14,]    0    0    1
[15,]   15  297  188
[16,]   16   14   20
[17,]  447  289  264
[18,]    3   32   15
[19,]    1    0    0
[20,]    2   12   36

We can pluck this data into a multinomial model and estimate the probabilities for each die. Die with few rolls will have larger uncertainty in their estimates.

library(cmdstanr)

mod <- cmdstan_model(
  write_stan_file("
data{
  int dice;
  int Y[dice, 3];
}
parameters{
  simplex[3] p_pop;
  simplex[3] p_die[dice];
}
model{
  vector[3] y = [1, 1, 1]';
  for (d in 1:dice){
    p_die[d] ~ dirichlet([1, 1, 1]');
    Y[d] ~ multinomial(p_die[d]);
    for (i in 1:3) y[i] += p_die[d][i];
  }
  
  p_pop ~ dirichlet(y);
  
}

  ")
)

posterior <- mod$sample(data = list(dice = SAMPLE_N, Y = Y))

To get the population distribution, I just summed up all simplexes. I did this because simplexes sum to 1 and the this way it’s like adding one (pseudo) count (observation/die) the population distribution. One can also provide priors for the outcomes of the rolls as well as for the population of dice. I’ve set both to dirichlet(1).

> posterior$summary()
# A tibble: 64 x 10
   variable         mean     median      sd     mad         q5        q95  rhat ess_bulk ess_tail
   <chr>           <dbl>      <dbl>   <dbl>   <dbl>      <dbl>      <dbl> <dbl>    <dbl>    <dbl>
 1 lp__       -9358.     -9358.     4.90    4.97    -9367.     -9351.      1.00    1465.    1909.
 2 p_pop[1]       0.216      0.209  0.0864  0.0894      0.0923     0.373   1.00    5504.    2834.
 3 p_pop[2]       0.343      0.339  0.0982  0.102       0.190      0.514   1.00    6130.    2990.
 4 p_pop[3]       0.440      0.438  0.101   0.102       0.276      0.610   1.00    6105.    3079.
 5 p_die[1,1]     0.0344     0.0343 0.00364 0.00354     0.0286     0.0405  1.00    6322.    2693.
 6 p_die[2,1]     0.130      0.130  0.00668 0.00675     0.120      0.141   1.00    6547.    3175.
 7 p_die[3,1]     0.226      0.223  0.0558  0.0562      0.141      0.324   1.00    6300.    2918.
 8 p_die[4,1]     0.0755     0.0699 0.0361  0.0343      0.0270     0.143   1.00    5675.    2327.
 9 p_die[5,1]     0.246      0.200  0.190   0.186       0.0169     0.626   1.00    5124.    2491.
10 p_die[6,1]     0.0571     0.0508 0.0328  0.0304      0.0154     0.119   1.00    5805.    2739.
# … with 54 more rows

Kind of works… Maybe that’s a good starting point for more elaborate stuff?

Interesting question! :)
Cheers,
Max

3 Likes