Multinomial distribution

I am not sure How Stan caculates multinomial distribution.
I did not know the multinomial distribution, then I model it by reducing it to binomials.
My question is the following two model is equivalent or not in Stan.

In my model, there is a random variable X=(X_1,X_2,X_3.X_4), X_i\geq0, \sum X_i =n.
Let p=(p_1,p_2,p_3), p_i >0, \sum p_i =1.

Then, the multinomial distribution is

1st model

X \sim \text{Multinomial}(p)

On the other hand, we can reduce the above to binomial distributions as follows

2nd model

X_3 \sim \text{Binomial}(p_3,n)

X_2 \sim \text{Binomial}(\frac{p_2}{1-p_3},n-X_3)

X_1 \sim \text{Binomial}(\frac{p_1}{1-p_3-p_2},n-X_3-X_2)

X_4:= n-X_1-X_2-X_3.

This reduction is also used, e.g., ?rmultinom in stats package.

I use the above both model in my model, then the behavour of HMC is not same. E.g., for some data, one of them converges but the other dose not.

In Stan, the multinomial distribution is defiend using the above reduction to binomials? Namely, the above two are equivalent?

Are you sure the parameterization are the same order?

This is what I got and the log densities in R evaluate to the same thing:

prob = c(0.1, 0.2, 0.3, 0.4)

N = 20

x = rmultinom(1, N, prob)

dmultinom(x, prob = prob, log = TRUE)

dbinom(x[1], N, prob[1], log = TRUE) +
  dbinom(x[2], N - x[1], prob[2] / (1 - prob[1]), log = TRUE) +
  dbinom(x[3], N - x[1] - x[2], prob[3] / (1 - prob[1] - prob[2]), log = TRUE)

In my terms the binomials are in a different order.

Thank you for reply.

I know that the the reduction of multinomials to binomials depends on the order of X_i.

I wander is there some order in which Stan reduce a multinomial to binomials. The convergence of Stan (e.g., in \hat{R} and divergence transition) is different between “~ multinomial()” and many “ ~ biomial()-statements” in my order.

Can I say that the multinomial differes to the binomials? Or, Is it essentially same?

I initially thought these were different, but coincidentally found the same reduction in ?rmultinom() in the stats package. Is this widely used in statistics or also used in Stan (e.g., multinomial_rng() and sampling statement)?

I tried this out and it’s working okay for me.

Here is the model I used:

data {
  int y[4];
}

transformed data {
  int N = sum(y);
}

parameters {
  simplex[4] p1;
  simplex[4] p2;
}

model {
  y ~ multinomial(p1);
  
  y[1] ~ binomial(N, p2[1]);
  y[2] ~ binomial(N - y[1], p2[2] / (1 - p2[1]));
  y[3] ~ binomial(N - y[1] - y[2], p2[3] / (1 - p2[1] - p2[2]));
}

generated quantities {
  real lp1 = multinomial_lpmf(y | p1);
  real lp2 = binomial_lpmf(y[1] | N, p1[1]) +
    binomial_lpmf(y[2] | N - y[1], p1[2] / (1 - p1[1])) +
    binomial_lpmf(y[3] | N - y[1] - y[2], p1[3] / (1 - p1[1] - p1[2]));
  real diff = lp1 - lp2;
}

The output is:

> fit$print(max_rows = 12)
 variable   mean median   sd  mad     q5    q95 rhat ess_bulk ess_tail
    lp__  -66.22 -65.87 1.79 1.63 -69.64 -63.91 1.00     1901     2329
    p1[1]   0.13   0.12 0.07 0.06   0.04   0.25 1.00     5113     2879
    p1[2]   0.21   0.20 0.08 0.08   0.09   0.35 1.00     5920     2935
    p1[3]   0.29   0.29 0.09 0.09   0.15   0.44 1.00     5858     3321
    p1[4]   0.37   0.37 0.10 0.10   0.22   0.54 1.00     5418     3137
    p2[1]   0.12   0.11 0.07 0.07   0.03   0.25 1.00     5221     2453
    p2[2]   0.21   0.20 0.08 0.08   0.09   0.36 1.00     5160     2678
    p2[3]   0.29   0.29 0.09 0.10   0.15   0.46 1.00     5302     3223
    p2[4]   0.37   0.37 0.10 0.10   0.22   0.54 1.00     5934     3311
    lp1    -5.60  -5.30 1.05 0.85  -7.72  -4.46 1.00     2126     2897
    lp2    -5.60  -5.30 1.05 0.85  -7.72  -4.46 1.00     2126     2897
    diff    0.00   0.00 0.00 0.00   0.00   0.00 1.00     3902     3913

So diff is zero if I calculate the lpmf using one method or the other. Also the inferences on p1 and p2 look vaguely the same.

Here’s the R code:

library(cmdstanr)

prob = c(0.1, 0.2, 0.3, 0.4)
N = 20
y = as.vector(rmultinom(1, N, prob))

model = cmdstan_model("multinom.stan")
fit = model$sample(data = list(y = y))
fit$print(max_rows = 12)
2 Likes