Additional arguments to loo_subsample() not passed

Main message

I am trying apply loo_subsample() to a model with several parameter arrays of differing dimensionality. Following the procedure outlined in this vignette, I wrote a function for the log-likelihood whose draws argument accepts the output of extract(stanfit), which is a list. This log-likelihood function works fine when fed to loo_i(), but loo_subsample() itself throws the following error:

Error in .ndraws.default(draws) : 
  .ndraws() has not been implemented for objects of class 'list'

The error occurs when the line

checkmate::assert_int(loo_approximation_draws, lower = 1, upper = .ndraws(draws), null.ok = TRUE)

is called. I tried to solve this error by instead adding to my log-likelihood function separate (array) arguments for each set of parameters. This again works with loo_i() but results in another error in loo_subsample():

Error in .llfun(data_i = data[i, , drop = FALSE], draws = point_est) : 
  argument "beta" is missing, with no default

This seems to occur after loo_subsample()'s calls elpd_loo_approximation(), which does not seem to make use of loo_subsample()'s ... argument.

I would appreciate any thoughts on how to get around this problem. Thanks!
~Devin


More details

The following modification of the code in the loo vignette produces a similar error:

logistic_stan <- "//
data {
  int<lower=0> N;             // number of data points
  int<lower=0> P;             // number of predictors (including intercept)
  matrix[N,P] X;              // predictors (including 1s for intercept)
  int<lower=0,upper=1> y[N];  // binary outcome 
}
parameters {
  vector[P] beta;
  vector[P] alpha;
}
model {
  beta ~ normal(0, 1);
  alpha ~ normal(0, 1);
  y ~ bernoulli_logit(X * beta);
}
"

llfun_logistic <- function(data_i, draws, log = TRUE, alpha) {
    if (missing(alpha)) stop("No alpha found.")
    x_i <- as.matrix(data_i[, which(grepl(colnames(data_i), pattern = "X")),
                            drop=FALSE])
    logit_pred <- draws %*% t(x_i)
    dbinom(x = data_i$y, size = 1, prob = 1/(1 + exp(-logit_pred)), log = log)
}

url <- "http://stat.columbia.edu/~gelman/arm/examples/arsenic/wells.dat"
wells <- read.table(url)
wells$dist100 <- with(wells, dist / 100)
X <- model.matrix(~ dist100 + arsenic, wells)

stan_mod <- stan_model(model_code = logistic_stan)

fit_1 <- sampling(stan_mod, data = standata, seed = 4711)

print(fit_1, pars = c("beta", "alpha"))

parameter_draws_1 <- extract(fit_1)$beta
stan_df_1 <- as.data.frame(standata)

r_eff <- relative_eff(llfun_logistic, 
                      log = FALSE,
                      chain_id = rep(1:4, each = 1000), 
                      data = stan_df_1, 
                      draws = parameter_draws_1,
                      alpha = extract(fit_1)$alpha,
                      cores = 2)

loo_i(i = 1, llfun_logistic, r_eff = r_eff, data = stan_df_1,
      draws = parameter_draws_1, alpha = extract(fit_1)$alpha)

set.seed(4711)
loo_ss_1 <-
  loo_subsample(
    llfun_logistic,
    draws = parameter_draws_1,
    alpha = extract(fit_1)$alpha,
    data = stan_df_1,
    observations = 100, # take a subsample of size 100
    cores = 2,
    r_eff = r_eff
  )
  • Operating System: MacOS 11.2.3
  • RStan Version: 2.21.2
  • Output of devtools::session_info("rstan"):
─ Session info ───────────────────────────────────────────────────────────────
 setting  value                       
 version  R version 4.0.3 (2020-10-10)
 os       macOS Big Sur 10.16         
 system   x86_64, darwin17.0          
 ui       X11                         
 language (EN)                        
 collate  en_US.UTF-8                 
 ctype    en_US.UTF-8                 
 tz       America/New_York            
 date     2021-04-26                  

