Extract log-likelihood function from rstanarm model

Hello,

I recently read the arXiv paper https://arxiv.org/abs/2002.09633 about Bayesian survival models in rstanarm.

I am interested in applying approximate leave-one-out cross-validation based on the loo_approximate_posterior function of the loo package to some of the survival models available in rstanarm.

Based on the papers I read, the approximate LOO method of the loo_approximate_posterior function speeds up the computation of LOO by approximating the posterior (e.g. by using Laplace, meanfield or fullrank approximations), and correcting the importance weights for using such a posterior approximation. The importance weights are adapted so that only the full posterior needs to be computed once (e.g. by MCMC) to obtain the leave-one-out posteriors required for the computation of LOO. Second, probability-proportional-to-size subsampling is used to use only a subset of LOO posteriors instead of all LOO posteriors, which speeds up the computation further, in particular in datasets with large sample size n.

So far, I followed the LOO vignette about large data at https://cran.r-project.org/web/packages/loo/vignettes/loo2-large-data.html. In the vignette, an example is given where a Laplace posterior approximation is used:

# Approximate LOO-CV using PSIS-LOO with posterior approximations
fit_laplace <- optimizing(stan_mod, data = standata, draws = 2000, 
                      importance_resampling = TRUE)
parameter_draws_laplace <- fit_laplace$theta_tilde # draws from approximate 
posterior
log_p <- fit_laplace$log_p # log density of the posterior
log_g <- fit_laplace$log_g # log density of the approximation

set.seed(4711)
loo_ap_ss_1 <-
  loo_subsample(
  x = llfun_logistic,
  draws = parameter_draws_laplace,
  data = stan_df_1,
  log_p = log_p,
  log_g = log_g,
  observations = 100
)
print(loo_ap_ss_1)

Now, to benefit from the approximation of the LOO posteriors by the full posterior, the parameters log_p and log_g need to be specified. Also, the log-likelihood function llfun_logistic needs to be specified.

My question now is the following:

  1. The vignette uses rstan. The above survival models on the other hand are fit via rstanarm. While it theoretically would be possible to extract the raw Stan code from the rstanarm models to manually create the log-likelihood functions, this is quite tedious (if possible at all). Is it possible to use somehow extract the log-likelihood function (not matrix) from the rstanarm models to subsequently pass it to the loo_subsample function?
  2. The vignette uses a Laplace approximation via the optimizing function to approximate the posterior distribution. The vb function provides fullrank and meanfield algorithms, too. To use the meanfield or fullrank approximate posterior in the loo_subsample function, I would need to extract the log_p and log_g parameters. Is there a way to do this? Based on the documentation I could not find anything, while the optimizing function easily provides both.

Thanks in advance,

Riko

Hi!

It was a long time since I worked with survival models, but from Eq (20) I assume that the likelihood is factorizable over observations. So this should work with subsampling.

I think I may have some code where I added subsampling to rstanarm from this paper:

Ill try to check my old repo to see how I solved it.

Also, use the difference estimator in loo subsampling instead of the HH estimator. Its probably better for you (see the paper above).

/Måns

I now quickly checked my old code. I extracted the log likelihood function from within the rstanarm package for the very specific models at hand. Hence, it worked for the project - but is not good for production/reuse.

I think it should be better to try to do this using the rstanarm public API. I looked it up and it seems like the log_likelihood function can be accessed using newdata. So you would just supply a row of your data to newdata to get the log_lik values for the subset of data, although. Im not sure if this work for loo_subsampling, I do not remember.

It is also possible to get the log_likelihood function directly (what I can tell) using rstanarm:::ll_fun(object). This mean that you would sidestep the rstanarm API so you would need to check that it works as expected (maybe compare with log_lik for a small subsample). Here is the code (71-89, 105-111, and 169-184):

Could you try and see if this would work?

I have also opened an issue to try to get this into rstanarm here:

/Måns

1 Like

I tinkered a little now and tried your solution. The good news is that I found a way to use the loo_subsample() function for an exponential model. The bad news is that probably the solution does not scale well, and I could not manage to use the rstanarm::ll_fun() function.

The first step is fitting the model based on some simulated data (here I just use one single treatment covariate, but of course in practice dozens to hundreds of covariates are more realistic, I omitted some code here, but the actual simulation used 50 covariates out of which 10 are relevant).

# Step 1: Load rstanarm in dev_mode
devtools::dev_mode(on=T)
library(rstanarm)

# Step 2: Simulate treatment covariate
library(simsurv)
set.seed(999111)
Nobs <- 1000

covs <- data.frame(id = 1:Nobs,
                   trt = rbinom(Nobs,1,0.5))

dat <- simsurv(dist = "exp",
               lambdas = 0.1,
               betas = c(trt = -0.5),
               x = covs,
               maxt = 10)
head(dat)

