Post-Hoc Prediction: Use samples from fitted model in a prediction task

Here’s code using bridgestan package and adaptive Gauss-Hermite quadrature from aghq package. It computes posterior mean in 0.005s for one X given one set of parameter values (for simplicity, I used the true values as you did when sampling X). For n=500 it takes 2.5s, and multiply that e.g. by 100 for averaging over 100 posterior draws. I guess your actual model and data are more complex. In the code I compute CDF for one X, and posterior mean for all X’s, but you can compute moments and quantiles of arbitrary functions based on your needs.

# Additional libraries
library(posterior)
library(bridgestan)
library(aghq)
library(ggplot2)
theme_set(bayesplot::theme_default(base_family = "sans"))
library(tictoc)

# bridgestan compilation
mod2bs_lib <- compile_model("multi_sim_missx.stan")

# make json literal from Stan data list
to_stan_json <- function(data, always_decimal = FALSE) {
  if (!is.list(data)) {
    stop("'data' must be a list.", call. = FALSE)
  }

  data_names <- names(data)
  if (length(data) > 0 &&
      (length(data_names) == 0 ||
       length(data_names) != sum(nzchar(data_names)))) {
    stop("All elements in 'data' list must have names.", call. = FALSE)

  }
  if (anyDuplicated(data_names) != 0) {
    stop("Duplicate names not allowed in 'data'.", call. = FALSE)
  }

  for (var_name in data_names) {
    var <- data[[var_name]]
    if (!(is.numeric(var) || is.factor(var) || is.logical(var) ||
          is.data.frame(var) || is.list(var))) {
      stop("Variable '", var_name, "' is of invalid type.", call. = FALSE)
    }
    if (anyNA(var)) {
      stop("Variable '", var_name, "' has NA values.", call. = FALSE)
    }

    if (is.table(var)) {
      var <- unclass(var)
    } else if (is.logical(var)) {
      mode(var) <- "integer"
    } else if (is.data.frame(var)) {
      var <- data.matrix(var)
    } else if (is.list(var)) {
      var <- list_to_array(var, var_name)
    }
    data[[var_name]] <- var
  }

  # unboxing variables (N = 10 is stored as N : 10, not N: [10])
  jsonlite::toJSON(
    data,
    auto_unbox = TRUE,
    factor = "integer",
    always_decimal = always_decimal,
    digits = NA,
    pretty = TRUE
  )
}

# Test with first observation
i=1
standati <- list(
  N = 1,
  K = 3,
  y = matrix(c(y1[i],y2[i],y3[i]), nrow=1, ncol=3),
  a = c(a1, a2, a3),
  r = c(r1, r2, r3),
  b = c(b1, b2, b3),
  n1 = c(n1_1, n1_2, n1_3),
  n2 = c(n2_1, n2_2, n2_3),
  cors = randcors
)
# instantiate with data
mod2bs <- StanModel$new(lib = mod2bs_lib,
                        data = to_stan_json(standati),
                        seed = 67134)
# list of density, gradient and hessian functions
ffs <- list(fn=mod2bs$log_density,
            gr=\(x) {mod2bs$log_density_gradient(x)$gradient},
            he=\(x) {mod2bs$log_density_hessian(x)$hessian})

# adaptive Gauss-Hermite quadrature in unconstrained space
aghqi <- aghq(ffs,k=11,startingvalue=0)
# quantiles 1%,...,99% in unconstrained space
q <- seq(0.01,0.99,by=.02)
qt <- compute_quantiles(aghqi, q = q)[[1]]
# quantiles 1%,...,99% in constrained space
qx <- sapply(qt, mod2bs$param_constrain)
# plot CDF for X[1]
data.frame(qx=qx, q=q) |>
  ggplot(aes(x=qx, y=q)) +
  geom_line() +
  labs(x='X', y='CDF')

# compute posterior mean for all X
xmean <- array(dim=n)
tic()
for (i in 1:n) {
  standati$y<-matrix(c(y1[i],y2[i],y3[i]), nrow=1, ncol=3)
  suppressWarnings(mod2bs <- StanModel$new(lib = mod2bs_lib,
                                           data = to_stan_json(standati),
                                           seed = 1))
  ffs <- list(fn=mod2bs$log_density,
            gr=\(x) {mod2bs$log_density_gradient(x)$gradient},
            he=\(x) {mod2bs$log_density_hessian(x)$hessian})
  aghqi <- aghq(ffs,k=11,startingvalue=2)
  # compute posterior mean in constrained space
  xmean[i] <- compute_moment(aghqi, mod2bs$param_constrain, method="correct")
}
toc()

# compare to Stan posterior mean
data.frame(x=x, xpred=xmean) |>
  ggplot(aes(x=x_sum$mean, y=xpred)) +
  geom_point() +
  geom_abline() +
  labs(x="X posterior mean via Stan", y="X posterior mean via aghq")
1 Like