Using Loo 2 for grouped k-fold cv

loo

#1

If I understand correctly, one should be able to use the kfold-helpers in the Loo 2 package to do a grouped k-fold cv. However, I do not understand exactly how to do it.

As a toy example, say that I have an experiment with some participants and I am interested in whether them being correct influences their response time.

I define the two models:

# Define null model
rtm0 <- stan_glmer(rt~+(1|id),
                 data=data, family=gaussian)

# Define model with one predictor
rtm1 <- update(rtm0,.~.+acc)

The classic way of comparing them would then be

### Compare with classic loo
loo_rtm0 <- loo(rtm0)
loo_rtm1 <- loo(rtm1)
compare(loo_rtm0,loo_rtm1)

But if I understand correctly, this answers how the participants should respond if they are getting one additional trial. The more interesting question is how a new participant will do.

To test this, I hope to use the kfold_split_stratified function something like this:

### Compare with grouped k-fold cv
kfold_split_stratified(K = nrparticipants, x = data$id)

But I am not sure what to do afterwards? Having gone through the vignettes and some googling has thus far not helped.

Any help would be appreciated. The specific code can be found here and the dataset, E1_gbr.csv can be found here.


#2

See http://mc-stan.org/rstanarm/reference/loo.stanreg.html and try

folds <- kfold_split_stratified(K = nrparticipants, x = data$id)
kfold(rtm0, folds = folds)

Unfortunately we haven’t had time to make case studies or vignettes for kfold. Note that, in your case even if kfold works, compare function will compute sd for difference as for loo, while it would be better to compute sd respecting also the grouping structure.


#3

Thanks for taking time to help, Aki. As always, I very much appreciate this.

Looking at your link and comment, I can define and run kfold in a non-grouped manner like this (5 fold for computational reasons):

kfng_rtm0 <- kfold(rtm0, K = 5)
kfng_rtm1 <- kfold(rtm1, K = 5)
compare(kfng_rtm0,kfng_rtm1)

And get a result that seems close to using the compare function on loo.

However when I try to run it in a grouped manner like so:

folds <- kfold_split_stratified(K = nrparticipants, x = data$id)
kf_rtm0 <- kfold(rtm0, folds = folds)

I am getting the following error:

Error: length(folds) == N is not TRUE

My first intuition was to add something like N = nrparticipants, but that returns
unused argument (n = nrparticipants)

Provided that I can get the upper code to work is there some obvious way for me to address that?


#4

There seems to be a bug (can you make an issue and @jonah will fix it?)
This seems to work

kf_rtm0 <- kfold(rtm0, K = nrparticipants, folds = folds)

kf_rtm0$pointwise has individual elpd’s. Sum together elpd’s common to each id, and then you should have as many sum elpds as there are unique ids. Then when you make the pairwise comparison of models, use the diff and sd of diff of these sum elpds.


#5

Tried to use the code:

kf_rtm0 <- kfold(rtm0, K = nrparticipants, folds = folds)

But it returns the same error:

Error: length(folds) == N is not TRUE

I will happily report it a github - provided that it isn’t just me doing something wrong?


#6

I’m quite certain that the problem is in rstanarm code, as we haven’t had time to test it properly and some fixes are only in develop branch in github.

Are you using rstanarm from CRAN? It’s likely that CRAN rstanarm version is different from the develop version in github. If you are brave enough, you can try

 install_github("stan-dev/rstanarm", build_vignettes = FALSE)

#7

@simon.dp Thanks for sharing the code and the data. I’ll try to reproduce this and get back to you.


#8

Ok, this is actually related to the issue in the post

although it might not seem obviously related. This should already be fixed on GitHub but until the next CRAN release you can get around the error by dropping the variables from data that you’re not using in the model (or more precisely, any variable not used in the model that also has NAs).

In your case, try this:

data <- data[, c("rt", "id", "acc")]
rtm0 <- stan_glmer(..., data = data)

nrparticipants <- length(unique(data$id))
folds <- kfold_split_stratified(K = nrparticipants, x = data$id)
kf_rtm0 <- kfold(rtm0, K = nrparticipants, folds = folds)