# Step 3: Fit expontial model to data
stanmodel <- stan_surv(
  formula = Surv(eventtime, status) ~ trt,
  data = dat,
  basehaz = "exp",
  chains = 2,
  cores = 4,
  seed = 42,
  iter = 2000
)

Now, first to the solution I got working: I changed the log-likelihood function used by the loo_subsampling() function, so that it basically searches the i-th row data_i in the data frame dat, which is declared globally outside of the functions scope.

This search is probably quite inefficient, but sidesteps the problem of having no analytic access to the log-likelihood by simply using the log_lik() function in rstanarm like you proposed via the newdata argument.

library(dplyr)

# Log-likelihood function for MCMC posterior model
llfun_stanmodel <- function(data_i, draws) {
  for(i in 1:nrow(dat)){ 
    s = sum((dat[i,] == data_i)[1,])
    if (s==ncol(dat)){ 
      data_i_index = dat[i,]$id
    }
  }
  rstanarm::log_lik(stanmodel, newdata = dat[data_i_index,])
}

After that, I simply extract the posterior draws of the fitted model, check if the loo_i() function used in loo_subsampling() works as desired, and run the subsampling function for LOO-CV:

# Exact LOO-CV with MCMC samples
parameter_draws_mcmc <- as.matrix(stanmodel)

# check loo_i function
library("loo")
loo_i(1, llfun_stanmodel, data = dat, draws = parameter_draws_mcmc)

set.seed(4711)
start.time <- Sys.time()
loo_ss_mcmc <-
  loo_subsample(
    llfun_stanmodel,
    draws = parameter_draws_mcmc,
    data = dat,
    observations = 100 # take a subsample of size 100
  )

print(loo_ss_mcmc)
end.time <- Sys.time()
time.taken <- end.time - start.time

For 1000 observations and 50 covariates running the loo_subsample function took still about 20 minutes on my machine for 500 subsamples, and about 17 minutes for 100 subsamples. I think maybe this is because of the inefficient implementation of the log-likelihood which has to run a search everytime. Any suggestions how to speed up my modified log-likelihood?

Now to the problem with this solution: It does not work with approximate posterior inference like Laplace or variational inference. When trying to modify the log-likelihood to compute the log-likelihood based on a Laplace or variational inference posterior, R throws an error because the log_lik() function is only implemented for precise MCMC posteriors:

# Meanfield variational inference approximation
stanmodel_mf <- stan_surv(
  formula = Surv(eventtime, status) ~ trt,
  algorithm = "meanfield",
  data = dat,
  basehaz = "exp",
  seed = 42,
  iter = 4000
)

# Log-likelihood function for meanfield variational inference posterior model
llfun_stanmodel_mf <- function(data_i, draws) {
  for(i in 1:nrow(dat)){ 
    s = sum((dat[i,] == data_i)[1,])
    if (s==ncol(dat)){ 
      data_i_index = dat[i,]$id
    }
  }
  rstanarm::log_lik(stanmodel_mf, newdata = dat[data_i_index,]) # note: stanmodel_mf instead of stanmodel now
}

# Exact LOO-CV with meanfield variational inference samples
parameter_draws_mf <- as.matrix(stanmodel_mf, pars = "trt")

# check loo_i function
loo_i(1, llfun_stanmodel_mf, data = dat, draws = parameter_draws_mf)

The loo_i() function already does not work, and throws the error. I think there is no easy solution to this, because there is no implementation of log_lik() for approximate posteriors.

Now to your solution via the rstanarm::ll_fun(object) function. I copied and pasted the whole code you linked into my script. When I try to run the function, I get some problems. First, some of the functions are not known, e.g. the validate_stanreg_object(x) function call inside rstanarm::ll_fun(object) is not known. How can I include these functions?

However, I commented the function out and was able to get the likelihood function then. Problematically, there are of course various parameters the function uses. Without knowledge what these parameters are and how they relate to each other, there is few hope of reverse engineering the likelihood via this method. For some parameters it is easy to infer what they mean, for example there’s a variable eta which is the linear combination of predictors. For others, it is (for me) very unclear what they are.

A quick sidenote: Of course, writing down the likelihood function for an exponential survival model or a weibull model should be no problem whatsoever, but in particular for the more complex models like M-spline models or B-spline models this quickly becomes difficult.

Another quick sidenote: I only was able to run meanfield or variational inference for the exponential model via the stan_surv() function. It seems like variational inference and Laplace approximations are not implemented thus far for most other models, so I guess there is little hope of approximating the posterior to speed up the LOO-CV computation for those survival models by now.

Let me know what you think of my solution. I am in particular not satisfied with the time the loo_subsample() function takes to run. Maybe this is simple due to the computational effort of precise MCMC fitting of these survival models, and it would already be much faster when approximations were available. Maybe this is due to the inefficient search, which becomes costly especially when the sample size becomes large.