Updating model based on new data via PSIS

This forum has amassed several questions over the years (e.g. here, here, here, here, here, here) about whether we can efficiently update a Stan model if we collect a modest number of new data points, without refitting the whole thing.

When this came up most recently, it occurred to me that such updating should in general be readily achievable via PSIS. In many cases this should work quite well, since more data may act to tighten the posterior, such that the old posterior (the proposal distribution) will have more mass in the tails than the new target posterior. Maybe this approach is obvious, but it seems not to have been mentioned before on the posts linked above.

In case it’s useful to anyone, I wrote a function (and quickly wrapped it in an R package; the package is very much not ready for primetime) that uses PSIS to update a brms model based on new data. The approach is to compute the Pareto-smoothed importance weights reflecting the new data, then to actually resample (with replacement) the draws inside the brmsfit object using those weights, so that we end up with a new brmsfit object that we can manipulate exactly like the old one, but whose draws reflect the updated posterior. Doing this resampling step incurs a loss of information, but AFAIK is necessary to enable summarizing and post-processing the updated model via standard brms tools.

I’ve called the function and package upsis for “Updating with PSIS” and for a way to efficiently “upsize” the data (get it?).

So you can do something like

remotes::install_github("jsocolar/upsis")
library(upsis)
set.seed(1)

# generate initial data
x <- rnorm(10)
y <- rnorm(10, x)
df <- data.frame(x = x, y = y)

# generate additional data
x2 <- rnorm(20)
y2 <- rnorm(20, x2)
df2 <- data.frame(x = x2, y = y2)

# fit initial model
fit <- brm(y ~ x, data = df, backend = "cmdstanr")

# fit updated model
fit2 <- upsis(fit, data_add = df2)

summary(fit)
summary(fit2$updated_model)

Note that using upsis is not always exactly the same as fitting a new brms model with the updated dataset, because brms sometimes uses data-dependent priors and/or model structures (e.g. the positions of knots in splines) that do not get updated by upsis.

In tinkering on this, I was really happily surprised with how trivially easy it was to use the loo package to perform PSIS for purposes other than LOO-CV. Big props to @avehtari @jonah and company for this awesome tool! The trickiest part of this all (which still requires some code cleanup and re-factoring to make it less brittle) was figuring out how to extract, resample, and reinsert draws into a stanfit object.

upsis should be easily extensible to objects from rstanarm and any other object that provides a way to calculate the log likelihood of observing some set of new data.

12 Likes

Thanks for sharing @jsocolar. This would be possible in raw Stan, too, if you can access the PSIS weights.

Have you seen how far this works? That is, how big can the new data get relative to the old data until the variance is too high for PSIS to work? I realize it’s problem dependent—I’ve only used this for LOO, where it’s one observation in the new data set.

1 Like

Yup, it’d be easy in raw stan as well, as long as you have a method for forming the log-likelihood of the new data. I haven’t pushed this very far (and haven’t checked the output for correctness!), but it sometimes runs without complaining about the Pareto k diagnostic up to about an order of magnitude increase in dataset size, at least in some cases like the simple regression I investigated above. What really matters of course isn’t the increase in dataset size per se, but rather how substantially the posterior mass moves around. In many cases, large increases in the dataset size don’t actually change the posterior to an unmanageable degree. And unless we get unlucky, increasing the dataset size tends to concentrate the new posterior into part of the region of high (original) posterior mass, where we are likely to have a decent pool of samples in the proposal distribution.

As soon as we have lots of parameters that are getting strongly informed by only a few observations each, I imagine much smaller increases in the dataset size could become problematic.

That’s a better framing than mine in terms of data size—it’s really about how informative the new data is. Of course, if it’s not very informative, you might not want to update in an application.

In the limit, the importance resampling or weighting like this turns into sequential Monte Carlo if you add a bit of jitter.

As long as there’s not too much new data so that the place you want to sample is concentrated within the existing point. But even if I have a standard normal, doubling the number of draws concentrates by a factor of 1 / sqrt(2) in every dimension. But if we have 1000 draws, adding 100 more isn’t going to do much.

And this of course brings up the issue of whether the data’s even stationary to begin with!

1 Like

I’ve been away and busy, so missed commenting in the previous threads.

Leave-future-out cross-validation (paper, code) is doing exactly this, and eventually when the posterior changes too much, MCMC is run.

Cool! This we didn’t have. This should make it easier to use LFO-CV with brms models, too (ping @paul.buerkner)

Support for weighted draws is coming to posterior package, which would make it easy to get summaries and many post-processing results, but doing re-sampling is definitely easier (I hope you used the stratified approach).

It should be even easier using posterior package, as all PSIS functions have moved there with easier interfaces and additional diagnostics.

posterior package should be able to help there a bit, too.

5 Likes

I stratified by chain, so that if the resampled chains don’t converge to the same stationary distribution standard diagnostics and summary outputs will pick that up. Perhaps unsurprisingly, I’ve found via a very small amount of experimentation that I tend to start getting convergence warnings from the outputs at approximately the same time as I start getting pareto-k warnings from PSIS.

Should I be stratifying in additional ways?

I meant stratified resampling algorithm by Kitagawa (1996) as implemented, e.g., in posterior package function resample_draws(). It reduces the variance in the resampling.

I had not thought about making the resampling independently for each chain, but that makes perfect sense so that Rhat is still available. ESS will be slightly wrong as the repeated draws will partially break the autocorrelation structure, but when importance sampling starts to fail Rhat part inside ESS computation dominates anyway.

1 Like

Thanks for the pointer. I’ve updated to use posterior to do the stratified resampling, but still iterating over chains one at a time.

Just a note that I fixed a bug in the resampling (or more specifically, in the splitting apart and re-combining of the warmup draws and the (resampled) post-warmup draws in the event that the stanfit contains warmup draws). I believe everything is working correctly now.

1 Like

Possibly a silly question, but how might one write the raw stan to compute the log-likelihood for new data? The log-likelihood for the initial data can just be added to the gq block (e.g., below). For new data, the only method that’s coming to mind is to write a separate stan file that takes both draws and new observations as data.[1] My assumption is that this is excessive/not necessary, since brms::log_lik() is able to take in newdata without specifying a new model.

// log likelihood of initial data
generated quantities {
  vector[N] log_lik = zeros_vector(N);
  for (n in 1:N) {
    log_lik[n] = binomial_logit_lpmf(Y[n] | K[n], mu[n]);
  }
}

  1. You could do this outside of stan, too. Stan just gets the benefit of running in C++ rather than R/python. You’d just need to set to only 1 iteration since all the draws are getting passed in. ↩︎

Actually, from digging through brms’ code in a bit more detail, I think brms is effectively doing the same thing that I would’ve originally guessed — evaluating the log-likelihood on new data using the posterior draws but doing so outside of the original stan model. It’s just doing some magic under the hood to dynamically create a function that recreates the linear predictions.