Priors on mixing distributions - probability vs log-odds

Hello everyone,

I have been fitting finite mixture models in brms, and I have one last piece of doubt regarding setting sensible priors and how they are implemented. I’ve done my best to read all the documentation, the output of make_stancode() for my models, and find other relevant posts, but I’m still not exactly sure that I’m doing this correctly.

In most of the materials that I’ve learned from online, a prior is set on the mixing proportions in the probability space using something like beta(4,4). When I set a prior like this, my models fit properly and I don’t have issues with non-converging chains. However, I notice that the estimate for parameters such as theta2 are not bound between 0 and 1, and thus I’m assuming that this prior is being transformed to log-odds space somewhere.

This would make sense to me, but the default prior that brms implements is logistic(0,1), which is already in the log-odds space. As my actual models take a long time to fit, I’ve come up with an example using the iris data set where I predict not only an outcome variable but also the probability of the proportion of the second distribution (theta2). Like my actual models, the beta(5, 5) leads to a well fitting model but the logistic(0,0.35), which is similar information only in the log-odds space, leads poorly mixed chains and other problems.

My questions:

  1. Is it appropriate to use the beta distribution for setting priors on mixing proportions in models like I’ve done below for the beta_prior_model?
  2. Is the theta2_Intercept in the model summary in log-odds?
  3. If I want to use the posterior of theta2_Intercept as an informative prior for new data, should I transform it back into the probability space/approximate that posterior with another beta distribution?

Please let me know if I haven’t explained things well, and thanks for your time if you’ve made it this far!

library(brms)

iris$Petal_length <- as.numeric(scale(iris$Petal.Length))
iris$Petal_width <- as.numeric(scale(iris$Petal.Width))
iris$Species <- as.factor(iris$Species)

# Setting up a mixture
two_normal_mixture <- mixture(gaussian(), gaussian(), order = TRUE)

# Setting up a model formula
model_formula <- bf(Petal_length ~ Petal_width,
                    theta2 ~ Species)

# Checking priors
get_prior(model_formula,
          data = iris,
          family = two_normal_mixture)

beta_prior_model <- brm(model_formula,
                        data = iris,
                        family = two_normal_mixture,
                        prior = c(
                          prior(normal(-1, 0.4), class = "Intercept", dpar = "mu1"),
                          prior(normal(1, 0.4), class = "Intercept", dpar = "mu2"),
                          prior(beta(5, 5), class = "Intercept", dpar = "theta2"),
                          prior(normal(0, 0.5), class = "b", dpar = "mu1"),
                          prior(normal(0, 0.5), class = "b", dpar = "mu2"),
                          prior(normal(0, 0.5), class = "b", dpar = "theta2"),
                          prior(normal(0, 1), class = "sigma1"),
                          prior(normal(0, 1), class = "sigma2")),
                        cores = 4,
                        warmup = 5000,
                        iter = 10000,
                        control = list(adapt_delta = 0.95),
                        backend = "cmdstanr")
summary(beta_prior_model)
Family: mixture(gaussian, gaussian) 
  Links: mu1 = identity; sigma1 = identity; mu2 = identity; sigma2 = identity; theta1 = identity; theta2 = identity 
Formula: Petal_length ~ Petal_width 
         theta2 ~ Species
   Data: iris (Number of observations: 150) 
  Draws: 4 chains, each with iter = 5000; warmup = 0; thin = 1;
         total post-warmup draws = 20000

Population-Level Effects: 
                         Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
mu1_Intercept               -0.16      0.03    -0.22    -0.10 1.00     9347     7230
mu2_Intercept                0.16      0.03     0.11     0.21 1.00    21235    15456
theta2_Intercept             0.39      0.30    -0.19     0.98 1.00    11268     7346
mu1_Petal_width              0.88      0.03     0.83     0.92 1.00    13547     8542
mu2_Petal_width              1.15      0.03     1.09     1.21 1.00    17965    13337
theta2_Speciesversicolor     0.89      0.46    -0.02     1.77 1.00    11413     8607
theta2_Speciesvirginica     -0.71      0.39    -1.48     0.06 1.00    14674     9848

Family Specific Parameters: 
       Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma1     0.17      0.02     0.14     0.22 1.00    18814    12725
sigma2     0.18      0.02     0.15     0.22 1.00    18822    14294

Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).

logistic_prior_model <- brm(model_formula,
                        data = iris,
                        family = two_normal_mixture,
                        prior = c(
                          prior(normal(-1, 0.4), class = "Intercept", dpar = "mu1"),
                          prior(normal(1, 0.4), class = "Intercept", dpar = "mu2"),
                          prior(logistic(0, 0.35), class = "Intercept", dpar = "theta2"),
                          prior(normal(0, 0.5), class = "b", dpar = "mu1"),
                          prior(normal(0, 0.5), class = "b", dpar = "mu2"),
                          prior(normal(0, 0.5), class = "b", dpar = "theta2"),
                          prior(normal(0, 1), class = "sigma1"),
                          prior(normal(0, 1), class = "sigma2")),
                        cores = 4,
                        warmup = 5000,
                        iter = 10000,
                        control = list(adapt_delta = 0.95),
                        backend = "cmdstanr")
