Modelling accuracy of a test across its range

Imagine we have a new test, with possible values that range from 0:100. We apply this test to predict a binary outcome status (e.g. alive or dead). We want to be able to estimate what the new test’s accuracy (e.g. sensitivity, specificity, with uncertainity) is for predicting our outcome when it has a certain value.

Thanks!

``` r
#required libraries
library(tidyverse)
library(brms)
library(tidybayes)

#load some data
d <- ISLR::Default

#prepare data
#note rescaling balance to between 0:100, just to match my real data
d <- d %>%
  mutate(default = factor(case_when(
    default=="Yes" ~ 1,
    default=="No" ~ 0))) %>%
  mutate(balance_z = round((balance - min(d$balance)) / (max(d$balance) - min(d$balance)) * 100))

#plot raw data
d %>%
  ggplot() +
  geom_jitter(aes(x=default, y=balance_z, colour=default), alpha=0.3)


#regression model to predict default from balance_z
m1 <- brm(
  default ~ balance_z,
  data = d,
  family=bernoulli(),
  backend="cmdstanr"
)
#set up a dataset containing all values of balance_z we want
nd <- data.frame(balance_z = seq(0, 100, 1))

Created on 2023-08-01 with reprex v2.0.2

But here is where I am stuck, and grateful for any suggestions:

#Now we want to get posterior samples for each value of balance_z
#BUT HOW TO CALCULATE SENSITIVITY AND SPECIFICTY FOR EACH VALUE OF BALANCE_Z?
post <- add_linpred_draws(m1, newdata=nd) %>%
  mutate(
    sens = ?
    spec = ?
  )

# then would want to summarise e.g. 
post_sum <- post %>%
  pivot_longer(cols=c(sens, spec)) %>%
  ungroup() %>%
  group_by(balance_z, name) %>%
  mean_qi(value)

#and plot e.g
post_sum %>%
  ggplot(aes(x=balance_z, y=value, ymin=.lower, ymax=.upper)) +
  geom_ribbon(aes(fill=name), alpha=0.3) +
  geom_line(aes(colour=name)) +
  facet_grid(name~.)

Please also provide the following information in addition to your question:

  • Operating System: macOS 13.4.1
  • brms Version: 2.1.9
1 Like

I’m afraid I don’t know brms, but I work with these kinds of problems all the time.

You can calculate in-sample sensitivities and specificities using posterior predictive inference. Suppose z_i \in \{ 0, 1 \} is the true category for item i. Then if my system makes hard predictions, sensitivity is just the fraction of predictions that are right when z_i = 1 and specificity the same for z_i = 0. So calculate this sensitivity and specificity for each posterior draw, then average them.

If your system makes probabilistic predictions for z_i you can do the same thing in expectation. If you predict 0.8 for item i and z_i = 1, then you get a 0.8 contribution to the numerator of the total for sensitivity. If instead, z_i = 0, then you get a 0.2 contribution to the numerator for specificity.

I’d be inclined to be a bit more non-parametric than a linear regression for this problem. With only 100 possible test values, you could be completely non-parametric with a monotonicity constraint (higher test value more likely to have positive outcome). That’s probably not easy to code in brms, though.

Have you done a plot of test score vs. actual category? Was that the salmon and teal plot (ggplot defaults to 1980s backpack colors!)? I couldn’t parse the axis labels or what the horizontal range represents (jitter, I’d guess). I think this’d be easier to parse as histograms or density plots—then you can see if there’s non-linearity at play.

5 Likes

Of interest, brms can specify this model with really concise syntax: y ~ mo(x)

Paper:
https://psyarxiv.com/9qkhj/

Vignette:
https://cran.r-project.org/web/packages/brms/vignettes/brms_monotonic.html

2 Likes

Thanks so much Bob and Jacob - very helpful.

Good point about the distribution. I have updated the plot to show densities, and with some non-default ggplot2 colours! I have also updated the model to specify monotonic effects

library(tidyverse)
library(brms)
library(tidybayes)

#load some data
d <- ISLR::Default

#prepare data
#note rescaling balance to between 0:100, just to match my real data
d <- d %>%
  mutate(default = factor(case_when(
    default=="Yes" ~ 1,
    default=="No" ~ 0))) %>%
  mutate(balance_z = round((balance - min(d$balance)) / (max(d$balance) - min(d$balance)) * 100))

#density plot of distributions

d %>%
  ggplot(aes(x=balance_z, fill=default, colour=default)) +
  geom_density(alpha=0.7) +
  facet_grid(default~.) +
  scale_fill_manual(values = c("mediumseagreen", "purple")) +
  scale_colour_manual(values = c("mediumseagreen", "purple"))


#fit the model
m2 <- brm(
  default ~ mo(balance_z),
  data = d,
  family=bernoulli(),
  backend="cmdstanr",
  chains=4,
  cores=4
)
summary(m2)
#>  Family: bernoulli 
#>   Links: mu = logit 
#> Formula: default ~ mo(balance_z) 
#>    Data: d (Number of observations: 10000) 
#>   Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
#>          total post-warmup draws = 4000
#> 
#> Population-Level Effects: 
#>             Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
#> Intercept     -10.67      1.03   -12.91    -8.89 1.00     3439     2430
#> mobalance_z     0.14      0.01     0.12     0.17 1.00     3529     2559
#> 
#> Simplex Parameters: 
#>                   Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
#> mobalance_z1[1]       0.01      0.01     0.00     0.04 1.00     4685     2187
#> mobalance_z1[2]       0.01      0.01     0.00     0.04 1.00     4290     2124
#> mobalance_z1[3]       0.01      0.01     0.00     0.04 1.00     5518     2337
#> mobalance_z1[4]       0.01      0.01     0.00     0.04 1.00     4602     2163
#> mobalance_z1[5]       0.01      0.01     0.00     0.03 1.00     5005     2367
#> mobalance_z1[6]       0.01      0.01     0.00     0.04 1.00     5176     1782
#> mobalance_z1[7]       0.01      0.01     0.00     0.03 1.00     4425     2333
#> mobalance_z1[8]       0.01      0.01     0.00     0.03 1.00     4644     2149
#> mobalance_z1[9]       0.01      0.01     0.00     0.03 1.00     4533     2276
#> mobalance_z1[10]      0.01      0.01     0.00     0.04 1.00     5320     2187
#> mobalance_z1[11]      0.01      0.01     0.00     0.04 1.00     5135     2114
#> mobalance_z1[12]      0.01      0.01     0.00     0.04 1.00     4441     2092
#> mobalance_z1[13]      0.01      0.01     0.00     0.04 1.00     4497     1707
#> mobalance_z1[14]      0.01      0.01     0.00     0.03 1.00     3660     2088
#> mobalance_z1[15]      0.01      0.01     0.00     0.04 1.00     4853     2028
#> mobalance_z1[16]      0.01      0.01     0.00     0.04 1.00     5080     2063
#> mobalance_z1[17]      0.01      0.01     0.00     0.04 1.00     4767     2448
#> mobalance_z1[18]      0.01      0.01     0.00     0.04 1.00     4512     2154
#> mobalance_z1[19]      0.01      0.01     0.00     0.04 1.00     4783     2514
#> mobalance_z1[20]      0.01      0.01     0.00     0.04 1.00     5163     2646
#> mobalance_z1[21]      0.01      0.01     0.00     0.04 1.00     4230     1921
#> mobalance_z1[22]      0.01      0.01     0.00     0.04 1.00     4590     2292
#> mobalance_z1[23]      0.01      0.01     0.00     0.04 1.00     4627     2326
#> mobalance_z1[24]      0.01      0.01     0.00     0.04 1.00     5221     2501
#> mobalance_z1[25]      0.01      0.01     0.00     0.04 1.00     5463     1981
#> mobalance_z1[26]      0.01      0.01     0.00     0.03 1.00     5122     1918
#> mobalance_z1[27]      0.01      0.01     0.00     0.03 1.00     5095     2179
#> mobalance_z1[28]      0.01      0.01     0.00     0.03 1.00     4820     2354
#> mobalance_z1[29]      0.01      0.01     0.00     0.03 1.00     5677     2176
#> mobalance_z1[30]      0.01      0.01     0.00     0.03 1.00     3996     1628
#> mobalance_z1[31]      0.01      0.01     0.00     0.03 1.00     4611     2095
#> mobalance_z1[32]      0.01      0.01     0.00     0.03 1.00     4589     2383
#> mobalance_z1[33]      0.01      0.01     0.00     0.04 1.00     4461     2069
#> mobalance_z1[34]      0.01      0.01     0.00     0.04 1.00     4870     2273
#> mobalance_z1[35]      0.01      0.01     0.00     0.04 1.00     4158     2258
#> mobalance_z1[36]      0.01      0.01     0.00     0.05 1.00     4608     2272
#> mobalance_z1[37]      0.01      0.01     0.00     0.03 1.00     4684     2317
#> mobalance_z1[38]      0.01      0.01     0.00     0.04 1.00     4751     2135
#> mobalance_z1[39]      0.01      0.01     0.00     0.04 1.00     4574     1990
#> mobalance_z1[40]      0.01      0.01     0.00     0.04 1.00     4112     1883
#> mobalance_z1[41]      0.01      0.01     0.00     0.04 1.00     4669     2417
#> mobalance_z1[42]      0.01      0.01     0.00     0.04 1.00     4500     2605
#> mobalance_z1[43]      0.01      0.01     0.00     0.03 1.00     4377     2269
#> mobalance_z1[44]      0.01      0.01     0.00     0.03 1.00     4604     1993
#> mobalance_z1[45]      0.01      0.01     0.00     0.02 1.00     3477     1809
#> mobalance_z1[46]      0.01      0.01     0.00     0.03 1.00     4438     2174
#> mobalance_z1[47]      0.01      0.01     0.00     0.04 1.00     4684     1980
#> mobalance_z1[48]      0.01      0.01     0.00     0.04 1.00     4183     2111
#> mobalance_z1[49]      0.01      0.01     0.00     0.04 1.00     4681     2408
#> mobalance_z1[50]      0.01      0.01     0.00     0.04 1.00     3859     2319
#> mobalance_z1[51]      0.01      0.01     0.00     0.03 1.00     4676     2576
#> mobalance_z1[52]      0.01      0.01     0.00     0.03 1.00     4439     2122
#> mobalance_z1[53]      0.01      0.01     0.00     0.03 1.00     3928     2186
#> mobalance_z1[54]      0.01      0.01     0.00     0.03 1.00     4613     2317
#> mobalance_z1[55]      0.01      0.01     0.00     0.03 1.00     3526     1821
#> mobalance_z1[56]      0.02      0.01     0.00     0.05 1.00     3874     2313
#> mobalance_z1[57]      0.01      0.01     0.00     0.03 1.00     4130     2376
#> mobalance_z1[58]      0.01      0.01     0.00     0.03 1.00     4039     2151
#> mobalance_z1[59]      0.01      0.01     0.00     0.02 1.00     4933     2306
#> mobalance_z1[60]      0.00      0.00     0.00     0.02 1.00     4654     1953
#> mobalance_z1[61]      0.01      0.01     0.00     0.02 1.00     4617     2312
#> mobalance_z1[62]      0.01      0.01     0.00     0.03 1.00     5524     2330
#> mobalance_z1[63]      0.01      0.01     0.00     0.03 1.00     4636     2432
#> mobalance_z1[64]      0.01      0.01     0.00     0.03 1.00     5187     2402
#> mobalance_z1[65]      0.01      0.01     0.00     0.03 1.00     4178     2290
#> mobalance_z1[66]      0.01      0.01     0.00     0.03 1.00     3961     1948
#> mobalance_z1[67]      0.01      0.01     0.00     0.03 1.00     4865     2479
#> mobalance_z1[68]      0.01      0.01     0.00     0.04 1.00     4381     2439
#> mobalance_z1[69]      0.01      0.01     0.00     0.03 1.00     4884     2172
#> mobalance_z1[70]      0.01      0.01     0.00     0.04 1.00     4195     2540
#> mobalance_z1[71]      0.01      0.01     0.00     0.04 1.00     4179     2315
#> mobalance_z1[72]      0.01      0.01     0.00     0.03 1.00     4380     2271
#> mobalance_z1[73]      0.01      0.01     0.00     0.03 1.00     4572     2336
#> mobalance_z1[74]      0.01      0.01     0.00     0.04 1.00     4203     2037
#> mobalance_z1[75]      0.01      0.01     0.00     0.04 1.00     4654     2324
#> mobalance_z1[76]      0.01      0.01     0.00     0.04 1.00     4220     2476
#> mobalance_z1[77]      0.01      0.01     0.00     0.03 1.00     4841     1890
#> mobalance_z1[78]      0.01      0.01     0.00     0.03 1.00     5277     2308
#> mobalance_z1[79]      0.01      0.01     0.00     0.03 1.00     4343     2102
#> mobalance_z1[80]      0.01      0.01     0.00     0.03 1.00     5498     2231
#> mobalance_z1[81]      0.01      0.01     0.00     0.04 1.00     4961     2326
#> mobalance_z1[82]      0.01      0.01     0.00     0.04 1.00     4089     1933
#> mobalance_z1[83]      0.01      0.01     0.00     0.03 1.00     4183     1994
#> mobalance_z1[84]      0.01      0.01     0.00     0.03 1.00     4588     1927
#> mobalance_z1[85]      0.01      0.01     0.00     0.03 1.00     5404     2275
#> mobalance_z1[86]      0.01      0.01     0.00     0.03 1.00     5821     2564
#> mobalance_z1[87]      0.01      0.01     0.00     0.03 1.00     5142     2189
#> mobalance_z1[88]      0.01      0.01     0.00     0.03 1.00     5861     2145
#> mobalance_z1[89]      0.01      0.01     0.00     0.03 1.00     5379     2073
#> mobalance_z1[90]      0.01      0.01     0.00     0.03 1.00     5084     2213
#> mobalance_z1[91]      0.01      0.01     0.00     0.04 1.00     5395     2180
#> mobalance_z1[92]      0.01      0.01     0.00     0.04 1.00     4791     2243
#> mobalance_z1[93]      0.01      0.01     0.00     0.04 1.00     4678     2364
#> mobalance_z1[94]      0.01      0.01     0.00     0.04 1.00     4451     1878
#> mobalance_z1[95]      0.01      0.01     0.00     0.04 1.00     4803     2434
#> mobalance_z1[96]      0.01      0.01     0.00     0.04 1.00     4953     2087
#> mobalance_z1[97]      0.01      0.01     0.00     0.04 1.00     5246     2313
#> mobalance_z1[98]      0.01      0.01     0.00     0.04 1.00     5414     2866
#> mobalance_z1[99]      0.01      0.01     0.00     0.04 1.00     5882     3083
#> mobalance_z1[100]     0.01      0.01     0.00     0.04 1.00     5982     2579
#> 
#> Draws were sampled using sample(hmc). For each parameter, Bulk_ESS
#> and Tail_ESS are effective sample size measures, and Rhat is the potential
#> scale reduction factor on split chains (at convergence, Rhat = 1).
plot(m2)

pp_check(m2, type = "bars")
#> Using 10 posterior draws for ppc type 'bars' by default.

Created on 2023-08-02 with reprex v2.0.2

And this is what the Stan code for the model looks like using stancode(m2)

// generated with brms 2.19.0
functions {
  /* compute monotonic effects
   * Args:
   *   scale: a simplex parameter
   *   i: index to sum over the simplex
   * Returns:
   *   a scalar between 0 and rows(scale)
   */
  real mo(vector scale, int i) {
    if (i == 0) {
      return 0;
    } else {
      return rows(scale) * sum(scale[1 : i]);
    }
  }
}
data {
  int<lower=1> N; // total number of observations
  array[N] int Y; // response variable
  int<lower=1> Ksp; // number of special effects terms
  int<lower=1> Imo; // number of monotonic variables
  array[Imo] int<lower=1> Jmo; // length of simplexes
  array[N] int Xmo_1; // monotonic variable
  vector[Jmo[1]] con_simo_1; // prior concentration of monotonic simplex
  int prior_only; // should the likelihood be ignored?
}
transformed data {
  
}
parameters {
  real Intercept; // temporary intercept for centered predictors
  simplex[Jmo[1]] simo_1; // monotonic simplex
  vector[Ksp] bsp; // special effects coefficients
}
transformed parameters {
  real lprior = 0; // prior contributions to the log posterior
  lprior += student_t_lpdf(Intercept | 3, 0, 2.5);
  lprior += dirichlet_lpdf(simo_1 | con_simo_1);
}
model {
  // likelihood including constants
  if (!prior_only) {
    // initialize linear predictor term
    vector[N] mu = rep_vector(0.0, N);
    mu += Intercept;
    for (n in 1 : N) {
      // add more terms to the linear predictor
      mu[n] += bsp[1] * mo(simo_1, Xmo_1[n]);
    }
    target += bernoulli_logit_lpmf(Y | mu);
  }
  // priors including constants
  target += lprior;
}
generated quantities {
  // actual population-level intercept
  real b_Intercept = Intercept;
}

This suggestion sounds exactly like what I am trying to do

@Bob_Carpenter
If your system makes probabilistic predictions for zi you can do the same thing in expectation. If you predict 0.8 for item i and zi=1, then you get a 0.8 contribution to the numerator of the total for sensitivity. If instead, zi=0, then you get a 0.2 contribution to the numerator for specificity.

But really not sure about how to actually code it correctly. Any suggestions very gratefully received.

Thanks!

Just hoping someone has an example of how to code these calculations. Thanks!

Here is how I would code the calculations for the sensitivity and specificity of a probabilistic predictive model.

First the overall sensitivity and specificity:

# Calculate posterior expected values E(Y | balance_z, zi) for each observed data point.
fitted_data <- brms_summary(mod, data)

# sensitivity = E{Y | z = 1}
# specificity = 1 - E{Y | z = 0}
fitted_data %>%
  summarise(
    sensitivity = sum(if_else(default == 1, meanY, 0)) / sum(default == 1),
    specificity = sum(if_else(default == 0, 1 - meanY, 0)) / sum(default == 0)
  )

And then sensitivity & specificity by balance_z:

# Calculate posterior expected values E(Y | balance_z, zi) for each possible combination of balance_z and z in {0, 1}.
newdata <- crossing(balance_z = seq(0, 100, 1), default = c(0, 1))
fitted_newdata <- brms_summary(mod, newdata) 

# sensitivity = E{Y | balance_z, z = 1} = E{Y | balance_z}
# specificity = 1 - E{Y | balance_z, z = 0} = 1 - E{Y | balance_z}
fitted_newdata %>%
  group_by(
    balance_z
  ) %>%
  summarise(
    sensitivity = sum(if_else(default == 1, meanY, 0)) / sum(default == 1),
    specificity = sum(if_else(default == 0, 1 - meanY, 0)) / sum(default == 0)
  )

When you plot the sensitivity and specificity curves as a function of balance_z, you’ll notice that the “balance-specific sensitivity” is the regression curve E(Y | balance_z) and the “balance-specific specificity” is its mirror image 1 - E(Y | balance_z). This is happens because – given a value balance_z – the model always makes the same probabilistic predictions for the events Y = 1 and Y = 0.

1 Like

Thanks so much!

Just checking a few things here.

  1. I couldn’t find a brms_summary() function, so have rewritten using the tidybayes::add_fitted_draws() function. Hoping this doesn’t introduce any new errors! (add_epred_draws() should give the same answer, as per the warning in the code output below)

  2. I am not 100% convinced that this approach is correct. I would expect that, at a low value of balance_z sensitivity would be high, and specificity correspondingly low.

However, summarising the model with your approach seems to give the opposite.

# Calculate posterior expected values E(Y | balance_z, zi) for each observed data point.
fitted_data <- d %>%
  add_fitted_draws(m1)
#> Warning: `fitted_draws` and `add_fitted_draws` are deprecated as their names were confusing.
#> - Use [add_]epred_draws() to get the expectation of the posterior predictive.
#> - Use [add_]linpred_draws() to get the distribution of the linear predictor.
#> - For example, you used [add_]fitted_draws(..., scale = "response"), which
#>   means you most likely want [add_]epred_draws(...).
#> NOTE: When updating to the new functions, note that the `model` parameter is now
#>   named `object` and the `n` parameter is now named `ndraws`.

# sensitivity = E{Y | z = 1}
# specificity = 1 - E{Y | z = 0}
fitted_data %>%
  ungroup() %>%
  summarise(
    sensitivity = sum(if_else(default == 1, .value, 0)) / sum(default == 1),
    specificity = sum(if_else(default == 0, 1 - .value, 0)) / sum(default == 0)
  )
#> # A tibble: 1 × 2
#>   sensitivity specificity
#>         <dbl>       <dbl>
#> 1       0.348       0.978


# Calculate posterior expected values E(Y | balance_z, zi) for each possible combination of balance_z and z in {0, 1}.
newdata <- crossing(balance_z = seq(0, 100, 1), default = c(0, 1))
fitted_newdata <- add_fitted_draws(m1, newdata=newdata)
#> Warning: `fitted_draws` and `add_fitted_draws` are deprecated as their names were confusing.
#> - Use [add_]epred_draws() to get the expectation of the posterior predictive.
#> - Use [add_]linpred_draws() to get the distribution of the linear predictor.
#> - For example, you used [add_]fitted_draws(..., scale = "response"), which
#>   means you most likely want [add_]epred_draws(...).
#> NOTE: When updating to the new functions, note that the `model` parameter is now
#>   named `object` and the `n` parameter is now named `ndraws`.

# sensitivity = E{Y | balance_z, z = 1} = E{Y | balance_z}
# specificity = 1 - E{Y | balance_z, z = 0} = 1 - E{Y | balance_z}
results <- fitted_newdata %>%
  group_by(
    balance_z
  ) %>%
  summarise(
    sensitivity = sum(if_else(default == 1, .value, 0)) / sum(default == 1),
    specificity = sum(if_else(default == 0, 1 - .value, 0)) / sum(default == 0)
  )

#plot the results
results %>%
  pivot_longer(cols = c(sensitivity, specificity)) %>%
  ggplot() +
  geom_line(aes(x=balance_z, y=value, colour=name)) +
  facet_grid(name~.) +
  scale_colour_manual(values = c("mediumseagreen", "purple"), name="")

Created on 2023-08-29 with reprex v2.0.2

  1. I am also keen to estimate uncertainty around sensitivity and specificity, and any suggestions for this also gratefully received.

I tend to mix up sensitivity and specificity, so I’ll start with the definition: sensitivity = TP / P (proportion of true positives) and the probabilistic version at a given balance_z is E{Y=1 | balance_z}.

The example data has no 1s at low balance_z (your data might be different though), so let’s say E{Y=1 | balance_z is low} = eps since the model won’t predict exactly 0. So the specificity at low balance_z is eps? In other words, this regression won’t predict a 1 at low balance_z. If the true Y happens to be 1 (and balance_z is low), then the model will make a mistake most of the time. Does this sound right?

brms_summary() is a helper function to calculate the posterior means from posterior_epred and add them to the input data frame. Gave it an official-looking name, unfortunately.

brms_summary <- function(model, data) {
  summary <- model %>%
    posterior_epred(
      newdata = data
    ) %>%
    summarise_draws("mean") %>%
    rename(meanY = mean)
  bind_cols(data, summary)
}
1 Like

Thanks so much for clarifying, and the brms_summary() function works.

Still not 100% clear on the output thought. I guess, to explain my logic in a little more detail:

We are using our index test (balance_z, ranging from 0:100) as a decision tool to detect the presence of default.

Image we set balance_z to have a threshold of >=5, with all people with a balance_z score of >=5 then predicted to default. In this case, the sensitivity is high (very few true positive default cases are missed). But specificity is low (we also flag lots of people as balance_z positive, but who don’t default).

Alternatively, if we set our threshold for balance_z to >=95, then our specificity will be high, with most true negatives correctly classified.

However, in the output from the brms_summary() function that I graphed above, at lower values of balance_z, sensitivity is low. And at higher values of balance_z specificity is low.

Hope this makes sense. Apologies if I am completely confused, but have been struggling with how to do this correctly for a while, and really great to have your insights.

1 Like

I understood your question to be about a calculation for each value of balance_z. For example:

(sensitivity | balance_z =  5) = ? # assume balance_z is 5, so very low
(sensitivity | balance_z = 95) = ? # assume balance_z is 95, so very high

This is different from:

(sensitivity | balance_z >=  5) = ? # assume balance_z can be anywhere from 5 to 100 (so from very low to very high & anything in between)
(sensitivity | balance_z >= 95) = ? # assume balance_z is between 95 and 100, so very high
1 Like

Yes - apologies for my lack of clarity, the second is what I am trying to do.

Thanks

@desislava Based on your very helpful suggestions and code above (thanks!) I have edited my code to calculate sensitivity and specificity based on being equal to or above the selected threshold value for balance_z.

I am hoping this is now correct!

Only thing I would now be grateful for assistance with is how to additionally calculate uncertainty bands around our mean estimates of sensitivity and specificity.

#brms_summary function (from above)
brms_summary <- function(model, data) {
  summary <- model %>%
    posterior_epred(
      newdata = data
    ) %>%
    summarise_draws("mean") %>%
    rename(meanY = mean)
  bind_cols(data, summary)
}

#Set an example threshold for balance_z
balance_z_threshold <- 60



#test it out for one threshold value of `balance_z`
brms_summary(m1, d) %>%
  summarise(
    sensitivity = sum(if_else((balance_z >= balance_z_threshold & default==1), meanY, 0)) / (sum(default==1)),
    specificity = sum(if_else((balance_z <  balance_z_threshold & default==0), 1-meanY, 0)) / (sum(default==0))
  ) %>%
    mutate(balance_z_threshold=balance_z_threshold)
#>   sensitivity specificity balance_z_threshold
#> 1   0.3286925   0.9421319                  60


#Now wrap this up into a function
accuracy_func <- function(data, model, threshold){
  brms_summary(model, data) %>%
  summarise(
    sensitivity = sum(if_else((balance_z >= {{threshold}} & default==1), meanY, 0)) / (sum(default==1)),
    specificity = sum(if_else((balance_z <  {{threshold}} & default==0), 1-meanY, 0)) / (sum(default==0))
  ) %>%
    mutate(balance_z_threshold={{threshold}})
}

#test it out for one value of balance_z_threshold
accuracy_func(data=d, model=m1, threshold=60)
#>   sensitivity specificity balance_z_threshold
#> 1   0.3286925   0.9421319                  60

#make a vector of thresholds
threshold_vec <- seq(from=0, to=100, by=1)

#run the function for all values of `balance_z_threshold`
all_results <- map_df(threshold_vec, ~accuracy_func(data = d, model = m1, threshold = .x))

#now plot
all_results %>%
  pivot_longer(cols = c(sensitivity, specificity)) %>%
  ggplot() +
  geom_line(aes(x=balance_z_threshold, y=value, colour = name)) +
  facet_grid(name~.) +
  scale_colour_manual(values = c("mediumseagreen", "purple"), name="")

Created on 2023-08-29 with reprex v2.0.2

2 Likes

With balance_z >= {{threshold}} you condition on balance_z >= threshold in the numerator of the sensitivity/specificity formulas. You need to apply the condition to the denominator as well. Something like:

sensitivity = sum(if_else((balance_z >= {{threshold}} & default==1), meanY, 0)) / (sum(balance_z >= {{threshold}} & default==1))

It’s also important to consider that you use the data to estimate the distribution of balance_z values when Y = 1 and when Y = 0. (That’s implied in the calculation above.) You’ll get different sensitivity & specificity curves with the same predictive model E(Y | balance_z) but a different sample of 0s and 1s across the range of balance_z.

2 Likes

Thanks so much! That is working brilliantly now - much appreciated!

For your second point, if I understand you correctly, we can’t construct a newdata dataframe to predict over, as the distribution of 0s and 1s will not reflect the underlying structure of the dataset. So we couldn’t use this to predict over:

newdata <- crossing(balance_z = seq(0, 100, 1), default = c(0, 1))

I wonder then if it is possible to construct a dataset to predict over with weights to indicate the distribution of 0s and 1s, and use the weights argument of summarise_draws() to correctly calculate summary estimates of sensitivity and specificity, something like this:

#calculate weights for prediction dataframe 
d_props <- d %>% 
  group_by(balance_z, default) %>% 
  count() %>%
  ungroup() %>%
  complete(balance_z, default, fill=list(n=0)) %>%
  mutate(prop = n/sum(n)) 

#function to summarise weighted draws
brms_summary_weighted <- function(model, data, weights) {
  summary <- model %>%
    posterior_epred(
      newdata = data,
      .args = list(w = {{weights}})
    ) %>%
    summarise_draws("mean", ~quantile(.x, probs = c(.025, .975))) %>%
    rename(meanY = mean,
           mean2_5 = `2.5%`,
           mean97_5 = `97.5%`)
  bind_cols(data, summary)
}

#run the function
brms_summary_weighted(model = m1, data = d_props, weights = n)

Note: Really not sure if this is sensible or not! For the weights argument, I specified n as the weighting variable, but not sure if it should be prop (if indeed this approach makes sense).

Thanks again for all the insights - I am really learning a lot!

1 Like

Exactly. If you have a different sample new_d (perhaps representing a different population), you can apply the model m1 to new_d and you’ll get somewhat different sensitivity and specificity curves. But d0 = crossing(balance_z, default) represents no meaningful population, so there is no point in the calculation accuracy_func(m1, d0). Perhaps not such an interesting point to make.

And no need to weight the posterior draws: d_prop is the data you used to fit the logistic regression. (Actually it’s less data because you lose information when you normalize counts into frequencies: it’s the difference between 1 / 10 and 100/ 1000 when you are trying to estimate proportions.)

2 Likes

Yes - that makes sense. And thank you so much for all the assistance here - really learned a lot!

2 Likes

Just to close this off, I have put all of this example code together here, in case useful for others. (hopefully I haven’t introduced and errors!)

Thanks again for all the assistance.

#required libraries
library(tidyverse)
library(brms)
library(tidybayes)
library(marginaleffects)

#load some data
d <- ISLR::Default

#prepare data
#note rescaling balance to between 0:100, just to match my real data
d <- d %>%
  mutate(default = factor(case_when(
    default=="Yes" ~ 1,
    default=="No" ~ 0))) %>%
  mutate(balance_z = round((balance - min(d$balance)) / (max(d$balance) - min(d$balance)) * 100)) %>%
  mutate(pid = row_number())

#what does the distribution of threshold scores look like?
d %>%
  ggplot() +
  geom_density(aes(x=balance_z, colour=default, fill=default), alpha=0.5) +
  scale_fill_manual(values = c("mediumseagreen", "purple")) +
  scale_colour_manual(values = c("mediumseagreen", "purple"))


#regression model
m1 <- brm(
  default ~ mo(balance_z),
  data = d,
  family=bernoulli(),
  backend="cmdstanr",
  cores=4,
  chains=4
)

#function to summarise draws
brms_summary <- function(model, data) {
  summary <- model %>%
    posterior_epred(
      newdata = data
    ) %>%
    summarise_draws("mean", ~quantile(.x, probs = c(.025, .975))) %>%
    rename(meanY = mean,
           mean2_5 = `2.5%`,
           mean97_5 = `97.5%`)
  bind_cols(data, summary)
}

#sensitivity and specificity function
#calculates sensitivity and specificity against reference standard for a predictor above a particular threshold
#TP=True Positive
#FP=False Positive
#TN=True Negative
#FN=False Negative
#Sensitivity= TP/(TP+FN)
#Specificity= TN/(TN+FP)

sens_spec_func <- function(data, outcome, predictor, threshold){
  data %>%
    summarise(
      #sensitivity calculations
      sensitivity_mean =  sum(if_else(({{predictor}} >= {{threshold}} & {{outcome}}==1), meanY, 0)) /#TP
                         (sum(if_else(({{predictor}} >= {{threshold}} & {{outcome}}==1), meanY, 0)) +
                         (sum(if_else(({{predictor}} <  {{threshold}} & {{outcome}}==1), meanY, 0)))),#TPs/(TPs+FNs)
  
      
      sensitivity_lb =    sum(if_else(({{predictor}} >= {{threshold}} & {{outcome}}==1), mean2_5, 0)) /#TP
                         (sum(if_else(({{predictor}} >= {{threshold}} & {{outcome}}==1), mean2_5, 0)) +
                         (sum(if_else(({{predictor}} <  {{threshold}} & {{outcome}}==1), mean2_5, 0)))), #TPs/(TPs+FNs)
      
      
      sensitivity_ub =    sum(if_else(({{predictor}} >= {{threshold}} & {{outcome}}==1), mean97_5, 0)) /#TP
                         (sum(if_else(({{predictor}} >= {{threshold}} & {{outcome}}==1), mean97_5, 0)) +
                         (sum(if_else(({{predictor}} <  {{threshold}} & {{outcome}}==1), mean97_5, 0)))), #TPs/(TPs+FNs)
      
      #specificity calculations
      specificity_mean =  sum(if_else(({{predictor}} <  {{threshold}} & {{outcome}}==0), meanY, 0)) /#TN
                         (sum(if_else(({{predictor}} <  {{threshold}} & {{outcome}}==0), meanY, 0)) +
                         (sum(if_else(({{predictor}} >= {{threshold}} & {{outcome}}==0), meanY, 0)))), #TNs/(TNs+FPs)
      
      specificity_lb =    sum(if_else(({{predictor}} <  {{threshold}} & {{outcome}}==0), mean2_5, 0)) /#TN
                         (sum(if_else(({{predictor}} <  {{threshold}} & {{outcome}}==0), mean2_5, 0)) +
                         (sum(if_else(({{predictor}} >= {{threshold}} & {{outcome}}==0), mean2_5, 0)))), #TNs/(TNs+FPs)
      
      
      specificity_ub =   sum(if_else(({{predictor}} <  {{threshold}} & {{outcome}}==0), mean97_5, 0)) /#TN
                        (sum(if_else(({{predictor}} <  {{threshold}} & {{outcome}}==0), mean97_5, 0)) +
                        (sum(if_else(({{predictor}} >= {{threshold}} & {{outcome}}==0), mean97_5, 0)))) #TNs/(TNs+FPs)
    )
}



#test it out for one threshold value of `balance_z`

#Set an example threshold for balance_z
balance_z_threshold <- 50

#run the function
brms_summary(m1, d) %>%
    sens_spec_func(data=., outcome=default, predictor=balance_z, threshold=balance_z_threshold) %>%
    mutate(balance_z_threshold=balance_z_threshold)
#>   sensitivity_mean sensitivity_lb sensitivity_ub specificity_mean
#> 1        0.9959157      0.9969205      0.9948712        0.1649944
#>   specificity_lb specificity_ub balance_z_threshold
#> 1      0.1294146      0.1963148                  50


#Now wrap all of this up into a single function
accuracy_func <- function(data, model, predictor, outcome, threshold){
  brms_summary(model, data) %>%
    sens_spec_func(data=., outcome={{outcome}}, predictor={{predictor}}, threshold={{threshold}}) %>%
    mutate(balance_z_threshold={{threshold}})
}

#test it out for one value of balance_z_threshold
accuracy_func(data=d, model=m1, predictor=balance_z, outcome=default, threshold=50)
#>   sensitivity_mean sensitivity_lb sensitivity_ub specificity_mean
#> 1        0.9959157      0.9969205      0.9948712        0.1649944
#>   specificity_lb specificity_ub balance_z_threshold
#> 1      0.1294146      0.1963148                  50


#make a vector of thresholds
threshold_vec <- seq(from=0, to=100, by=1)

#run the function for all values of `balance_z_threshold`
all_results <- map_df(threshold_vec, ~accuracy_func(data = d, model = m1, predictor=balance_z, outcome=default, threshold = .x))

#now plot
all_results %>%
  ggplot() +
  geom_ribbon(aes(x=balance_z_threshold, ymin=sensitivity_lb, ymax=sensitivity_ub, fill="Sensitivity"), alpha=0.5) +
  geom_ribbon(aes(x=balance_z_threshold, ymin=specificity_lb, ymax=specificity_ub, fill="Specificity"), alpha=0.5) +
  geom_line(aes(x=balance_z_threshold, y=sensitivity_mean, colour="Sensitivity")) +
  geom_line(aes(x=balance_z_threshold, y=specificity_mean, colour="Specificity")) +
  scale_colour_manual(values = c("mediumseagreen", "purple"), name="") +
  scale_fill_manual(values = c("mediumseagreen", "purple"), name="") +
  scale_y_continuous(labels=scales::percent, name="")

Created on 2023-09-05 with reprex v2.0.2

2 Likes