Interpret parameters from multinomial (categorical) logit regression

I have trouble finding help related to the interpretation of “categorical” models.

Let’s take the following simple model as an example:

library(brms)
library(ggplot2)

m <- brm(Species ~ Sepal.Length, 
         data=iris, 
         family=categorical(link="logit", refcat = "setosa"),
         algorithm = "meanfield") 
#> Warning: Pareto k diagnostic value is 0.71. Resampling is unreliable. Increasing
#> the number of draws or decreasing tol_rel_obj may help.


modelbased::estimate_relation(m, at="Sepal.Length", length = 30) |> 
  ggplot(aes(x = Sepal.Length, y = Predicted)) +
  geom_line(aes(color = Response))


m
#>  Family: categorical 
#>   Links: muversicolor = logit; muvirginica = logit 
#> Formula: Species ~ Sepal.Length 
#>    Data: iris (Number of observations: 150) 
#>   Draws: 1 chains, each with iter = 1000; warmup = 0; thin = 1;
#>          total post-warmup draws = 1000
#> 
#> Population-Level Effects: 
#>                           Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
#> muversicolor_Intercept      -23.73      2.21   -28.00   -19.52 1.00     1109
#> muvirginica_Intercept       -37.77      1.94   -41.78   -34.10 1.00     1015
#> muversicolor_Sepal.Length     4.38      0.38     3.64     5.11 1.00     1091
#> muvirginica_Sepal.Length      6.63      0.33     6.00     7.31 1.00      993
#>                           Tail_ESS
#> muversicolor_Intercept         981
#> muvirginica_Intercept          994
#> muversicolor_Sepal.Length      872
#> muvirginica_Sepal.Length       994
#> 
#> Draws were sampled using variational(meanfield).

Created on 2022-02-16 by the reprex package (v2.0.1)

The model includes 4 fixed parameters, 2 suffixed as _Intercept and 2 as the effect of the Sepal.Length predictor. The parameters are only pertaining two the non-reference categories (in this case virginica and versicolor).

My question is: are these parameters to be all interpreted as differences from their reference counterpart? Is it correct that:

  • muversicolor_Intercept is the difference between the intercept at setosa and the one of versicolor?
  • muversicolor_Sepal.Length is the difference between the effect of Sepal.Length at setosa and the one at versicolor?

Is there any way to infer say the effect of Sepal.Length at the reference level setosa?

Thanks for any clarification as to how to interpret these parameters.

1 Like

This is correct, where the coefficients are in terms of logits. This particular parameterization constraints the reference intercept and slope to 0, thus why intercepts and slopes are estimated for only two categories. You then use the softmax function to transform the vector of latent scores into probabilities. See the “Identability” section here: 1.6 Multi-logit regression | Stan User’s Guide . I believe it is equivalent in point estimates (but not standard errors) to two binary logistic models, each of which treat Setosa as 0 and the other as 1, if that helps with interpretation.

The “effect” of Sepal.Length for Setosa is relative to the other two. Since the baseline model constrains Setosa’s coefficients to 0, you can just linearly shift the coefficients to center on some other value. But I’m not sure this gets you much.

The example below illustrates all of this. I copied your point estimates and set the coefficients for Setosa to zero. You can add any arbitrary constant to the intercepts and slopes and still get the same predictions. So constraining the parameters is necessary.

library(dplyr)
library(ggplot2)

param_df <- data.frame(group = c('Setosa', 'Versicolor', 'Virginica'),
                   intercept = c(0, -23.73, -37.77),
                   slope = c(0, 4.38, 6.63))


fitted_df <- expand.grid(group = c('Setosa', 'Versicolor', 'Virginica'),
                         sepal_length = seq(from = 4, to = 8, by = 0.01)) %>%
  left_join(param_df) %>%
  mutate(logit = intercept + slope * sepal_length) %>%
  group_by(sepal_length) %>%
  mutate(p = exp(logit) / sum(exp(logit))) %>%
  ungroup()


fitted_df %>%
  ggplot(aes(sepal_length, p, group = group, color = group)) +
    geom_line() +
    scale_color_brewer(palette = 'Set1')

2 Likes