summary(logistic_prior_model)
Family: mixture(gaussian, gaussian) 
  Links: mu1 = identity; sigma1 = identity; mu2 = identity; sigma2 = identity; theta1 = identity; theta2 = identity 
Formula: Petal_length ~ Petal_width 
         theta2 ~ Species
   Data: iris (Number of observations: 150) 
  Draws: 4 chains, each with iter = 5000; warmup = 0; thin = 1;
         total post-warmup draws = 20000

Population-Level Effects: 
                         Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
mu1_Intercept               -0.09      0.12    -0.21     0.15 1.53        7       NA
mu2_Intercept                0.23      0.13     0.12     0.65 1.44        8       NA
theta2_Intercept            -0.38      0.81    -2.00     0.82 1.53        7       NA
mu1_Petal_width              0.94      0.10     0.84     1.15 1.53        7       NA
mu2_Petal_width              0.97      0.32     0.25     1.22 1.53        7       NA
theta2_Speciesversicolor     0.74      0.60    -0.63     1.75 1.22       13       30
theta2_Speciesvirginica      0.06      1.10    -1.31     2.35 1.53        7       NA

Family Specific Parameters: 
       Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma1     0.18      0.02     0.14     0.22 1.05       57     2982
sigma2     0.18      0.03     0.14     0.26 1.12     6459       35

Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).

Just leaving the answer here below in the case someone finds this and is looking for the answer. It appears that, when estimating mixing proportions in brms, estimates are indeed on the logit scale and priors should be specified as such.

I was still somewhat uncertain of this after looking at the Stan code created by the make_stancode() function, so I ran some simple models with simulated data that convinced me. The following code demonstrates this with some simulated data:

library(brms)


set.seed(717)

# Generating data from two different normal distributions.
N <- 100
x1 <- rnorm(N, -5, 1)
x2 <- rnorm(N, 5, 1)

x <- c(x1, x2)
d <- data.frame(x)

I then fit a model with Normal(0,1) prior on mixing proportions - appropriate for logit space, and wide but non-exchangeable priors on the location of the two distributions (assuming one is negative and one is positive). The model fits well and the mixing proportion is estimated at zero, which is where it should be for even proportions in log-odds.

b1 <- brm(bf(x ~ 1,
             theta2 ~ 1),
          data = d,
          family = mixture(gaussian(),gaussian(), order = TRUE),
          cores = 4,
          prior = c(
            prior(exponential(1), class = "sigma1"),
            prior(exponential(1), class = "sigma2"),
            prior(normal(-5, 2.5), class = "Intercept", dpar = "mu1"),
            prior(normal(5, 2.5), class = "Intercept", dpar = "mu2"),
            prior(normal(0, 1), class = "Intercept", dpar = "theta2")
          ),
          backend = "cmdstanr"
)
summary(b1)
Family: mixture(gaussian, gaussian) 
  Links: mu1 = identity; sigma1 = identity; mu2 = identity; sigma2 = identity; theta1 = identity; theta2 = identity 
Formula: x ~ 1 
         theta2 ~ 1
   Data: d (Number of observations: 200) 
  Draws: 4 chains, each with iter = 1000; warmup = 0; thin = 1;
         total post-warmup draws = 4000

Population-Level Effects: 
                 Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
mu1_Intercept       -4.90      0.09    -5.08    -4.73 1.00     4126     2907
mu2_Intercept        5.01      0.11     4.79     5.22 1.00     5148     3758
theta2_Intercept    -0.00      0.14    -0.28     0.27 1.00     4333     2882

Family Specific Parameters: 
       Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma1     0.89      0.06     0.77     1.02 1.00     4694     3080
sigma2     1.09      0.08     0.96     1.26 1.00     4794     2633

Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).

I then fit the same model with a beta(4, 4) on the mixing proportions, which would be an appropriate weakly informative prior in probability space but informative and strange in the logit space as it’s bounded. The true mixing proportion is not recovered and there are now a few divergent transitions.

b2 <- brm(bf(x ~ 1,
             theta2 ~ 1),
          data = d,
          family = mixture(gaussian(),gaussian(), order = TRUE),
          cores = 4,
          prior = c(
            prior(exponential(1), class = "sigma1"),
            prior(exponential(1), class = "sigma2"),
            prior(normal(-5, 2.5), class = "Intercept", dpar = "mu1"),
            prior(normal(5, 2.5), class = "Intercept", dpar = "mu2"),
            prior(beta(4, 4), class = "Intercept", dpar = "theta2")
          ),
          backend = "cmdstanr"
)
summary(b2)
Family: mixture(gaussian, gaussian) 
  Links: mu1 = identity; sigma1 = identity; mu2 = identity; sigma2 = identity; theta1 = identity; theta2 = identity 
