This code computes a brmsfit Dirichlet regression object and then extracts fitted values in three ways: (1) using brms::fitted.brmsfit(), or using brms::fitted.brmsfit(scale = "linear") and then applying either (2) brms:::softmax(), or (3) brms::dirichlet()$linkinv() (after adding the reference category value manually):
library(brms)
library(dplyr)
# Dirichlet regression example (modified from https://discourse.mc-stan.org/t/dirichlet-regresion-using-brms/8591/2)
bind <- function(...) cbind(...)
N <- 20
df <-
data.frame(
y1 = rbinom(N, 10, 0.5), y2 = rbinom(N, 10, 0.7),
y3 = rbinom(N, 10, 0.9), x = rnorm(N)
) |>
dplyr::mutate(
size = y1 + y2 + y3,
y1 = y1 / size,
y2 = y2 / size,
y3 = y3 / size
)
df$y <- with(df, cbind(y1, y2, y3))
fit <-
brms::brm(
bind(y1, y2, y3) ~ x,
df,
family = dirichlet_family,
chains = 2,
iter = 200,
warmup = 100,
seed = 4324
)
#> Compiling Stan program...
#> Start sampling
#>
#> SAMPLING FOR MODEL 'anon_model' NOW (CHAIN 1).
#> Chain 1:
#> Chain 1: Gradient evaluation took 0.000193 seconds
#> Chain 1: 1000 transitions using 10 leapfrog steps per transition would take 1.93 seconds.
#> Chain 1: Adjust your expectations accordingly!
#> Chain 1:
#> Chain 1:
#> Chain 1: WARNING: There aren't enough warmup iterations to fit the
#> Chain 1: three stages of adaptation as currently configured.
#> Chain 1: Reducing each adaptation stage to 15%/75%/10% of
#> Chain 1: the given number of warmup iterations:
#> Chain 1: init_buffer = 15
#> Chain 1: adapt_window = 75
#> Chain 1: term_buffer = 10
#> Chain 1:
#> Chain 1: Iteration: 1 / 200 [ 0%] (Warmup)
#> Chain 1: Iteration: 20 / 200 [ 10%] (Warmup)
#> Chain 1: Iteration: 40 / 200 [ 20%] (Warmup)
#> Chain 1: Iteration: 60 / 200 [ 30%] (Warmup)
#> Chain 1: Iteration: 80 / 200 [ 40%] (Warmup)
#> Chain 1: Iteration: 100 / 200 [ 50%] (Warmup)
#> Chain 1: Iteration: 101 / 200 [ 50%] (Sampling)
#> Chain 1: Iteration: 120 / 200 [ 60%] (Sampling)
#> Chain 1: Iteration: 140 / 200 [ 70%] (Sampling)
#> Chain 1: Iteration: 160 / 200 [ 80%] (Sampling)
#> Chain 1: Iteration: 180 / 200 [ 90%] (Sampling)
#> Chain 1: Iteration: 200 / 200 [100%] (Sampling)
#> Chain 1:
#> Chain 1: Elapsed Time: 0.195 seconds (Warm-up)
#> Chain 1: 0.194 seconds (Sampling)
#> Chain 1: 0.389 seconds (Total)
#> Chain 1:
#>
#> SAMPLING FOR MODEL 'anon_model' NOW (CHAIN 2).
#> Chain 2:
#> Chain 2: Gradient evaluation took 0.000113 seconds
#> Chain 2: 1000 transitions using 10 leapfrog steps per transition would take 1.13 seconds.
#> Chain 2: Adjust your expectations accordingly!
#> Chain 2:
#> Chain 2:
#> Chain 2: WARNING: There aren't enough warmup iterations to fit the
#> Chain 2: three stages of adaptation as currently configured.
#> Chain 2: Reducing each adaptation stage to 15%/75%/10% of
#> Chain 2: the given number of warmup iterations:
#> Chain 2: init_buffer = 15
#> Chain 2: adapt_window = 75
#> Chain 2: term_buffer = 10
#> Chain 2:
#> Chain 2: Iteration: 1 / 200 [ 0%] (Warmup)
#> Chain 2: Iteration: 20 / 200 [ 10%] (Warmup)
#> Chain 2: Iteration: 40 / 200 [ 20%] (Warmup)
#> Chain 2: Iteration: 60 / 200 [ 30%] (Warmup)
#> Chain 2: Iteration: 80 / 200 [ 40%] (Warmup)
#> Chain 2: Iteration: 100 / 200 [ 50%] (Warmup)
#> Chain 2: Iteration: 101 / 200 [ 50%] (Sampling)
#> Chain 2: Iteration: 120 / 200 [ 60%] (Sampling)
#> Chain 2: Iteration: 140 / 200 [ 70%] (Sampling)
#> Chain 2: Iteration: 160 / 200 [ 80%] (Sampling)
#> Chain 2: Iteration: 180 / 200 [ 90%] (Sampling)
#> Chain 2: Iteration: 200 / 200 [100%] (Sampling)
#> Chain 2:
#> Chain 2: Elapsed Time: 0.242 seconds (Warm-up)
#> Chain 2: 0.248 seconds (Sampling)
#> Chain 2: 0.49 seconds (Total)
#> Chain 2:
#> Warning: The largest R-hat is 1.11, indicating chains have not mixed.
#> Running the chains for more iterations may help. See
#> https://mc-stan.org/misc/warnings.html#r-hat
#> Warning: Bulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
#> Running the chains for more iterations may help. See
#> https://mc-stan.org/misc/warnings.html#bulk-ess
#> Warning: Tail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
#> Running the chains for more iterations may help. See
#> https://mc-stan.org/misc/warnings.html#tail-ess
set.seed(478)
fit_fitted <- fitted(fit, scale = "response")[1, 1, ]
set.seed(478)
fit_fitted_linear <- fitted(fit, scale = "linear")[1, 1, ]
fit_fitted_linkinv <- fit$family$linkinv(c(0, fit_fitted_linear))
fit_fitted_softmax <- brms:::softmax(c(0, fit_fitted_linear))
fit_fitted
#> P(Y = y1) P(Y = y2) P(Y = y3)
#> 0.2340334 0.2933542 0.4726124
fit_fitted_linkinv
#> eta1 eta2
#> 0.5000000 0.5564065 0.6691070
fit_fitted_softmax
#> [,1] [,2] [,3]
#> [1,] 0.2338393 0.2933084 0.4728523
Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#> setting value
#> version R version 4.3.1 (2023-06-16 ucrt)
#> os Windows 11 x64 (build 26100)
#> system x86_64, mingw32
#> ui RTerm
#> language (EN)
#> collate German_Germany.utf8
#> ctype German_Germany.utf8
#> tz Europe/Berlin
#> date 2025-09-25
#> pandoc 3.6.3 @ C:/Program Files/RStudio/resources/app/bin/quarto/bin/tools/ (via rmarkdown)
#>
#> ─ Packages ───────────────────────────────────────────────────────────────────
#> package * version date (UTC) lib source
#> cli 3.6.1 2023-03-23 [1] CRAN (R 4.3.1)
#> digest 0.6.33 2023-07-07 [1] CRAN (R 4.3.1)
#> evaluate 1.0.3 2025-01-10 [1] CRAN (R 4.3.3)
#> fastmap 1.1.1 2023-02-24 [1] CRAN (R 4.3.1)
#> fs 1.6.3 2023-07-20 [1] CRAN (R 4.3.1)
#> glue 1.6.2 2022-02-24 [1] CRAN (R 4.3.1)
#> htmltools 0.5.6 2023-08-10 [1] CRAN (R 4.3.1)
#> knitr 1.45.13 2024-02-26 [1] Github (yihui/knitr@ad47ce5)
#> lifecycle 1.0.3 2022-10-07 [1] CRAN (R 4.3.1)
#> reprex 2.0.2 2022-08-17 [1] CRAN (R 4.3.1)
#> rlang 1.1.1 2023-04-28 [1] CRAN (R 4.3.1)
#> rmarkdown 2.29 2024-11-04 [1] CRAN (R 4.3.3)
#> rstudioapi 0.15.0 2023-07-07 [1] CRAN (R 4.3.1)
#> sessioninfo 1.2.2 2021-12-06 [1] CRAN (R 4.3.1)
#> withr 2.5.0 2022-03-03 [1] CRAN (R 4.3.1)
#> xfun 0.42 2024-02-08 [1] CRAN (R 4.3.2)
#> yaml 2.3.10 2024-07-26 [1] CRAN (R 4.3.3)
#>
#> [1] C:/Users/henni/AppData/Local/R/win-library/4.3
#> [2] C:/Program Files/R/R-4.3.1/library
#>
#> ──────────────────────────────────────────────────────────────────────────────
Created on 2025-09-25 with reprex v2.0.2
I would expect that all three ways would produce the same results (apart from random sampling variation), but they do not.
As far as I know, brms::dirichlet()$linkinv() returns the inverse logit function, but the Stan code generated by brms for a Dirichlet regression uses the softmax function as inverse link function.
Is this correct? And shouldn’t brms provide softmax() (or brms:::inv_link_categorical() which automatically adds the reference category value) as brms::dirichlet()$linkinv element?
The reason I am interested in this is that I need to generate predictions manually with the draws from a Dirichlet regression brmsfit object because I want to propagate uncertainty in predictor variable values. Unfortunately, softmax() or the helper function brms:::inv_link_categorical() are not exported from brms which makes it difficult to implement these functions in an R package.
Operating System: Windows 11
Interface Version: brms 2.23.1