PSIS-LOO classification accuracy for categorical outcomes

This example shows how to compute approximate loo classification accuracy for binary outcomes. However, I seem not to be savvy enough to replicate that when the response has >2 categories. (I know it’s not the best metric, but my audience will probably want to see it.) Here’s an example of how far I get, using an über-simple example dataset:

Gators <- read.table("http://www.stat.ufl.edu/~aa/cat/data/Alligators.dat", header = TRUE)
require(brms)
mymod.bayes <- brm(y ~ x, data = Gators, family = categorical, prior = prior(normal(0,4), class = b))
myloo <- loo(mymod.bayes, save_psis = TRUE)
preds <- posterior_epred(mymod.bayes)
pred <- colMeans(preds)
loopreds <- E_loo(preds, myloo$psis_object, type = "mean", log_ratios = -log_lik(mymod.bayes))$value

Error in E_loo.default(preds, myloo$psis_object, type = “mean”, log_ratios = -log_lik(mymod.bayes)) : **
** length(x) == dim(psis_object)[1] is not TRUE

Can this actually be done, or are the functions not yet applicable to multicategory responses?

3 Likes

In your case, posterior_epred() returns a 3-dimensional array. Therefore, you need to do something like this:

loopreds <- apply(preds, 3, function(x){
  E_loo(x, myloo$psis_object, type = "mean", log_ratios = -log_lik(mymod.bayes))$value
})

However, note that the LOO probabilities from loopreds might not sum to 1 (within each row of loopreds). Therefore, at least the following check should be made if you want to compute the LOO-accuracy analogously to the binary example you mentioned:

stopifnot(isTRUE(all.equal(
  rowSums(loopreds),
  rep(1, nrow(loopreds)),
  tolerance = 1e-14
)))

Then you can proceed like this:

y_loopredict <- as.numeric(apply(loopreds, 1, which.max))
table(y_loopredict, y = mymod.bayes$data$y, useNA = "ifany")
( looacc <- mean(y_loopredict == mymod.bayes$data$y) )

Alternatively, you may compute an “individual LOO-accuracy”:

y_predict <- apply(preds, c(1, 2), which.max)
iacc_raw <- sweep(y_predict, 2, mymod.bayes$data$y, "==")
mode(iacc_raw) <- "numeric"
looiacc <- E_loo(iacc_raw, myloo$psis_object, type = "mean", log_ratios = -log_lik(mymod.bayes))$value
# Now plot `looiacc` or do other analyses, e.g.:
quantile(looiacc)

Besides, I can only recommend to check the Pareto k-values returned by E_loo() (i.e. not only using element value, but also element pareto_k of E_loo()'s output object).

Also note that your object pred is not necessary for the LOO computations.

4 Likes

Many thanks for all that slick code! This may turn out to be even more helpful than I dared hope. Here are a couple of newbie follow-up questions, if I may:

  1. Are the contents of your loopred object the PSIS-LOO fitted values of the observations? In other words, are the three category probabilities in each row now based on the estimated “loo posterior” which corresponds approximately to the posterior distribution that would be obtained by refitting the model with that observation left out?

If it is, then that object will enable computation of not only an approximate LOO classification accuracy but also LOO estimates of numerous more sophisticated model-performance metrics, which is hugely valuable.

  1. Could you (or anyone) please explain, or provide a link to an explanation of, what “individual classification accuracy” is? The last code block flies a bit over my head, but I have a suspicion that it’s something which may prove useful.

To 1.:

Yes, each row of loopreds gives the posterior predictive distribution for that observation when leaving it out from the model fitting procedure (approximately, due to PSIS). The missing sum-to-one property has to be kept in mind, though (see also point 2 below).

I’m happy this is valuable for you in other ways, too.

To 2.:

For other data and a different model, I made the experience that the columns of loopreds occasionally do not to sum to 1. That made me question the reliability of loopreds for the purpose of using it as the (PSIS-)LOO posterior predictive distribution (in this case). And that’s why I calculated what I termed “individual LOO-accuracy”. Perhaps “pointwise LOO-accuracy” would be a better name. In that approach, you first determine the maximum-probability outcome category for each observation (individual) and each posterior draw. Then, you check whether that maximum-probability category was the observed one and assign 1 for “yes” and 0 for “no”. This is what iacc_raw contains. And then, for each individual, you average those zeros and ones over the posterior draws, but taking the PSIS-LOO weights into account. So you get something like \text{E}_{\theta|y_{(-i)}}\left(\mathbf{1}(\hat{y}_i = y_i) | \theta\right) = \text{P}(\hat{y}_i = y_i | y_{(-i)}) if you know what I mean (if not, tell me so).

1 Like

Thanks. This is great! I have just one question about the sum-to-one property. It’s been my understanding that categorical/multinomial models impose the sum-to-one constraint to their fitted values by default (using something called the softmax function) so that one need not worry about it.

I just now managed to compute my first serious LOO metric for a real model using your apply() trick. When inspecting the LOO fitted values I find that rowSums() returns all 1s, as it should. However, calling rowsums(myLOOfit) == 1 returns several FALSEs even though the rowSums() in question are reported as 1. They are reported as 1 even when I do round(rowsums(myLOOfit), digits = 10). I am thus inclined to suspect that it’s just a rounding issue, hence not a cause for concern.

Yes, that should only be a minor numerical inaccuracy. What I was talking about were row sums which were really far away from 1.

For checking equality with numerical inaccuracies taken into account, I would recommend the check proposed above:

stopifnot(isTRUE(all.equal(
  rowSums(loopreds),
  rep(1, nrow(loopreds)),
  tolerance = 1e-14
)))

where you can probably increase tolerance slightly if it fails with 1e-14 (the default tolerance is ca. 1.49e-8 on my machine).

1 Like

Got it. Thanks a lot for all your help!

Note that by default cmdstanr writes csv only with 6 digit accuracy. See the related issue generate_quantities can fail with simplex · Issue #420 · stan-dev/cmdstanr · GitHub

1 Like

Thanks for the information. But in this case, I experienced sums which were much closer to zero than to one. Unfortunately, I can’t reproduce that issue. So back then, it might have also been a coding error from my side. In any case, I thought adding the check would not harm.