Formula: x ~ 1 
         theta2 ~ 1
   Data: d (Number of observations: 200) 
  Draws: 4 chains, each with iter = 1000; warmup = 0; thin = 1;
         total post-warmup draws = 4000

Population-Level Effects: 
                 Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
mu1_Intercept       -4.90      0.09    -5.08    -4.73 1.00     3992     3236
mu2_Intercept        5.01      0.11     4.79     5.23 1.00     5524     3304
theta2_Intercept     0.23      0.09     0.08     0.42 1.00     3499     2149

Family Specific Parameters: 
       Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma1     0.89      0.07     0.77     1.03 1.00     4296     2495
sigma2     1.09      0.08     0.96     1.26 1.00     4092     2686

Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
Warning message:
There were 14 divergent transitions after warmup. Increasing adapt_delta above  may help. See http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup 

I then ran this again with a mixing proportion of 0.25 for the second distribution. Ran the same two models again. The weakly-informative prior of normal(0,1) in the logit space allows for my model to recover the true mixing proportion, as -1.08 in logit space is approximately 0.25 back-transformed to probability.

# Creating unbalanced data for probability of theta2 = 0.25 (probability scale)
set.seed(717)

x1 <- rnorm(150, -5, 1)
x2 <- rnorm(50, 5, 1)

x <- c(x1, x2)
d <- data.frame(x)

b3 <- brm(bf(x ~ 1,
             theta2 ~ 1),
          data = d,
          family = mixture(gaussian(),gaussian(), order = TRUE),
          cores = 4,
          prior = c(
            prior(exponential(1), class = "sigma1"),
            prior(exponential(1), class = "sigma2"),
            prior(normal(-5, 2.5), class = "Intercept", dpar = "mu1"),
            prior(normal(5, 2.5), class = "Intercept", dpar = "mu2"),
            prior(normal(0, 1), class = "Intercept", dpar = "theta2")
          ),
          backend = "cmdstanr"
)
summary(b3)
 Family: mixture(gaussian, gaussian) 
  Links: mu1 = identity; sigma1 = identity; mu2 = identity; sigma2 = identity; theta1 = identity; theta2 = identity 
Formula: x ~ 1 
         theta2 ~ 1
   Data: d (Number of observations: 200) 
  Draws: 4 chains, each with iter = 1000; warmup = 0; thin = 1;
         total post-warmup draws = 4000

Population-Level Effects: 
                 Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
mu1_Intercept       -5.00      0.08    -5.16    -4.84 1.00     4352     2901
mu2_Intercept        4.88      0.14     4.62     5.15 1.00     5091     3547
theta2_Intercept    -1.08      0.16    -1.40    -0.76 1.00     5389     2836

Family Specific Parameters: 
       Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma1     0.99      0.06     0.88     1.11 1.00     4558     2747
sigma2     0.96      0.10     0.79     1.18 1.00     4197     2180

Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).

Here the beta(4,4) prior is really not good, as we are fixing the mixing proportions at an incorrect value that is not close to the true value. Chaos ensues and the model can’t sample properly.

b4 <- brm(bf(x ~ 1,
             theta2 ~ 1),
          data = d,
          family = mixture(gaussian(),gaussian(), order = TRUE),
          cores = 4,
          prior = c(
            prior(exponential(1), class = "sigma1"),
            prior(exponential(1), class = "sigma2"),
            prior(normal(-5, 2.5), class = "Intercept", dpar = "mu1"),
            prior(normal(5, 2.5), class = "Intercept", dpar = "mu2"),
            prior(beta(4, 4), class = "Intercept", dpar = "theta2")
          ),
          backend = "cmdstanr"
)
summary(b4)
Family: mixture(gaussian, gaussian) 
  Links: mu1 = identity; sigma1 = identity; mu2 = identity; sigma2 = identity; theta1 = identity; theta2 = identity 
Formula: x ~ 1 
         theta2 ~ 1
   Data: d (Number of observations: 200) 
  Draws: 4 chains, each with iter = 1000; warmup = 0; thin = 1;
         total post-warmup draws = 4000

Population-Level Effects: 
                 Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
mu1_Intercept       -5.05      0.16    -5.52    -4.84 1.06       74       NA
mu2_Intercept       -0.03      4.92    -5.08     5.13 1.73        6       NA
theta2_Intercept     0.30      0.25     0.02     0.75 1.73        6       NA

Family Specific Parameters: 
       Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sigma1     4.51      3.55     0.90     9.19 1.74        6       NA
sigma2     0.89      0.11     0.71     1.13 1.36        9       NA

Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
Warning messages:
1: Parts of the model have not converged (some Rhats are > 1.05). Be careful when analysing the results! We recommend running more iterations and/or setting stronger priors. 
2: There were 61 divergent transitions after warmup. Increasing adapt_delta above  may help. See http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
1 Like