Best way to do prediction on new data? R rstan stanfit


I am new to stan, so I’d really appreciate any help. I use R 3.4.1 with rstan (2.14.1).

I have a fitted stan object in R, fitted via rstan::stan. How do I make prediction on new data?

Options so far:

  • Include new data in call to ‘stan’ function, like stan reference 2.16.0 page 160. However this requires new data to be available at training time, which is not the case for me. I do not wish to retrain the model when new-data-for-prediction comes since model training takes a while
  • Extract posterior parameter samples into R and do prediction in R with new-data-for-prediction. However this requires me to duplicate the model structure in R
  • Muck around with the “algorithm = 'Fixed_parameter” function in rstan::stan. Doesn’t do what I want, it probably is not meant for this use case?

The user experience I am after is:

  • Fit model in R, eg: fit <- stan(…)
  • Make posterior predictions using new data and the posterior parameter samples in fit, eg: new_data_pred_samples <- predict(fit, newdata)
  • then I can get prediction point estimates and bounds from these new_data_pred_samples

Any ideas?

1 Like

This is your best bet if you are using rstan::stan. If you were using the rstanarm or brms packages, there is a posterior_predict method with a newdata argument that does exactly this.

Please note also that, if you coded your model in stan, you can access it in rstan without recoding: here

Thanks everyone. I’ve decided to duplicate model structure in R.

Think the expose_stan_functions functionality seemed promising, but couldn’t find any useful examples of it being used for my use case. Most of the R code I ended up writing were wrapper code that pulls samples from stanfit object and stan-data lists anyway.

It is a shame that rstan doesn’t have posterior_predict method like brms and rstanarm.

rstan::stan does not not what the model is, so it cannot do the duplication to have a posterior_predict method.

1 Like

I had a similar problem a while ago and I too lamented the lack of something as easy as a predict() type solution. If you’re not yet familiar with it, generating posterior predicted values using McElreath’s rethinking R package is just as easy. The rethinking package is a front-end for stan and I can’t say enough positive things about the package or McElreath’s book. If you pursue this route, see the link() function in chapter 4.

Until there’s a similar solution available directly in rstan, I think you’ve got to do this manually. But rather than using expose_stan_functions(), I think it’s easier and more transparent to do this entirely outside of stan and in R. Here’s an example using the rats example that comes with rstan. In that example, assume you’ve fit the very same model I just linked to in stan and that you have a resulting stanfit object called “rats1”. Assume also that your interest is in predicting average rat weight for new data–days 21, 22, and 23.

# generate posterior samples of the stan model coefficients
param.sample <-  

# assume data2 is your new df. This df has just 3 time periods one day apart in which the rats were weighed
data2 <- data.frame(x=c(21, 22, 23))  
data2$xbar <- mean(data2$x)

# posterior predicted values for rat weight using the new data frame (data2):
result.sample<- t(apply(param.sample, 1, function(x) x["mu_alpha"] + x["mu_beta"]*(data2$xbar - data2$x) ))

# summary
result.summary <- apply(result.sample, 2, mean) 
result.summary.std <- apply(result.sample, 2, sd) 

I’m guessing your interest would be in result.summary and in the standard errors around it.

Hi Ben. I think you meant to say that rstan doesn’t know what the model is and, hence, it can’t compute predicted values. But can’t one divine what the model is from the following?


In any event, it appears that others have been able to create simple posterior_predict() type solutions from a stanfit object. In my reply to the OP above, I make reference to McElreath’s rethinking package. That package has a link() function that appears to pull whatever’s needed from the stanfit object and generate predicted values. And since rethinking is a front-end for stan, this ought to be doable natively in rstan.

pred.sample <- rethinking::link(my.model, new.df) # where mymodel is an object of class 'rethinking' and rethinking is just a stan wrapper.

I’m just not smart enough to extract just exactly only what’s needed from the get_stancode text string and cast it into an R formula object or equation.

A small percentage of humans can, but a computer cannot (yet), in general. The rethinking, rstanarm, brms, etc. R packages all use R syntax to specify the model and call rstan::sampling. By caching the information from the R syntax, these packages can do posterior prediction fairly easily. But they are not inferring it from the Stan program.

1 Like

Thanks. That makes complete sense.

I assume this isn’t feasible for larger datasets, but could you not just do posterior predictions within generated quantities? Include the prediction-dataset data {}, then plug all that into a predictive formula in generated quantities {}?

Yes, that’s what we recommend. But it’s not automatic the way it is in the other packages mentioned.

1 Like

I do not really know how to perform posterior predictions and I wonder whether you would have a practical example or tutorial that could explain it further. How would the code.stan look like?

There are examples in the manual and in my case study on repeated binary trials (there’s a page of links to case studies for Stan).


As this post mention, deploying stan model is pretty hard in production comparing glm or random forest model using tidypredict, which can convert the model to SQL directly.

wish stan can support better deploy user experience.

If we want to do full Bayes predictively, we need to average predictions from posterior draws. It’s not just a simple function, though there’s no reason you couldn’t write DB code.