─ Packages ───────────────────────────────────────────────────────────────────
 package      * version   date       lib source        
 backports      1.2.1     2020-12-09 [1] CRAN (R 4.0.2)
 BH             1.75.0-0  2021-01-11 [1] CRAN (R 4.0.2)
 callr          3.7.0     2021-04-20 [1] CRAN (R 4.0.3)
 checkmate      2.0.0     2020-02-06 [1] CRAN (R 4.0.2)
 cli            2.4.0     2021-04-05 [1] CRAN (R 4.0.2)
 colorspace     2.0-0     2020-11-11 [1] CRAN (R 4.0.2)
 crayon         1.4.1     2021-02-08 [1] CRAN (R 4.0.2)
 curl           4.3       2019-12-02 [1] CRAN (R 4.0.1)
 desc           1.3.0     2021-03-05 [1] CRAN (R 4.0.2)
 digest         0.6.27    2020-10-24 [1] CRAN (R 4.0.2)
 ellipsis       0.3.1     2020-05-15 [1] CRAN (R 4.0.2)
 fansi          0.4.2     2021-01-15 [1] CRAN (R 4.0.2)
 farver         2.1.0     2021-02-28 [1] CRAN (R 4.0.2)
 ggplot2      * 3.3.3     2020-12-30 [1] CRAN (R 4.0.2)
 glue           1.4.2     2020-08-27 [1] CRAN (R 4.0.2)
 gridExtra      2.3       2017-09-09 [1] CRAN (R 4.0.2)
 gtable         0.3.0     2019-03-25 [1] CRAN (R 4.0.2)
 inline         0.3.17    2020-12-01 [1] CRAN (R 4.0.2)
 isoband        0.2.4     2021-03-03 [1] CRAN (R 4.0.2)
 jsonlite       1.7.2     2020-12-09 [1] CRAN (R 4.0.2)
 labeling       0.4.2     2020-10-20 [1] CRAN (R 4.0.2)
 lattice        0.20-41   2020-04-02 [1] CRAN (R 4.0.3)
 lifecycle      1.0.0     2021-02-15 [1] CRAN (R 4.0.3)
 loo          * 2.4.1     2020-12-09 [1] CRAN (R 4.0.2)
 magrittr       2.0.1     2020-11-17 [1] CRAN (R 4.0.2)
 MASS           7.3-53.1  2021-02-12 [1] CRAN (R 4.0.2)
 Matrix       * 1.3-2     2021-01-06 [1] CRAN (R 4.0.2)
 matrixStats    0.58.0    2021-01-29 [1] CRAN (R 4.0.2)
 mgcv           1.8-35    2021-04-18 [1] CRAN (R 4.0.3)
 munsell        0.5.0     2018-06-12 [1] CRAN (R 4.0.2)
 nlme           3.1-152   2021-02-04 [1] CRAN (R 4.0.3)
 pillar         1.6.0     2021-04-13 [1] CRAN (R 4.0.3)
 pkgbuild       1.2.0     2020-12-15 [1] CRAN (R 4.0.2)
 pkgconfig      2.0.3     2019-09-22 [1] CRAN (R 4.0.2)
 prettyunits    1.1.1     2020-01-24 [1] CRAN (R 4.0.2)
 processx       3.5.1     2021-04-04 [1] CRAN (R 4.0.2)
 ps             1.6.0     2021-02-28 [1] CRAN (R 4.0.2)
 R6             2.5.0     2020-10-28 [1] CRAN (R 4.0.2)
 RColorBrewer   1.1-2     2014-12-07 [1] CRAN (R 4.0.2)
 Rcpp           1.0.6     2021-01-15 [1] CRAN (R 4.0.2)
 RcppEigen      0.3.3.9.1 2020-12-17 [1] CRAN (R 4.0.2)
 RcppParallel   5.1.2     2021-04-15 [1] CRAN (R 4.0.3)
 rlang          0.4.10    2020-12-30 [1] CRAN (R 4.0.2)
 rprojroot      2.0.2     2020-11-15 [1] CRAN (R 4.0.2)
 rstan        * 2.21.2    2020-07-27 [1] CRAN (R 4.0.2)
 scales         1.1.1     2020-05-11 [1] CRAN (R 4.0.2)
 StanHeaders  * 2.21.0-7  2020-12-17 [1] CRAN (R 4.0.2)
 tibble       * 3.1.1     2021-04-18 [1] CRAN (R 4.0.3)
 utf8           1.2.1     2021-03-12 [1] CRAN (R 4.0.2)
 V8             3.4.1     2021-04-23 [1] CRAN (R 4.0.2)
 vctrs          0.3.7     2021-03-29 [1] CRAN (R 4.0.2)
 viridisLite    0.4.0     2021-04-13 [1] CRAN (R 4.0.3)
 withr          2.4.2     2021-04-18 [1] CRAN (R 4.0.3)

[1] /Library/Frameworks/R.framework/Versions/4.0/Resources/library

Ping @mans_magnusson and @jonah

Thanks! Ill try to take a look at it this week!

Hi!

Now I have checked this problem. Just as you say, the ellipsis in loo_i() cannot currently be used with loo_subsampling(). We should probably clarify this fact further in the vignette and documentation.

There are two solutions I think might be of relevance:

  1. The simplest solution is to include all your parameters in the draws object, which should work equally good since ellipsis is passed in the same way. Here is how I got your code to work (change this part in your code above). I essentially put alpha in the draws object.
llfun_logistic <- function(data_i, draws, log = TRUE, alpha) {
  # if (missing(alpha)) stop("No alpha found.")
  x_i <- as.matrix(data_i[, which(grepl(colnames(data_i), pattern = "X")),
                          drop=FALSE])
  logit_pred <- draws[, c("beta[1]", "beta[2]", "beta[3]")] %*% t(x_i)
  dbinom(x = data_i$y, size = 1, prob = 1/(1 + exp(-logit_pred)), log = log)
}

parameter_draws_1 <- cbind(extract(fit_1)$beta, extract(fit_1)$alpha)
colnames(parameter_draws_1) <- c("beta[1]", "beta[2]", "beta[3]", "alpha[1]", "alpha[2]", "alpha[3]")

  1. .ndraws is a function that just need to know the number of draws. I think if you define how to compute the number of draws from your list, it might work:
.ndraws.list <- function(x) {
# How to extract the number of draws from your list here
}

Although, I could not test if this worked with the code above since I don’t have your draws list object.

I hope this helped. Otherwise, just let me know, and I will try to help out!

1 Like