Inverse predictions from posterior

I conducted a logistic regression in stanarm, to characterize the relationship between a continuous predictor X and a binary outcome Y. I can successfully sample from the posterior distribution of parameters to produce posterior predictive plots. Is there any quick way to extract an inverse probability? That is, for a specified value of Y (say, probability = 0.5), can I produce an 80% credible interval for X?

1 Like

There are two methods I’ve used for this in models where the relationship between X and the log odds of the response is linear, conditional on any other predictors (eg, if there are groups in my data conditional upon which the intercept or slope for X is different). I don’t know if there’s an easy generic solution in non-linear cases.

  1. Do some parameter munging (nowadays, using as_draws_rvars() on the model) to get the slope (b) and intercept (a) for logit(p) = a + bX, then rearrange to get X for a particular logit(p). This is fine in simple models but annoying for models with even a few other predictors that change a or b, since you have to write the equations for each.

  2. Evaluate rvar(posterior_linpred(...)) to get values of logit(p) at two different X values (eg 0 and 1) conditional on all other predictors, then use these two points and the equation of the line to calculate X for a particular logit(p). This is nice because it doesn’t require you to munge model parameters, and can easily be stuffed into a loop (or equivalent) to do the same calculation conditional on whatever combinations of non-X predictors you care about. This is the approach I usually use and recommend.

Sorry if that isn’t clear, I’d mock up an example but I’m on my phone. If that doesn’t make sense and no one else wanders along in the meantime, I may mock up an example tomorrow.

Thanks! I think I’m doing a version of #1. I used

paramCI = posterior_interval (fit object, prob=0.8)

to obtain 10%-iles and 90%-iles on the intercept and slope parameters. Then, I can calculate values in the y (probability) scale for a sequence of X values using those parameter bounds. But to find X values that are associated with p=0.5, I have to find the Y’s nearest 0.5 and note the X value that produced a Y near 0.5…
That’s clunky, and the precision of my estimate depends on the granularity of my sequence of Xs.
Furthermore, I’m not even sure that the X values obtained this way actually give me an 80% credible interval on X for y=0.5.

Ah yeah, you should not do calculations like this on the bounds of intervals — you need to do the transformation on all of the draws and then only calculate intervals as the final step.

I’ll demo both methods now using posterior::rvar, which makes it easy to write out the algebraic operations on posterior draws the same way we would write them on normal variables in R.

First, a demo dataset with two groups, each having different intercepts and slopes:

library(ggplot2)

set.seed(1234)

n = 300

x = seq(-1, 1, length.out = n)
group = c("a", "b")
intercept = c(-0.5, 0.5)
slope = c(3, 6)
log_odds = intercept + slope * x
prob = plogis(log_odds)
y = rbinom(n, 1, prob)

df = data.frame(x, group, prob, y)

df |>
  ggplot(aes(x, y)) +
  geom_point() +
  geom_line(aes(y = prob)) +
  facet_grid(~ group)

Fit an rstanarm logistic regression (summarise_draws() is from posterior):

library(rstanarm)
library(posterior)

m = stan_glm(y ~ x*group, data = df, family = binomial)
summarise_draws(m)
## # A tibble: 4 × 10
##   variable      mean median    sd   mad     q5    q95  rhat ess_bulk ess_tail
##   <chr>        <dbl>  <dbl> <dbl> <dbl>  <dbl>  <dbl> <dbl>    <dbl>    <dbl>
## 1 (Intercept) -0.650 -0.647 0.238 0.243 -1.04  -0.268  1.00    3778.    3006.
## 2 x            3.59   3.56  0.543 0.545  2.75   4.53   1.00    2987.    2838.
## 3 groupb       1.11   1.11  0.379 0.385  0.508  1.74   1.00    3811.    2525.
## 4 x:groupb     2.47   2.42  1.11  1.09   0.760  4.35   1.00    2937.    2691.

Parameters look approximately correct.

Method 1