Often in full Bayes, even that’s not enough, as we want to use the test inputs in a kind of semi-supervised way (as determined by the standard rules of probability). Stan does that if you include prediction along with the model (not in generated quantities), but this is very compute intensive.

I suspect the larger problem is that you’re looking for a prebuilt solution like that, which isn’t on our roadmap, so unlikely to happen any time soon.

I understand the challenge of creating a predict function for models fit using rstan::stan. However it is a problem I run into constantly: Needing to use a computationally intensive model to predict outcomes on lots of new data (where I only have observations of the independent data so can’t use the new data as new observations for model fitting). While it couldn’t be called a prediction method per say (since requires the user to write the code that actually produces the predictions), would it be possible to extend the algorithm = "Fixed_param" to allow say a list of fixed parameters corresponding to the posterior of the parameters from a fitted model? At the moment, my strategy is to sequentially pass each draw from the posterior to stan using algorithm = "Fixed_param" with new data, and then storing the posterior predictives for the new data for each run. See reprex below. Not terrible for users to write their own code for this, but seems like the Fixed_param option is almost there as a way to get posterior predictives for new data without having to refit the model each time? Or am I missing something and can that already be done using Fixed_param?. If not and if there’s any interest in developing that functionality I’d be happy to help out as best I can.

#> Loading required package: StanHeaders
#> rstan (Version 2.19.2, GitRev: 2e1f913d3ca3)
#> For execution on a local, multicore CPU with excess RAM we recommend calling
#> options(mc.cores = parallel::detectCores()).
#> To avoid recompilation of unchanged Stan programs, we recommend calling
#> rstan_options(auto_write = TRUE)
#> Attaching package: 'rstan'
#> The following object is masked from 'package:tidyr':
#>     extract

# a toy model

stan_model <- "

data {
  int<lower=0> n;          // number of observations
  int<lower = 1> n_betas;  // number of betas
  vector[n] y;               // outcomes
  matrix[n,n_betas] x;          // predictors
parameters {

  vector[n_betas] betas;

  real sigma;

model {

  y ~ normal(x * betas, sigma);

  betas ~ normal(0,10);

  sigma  ~ cauchy(0,2.5);


generated quantities {

vector[n] y_pp;

for (i in 1:n){

  y_pp[i] = normal_rng(x[i,1:n_betas] * betas, sigma);



n <- 50

betas <- c(2,10)

sigma <- 10

training_data <- data.frame(x = 1:n, intercept = 1)

y = as.numeric(as.matrix(training_data) %*% betas + rnorm(n,0,sigma))

# plot(training_data$x, y)

# fit the model

stan_fit  <- stan(
  model_code = stan_model,
  data = list(n = n,
              n_betas = length(betas),
              y = y,
              x = training_data),
  chains = 1,
  warmup = 500,
  iter = 1000,
  cores = 1,
  refresh = 0             # no progress shown

# plot(stan_fit, pars = "betas")

# go through and get the individual draws for each parameter

tidy_posts <-  tidybayes::gather_draws(stan_fit, betas[variable])

nested_posts <- tidy_posts %>%
  group_by(.draw) %>%

# create some new data partly outside of the range of the training data
testing_data <- data.frame(x = 20 + (1:n), intercept = 1)

new_data <- list(n = nrow(testing_data),
                 n_betas = ncol(testing_data),
                 y = rep(1,n),
                 x = testing_data

pred_foo <- function(params, stan_model, new_data) { # function to get posterior predictives given fixed parameters

  variables <- unique(params$.variable)

  inits <-
    purrr::map(variables, ~ params$.value[params$.variable == .x]) %>%

  pp_samps <- stan(
    model_code = stan_model,
    data = new_data,
    chains = 1,
    warmup = 0,
    iter = 1,
    cores = 1,
    refresh = 0,
    init = list(inits),
    algorithm = "Fixed_param"

  out <- tidybayes::tidy_draws(pp_samps)

} # close function

# iterate over posterior of parameters to generate predictions (pretending you had "new" schools data)
nested_posts <- nested_posts %>%
  mutate(preds = map(data, pred_foo, stan_model = stan_model, new_data = new_data))

unnested_posts <- nested_posts %>%
  rename(draw = .draw) %>%
  select(-data) %>%
  unnest(cols = preds)

y_pp <- unnested_posts %>%
    cols = contains("_pp"),
    names_to = "observation",
    values_to = "prediction",
    names_pattern = "y_pp\\[(.*)\\]",
    names_ptypes = list(observation = integer())

y_pp %>%
  mutate(x = observation + min(testing_data$x) - 1) %>%
  group_by(x) %>%
  summarise(mean_pred = mean(prediction),
            lower = quantile(prediction, 0.05),
            upper = quantile(prediction, 0.95)) %>%
  ungroup() %>%
  ggplot() +
  geom_ribbon(aes(x, ymin = lower, ymax = upper), alpha = 0.5) +
    geom_line(aes(x, mean_pred), color = "red") +
  scale_y_continuous(name = "y") +
  labs(caption = "Red line is mean posterior predictive, grey shaded area 90% credible interval")

Created on 2019-10-22 by the reprex package (v0.3.0)