When the next CRAN release comes out the data <- data[, c("rt", "id", "acc")] line shouldn’t be necessary.


#9

Thanks for your feedback, Jonah and Aki.

I am currently running rstanarm_2.17.4 will try to install the develop version next, but first giving you some feedback that may or may not be useful for you.

Using Jonah’s workaround:

data <- data[, c("rt", "id", "acc")]
...
folds <- kfold_split_stratified(K = nrparticipants, x = data$id)
kf_rtm0 <- kfold(rtm0, K = nrparticipants, folds = folds)

Returns:

Error in UseMethod(“kfold”) :
no applicable method for ‘kfold’ applied to an object of class “c(‘stanreg’, ‘glm’, ‘lm’, ‘lmerMod’)”

I am getting the same error using
kfold(rtm0, K = 3)
Which worked fine before.

However, the loo function seems to work well
loo(rtm0)

If I remember correctly, the loo package should also be able to handle models generated by BRMS, so I tested that too.

I defined the models equivalently to the stan models:

brms0 <- brm(rt~+(1|data=data, family=gaussian)
brms1 <- brm(rt~+(1|id) + data=data, family=gaussian)

Using these two functions worked fine

loo(brms0)
kfold(brms0, K = 3)

But using

folds <- kfold_split_stratified(K = nrparticipants, x = data$id)
kf_rtm0 <- kfold(brms0, folds = folds)

Returns the following:

Fitting model 1 out of 10
Start sampling
starting worker pid=29740 on localhost:11099 at 08:10:21.589
starting worker pid=29752 on localhost:11099 at 08:10:21.829
starting worker pid=29764 on localhost:11099 at 08:10:22.072
starting worker pid=29776 on localhost:11099 at 08:10:22.309
Error in checkForRemoteErrors(val) :
4 nodes produced errors; first error: passing unknown arguments: folds.

Will return, when I have tested the develop version.


#10

Slightly embarrassed to admit this, but I am not sure if I’ve installed the develop version or not…

I ran this code:

if (!require(devtools)) {
install.packages(“devtools”)
library(devtools)
}
install_github(“stan-dev/rstanarm”, build_vignettes = FALSE)

And as far as I could see the installation was successful


** building package indices
** installing vignettes
** testing if installed package can be loaded

  • DONE (rstanarm)

However, when I run sessionInfo() afterwards, I still get “rstanarm_2.17.4” as the attached package.

When trying to run the kfold function, I still get.

Error in UseMethod(“kfold”) :
no applicable method for ‘kfold’ applied to an object of class “c(‘stanreg’, ‘glm’, ‘lm’, ‘lmerMod’)”


#11

Develop version has the same version number until just before making a new release.

Do you have some other package with kfold function masking? Try

environment(kfold)

#12

Your intuition was correct - apparently brms took over that function.

Now it at least seems like everything works. I have redone the code for another example (the chimps dataset from rethinking) as there are fewer individuals and it is more readily available if others should be in the same situation as I am.

My only question for now (hopefully) is whether I do a correct comparison.

compare(kf_m0,kf_m1)

returns

elpd_diff se
-3.8 1.0

And my rough code

for(i in 1:nractors){
m0par[i] <- sum(kf_m0$pointwise[(folds==i),])
m1par[i] <- sum(kf_m1$pointwise[(folds==i),])
}

round(sum(m1par-m0par),2) # Diff
round(sd(m1par-m0par),2) # se

Returns

-3.77
0.9

So it could at least look to me like they are relatively comparable and that in this case, we should conclude that a model including condition seems to be worse in predicting the outcome for a new chimp. Do you agree with this approach + interpretation?


#13

This should be multiplied by sqrt(nractors) (similar as in eq 24 for LOO, multplied by sqrt(n))
so you will have a bit more uncertainty reflecting the fact that the variation between ractors is larger than within variation.


#14

I forgot the link for eq 24 in https://arxiv.org/abs/1507.04544


#15

This seems to work well. Thanks for taking your time to help me, @avehtari and @jonah.