Get the parameters for the slope and intercept for one group (for simplicity "a" since it’s the reference group) and use those to find the x value at p = 0.8.

First we extract model parameters as rvars:

d = as_draws_rvars(m)
d
## # A draws_rvars: 1000 iterations, 4 chains, and 4 variables
## $(Intercept): rvar<1000,4>[1] mean ± sd:
## [1] -0.65 ± 0.24 
## 
## $x: rvar<1000,4>[1] mean ± sd:
## [1] 3.6 ± 0.54 
## 
## $groupb: rvar<1000,4>[1] mean ± sd:
## [1] 1.1 ± 0.38 
## 
## $x:groupb: rvar<1000,4>[1] mean ± sd:
## [1] 2.5 ± 1.1 

Each rvar above is actually an array containing all draws for that variable from the model, and when we do arithmetic operations on the object those operations are applied to all draws, creating a new rvar. So we can pull out the slope and intercept for group "a" from d and calculate the corresponding x value:

intercept_a = d$`(Intercept)`
slope_a = d$x

x_80_a = (qlogis(0.8) - intercept_a) / slope_a
x_80_a
## rvar<1000,4>[1] mean ± sd:
## [1] 0.58 ± 0.095 

summarise_draws works on this as well:

summarise_draws(x_80_a)
## # A tibble: 1 × 10
##   variable  mean median     sd    mad    q5   q95  rhat ess_bulk ess_tail
##   <chr>    <dbl>  <dbl>  <dbl>  <dbl> <dbl> <dbl> <dbl>    <dbl>    <dbl>
## 1 x_80_a   0.578  0.571 0.0949 0.0891 0.437 0.746  1.00    4530.    2636.

Personally I don’t like this method, since we’d have to do slightly different calculations for group "b", motivating method 2.

Method 2

Use the values of logit(p) at x = 0 and x = 1 for all groups (obtained via posterior_linpred()) to calculate the slope and intercept, and then apply Method 1 on all groups simultaneously:

pred_df = data.frame(group = c("a", "b"))
pred_df$intercept = rvar(posterior_linpred(m, newdata = cbind(pred_df, x = 0)))
pred_df$slope = rvar(posterior_linpred(m, newdata = cbind(pred_df, x = 1))) - pred_df$intercept
pred_df$x_80 = with(pred_df, (qlogis(0.8) - intercept) / slope)
pred_df
##   group    intercept      slope         x_80
## 1     a -0.65 ± 0.24 3.6 ± 0.54 0.58 ± 0.095
## 2     b  0.46 ± 0.29 6.1 ± 0.96 0.16 ± 0.060

We can again do summarise_draws:

## > summarise_draws(pred_df$x_80)
## # A tibble: 2 × 10
##   variable         mean median     sd    mad     q5   q95  rhat ess_bulk ess_tail
##   <chr>           <dbl>  <dbl>  <dbl>  <dbl>  <dbl> <dbl> <dbl>    <dbl>    <dbl>
## 1 pred_df$x_80[1] 0.578  0.571 0.0949 0.0891 0.437  0.746  1.00    4502.    2562.
## 2 pred_df$x_80[2] 0.158  0.153 0.0599 0.0586 0.0650 0.263  1.00    3303.    2722.

Or do a quick visual check with a plot (they won’t line up exactly with the curve at 80 because this is the true curve, not the estimated curve; you could plot the latter using posterior_epred() plus ggdist::stat_lineribbon(); for examples of that see this vignette).

df |>
  ggplot() +
  geom_point(aes(x, y)) +
  ggdist::stat_slab(aes(xdist = x_80), y = 0, data = pred_df, scale = 0.5) +
  geom_vline(aes(xintercept = mean(x_80)), data = pred_df) +
  geom_line(aes(x, y = prob)) +
  facet_grid(~ group)

3 Likes

Thanks @mjskay ! I was able to use your proposed methods to get the values I needed. Much obliged. I don’t know why I haven’t been using the posterior package. That certainly helps.

2 Likes