Here’s code using bridgestan
package and adaptive Gauss-Hermite quadrature from aghq
package. It computes posterior mean in 0.005s for one X given one set of parameter values (for simplicity, I used the true values as you did when sampling X). For n=500 it takes 2.5s, and multiply that e.g. by 100 for averaging over 100 posterior draws. I guess your actual model and data are more complex. In the code I compute CDF for one X, and posterior mean for all X’s, but you can compute moments and quantiles of arbitrary functions based on your needs.
# Additional libraries
library(posterior)
library(bridgestan)
library(aghq)
library(ggplot2)
theme_set(bayesplot::theme_default(base_family = "sans"))
library(tictoc)
# bridgestan compilation
mod2bs_lib <- compile_model("multi_sim_missx.stan")
# make json literal from Stan data list
to_stan_json <- function(data, always_decimal = FALSE) {
if (!is.list(data)) {
stop("'data' must be a list.", call. = FALSE)
}
data_names <- names(data)
if (length(data) > 0 &&
(length(data_names) == 0 ||
length(data_names) != sum(nzchar(data_names)))) {
stop("All elements in 'data' list must have names.", call. = FALSE)
}
if (anyDuplicated(data_names) != 0) {
stop("Duplicate names not allowed in 'data'.", call. = FALSE)
}
for (var_name in data_names) {
var <- data[[var_name]]
if (!(is.numeric(var) || is.factor(var) || is.logical(var) ||
is.data.frame(var) || is.list(var))) {
stop("Variable '", var_name, "' is of invalid type.", call. = FALSE)
}
if (anyNA(var)) {
stop("Variable '", var_name, "' has NA values.", call. = FALSE)
}
if (is.table(var)) {
var <- unclass(var)
} else if (is.logical(var)) {
mode(var) <- "integer"
} else if (is.data.frame(var)) {
var <- data.matrix(var)
} else if (is.list(var)) {
var <- list_to_array(var, var_name)
}
data[[var_name]] <- var
}
# unboxing variables (N = 10 is stored as N : 10, not N: [10])
jsonlite::toJSON(
data,
auto_unbox = TRUE,
factor = "integer",
always_decimal = always_decimal,
digits = NA,
pretty = TRUE
)
}
# Test with first observation
i=1
standati <- list(
N = 1,
K = 3,
y = matrix(c(y1[i],y2[i],y3[i]), nrow=1, ncol=3),
a = c(a1, a2, a3),
r = c(r1, r2, r3),
b = c(b1, b2, b3),
n1 = c(n1_1, n1_2, n1_3),
n2 = c(n2_1, n2_2, n2_3),
cors = randcors
)
# instantiate with data
mod2bs <- StanModel$new(lib = mod2bs_lib,
data = to_stan_json(standati),
seed = 67134)
# list of density, gradient and hessian functions
ffs <- list(fn=mod2bs$log_density,
gr=\(x) {mod2bs$log_density_gradient(x)$gradient},
he=\(x) {mod2bs$log_density_hessian(x)$hessian})
# adaptive Gauss-Hermite quadrature in unconstrained space
aghqi <- aghq(ffs,k=11,startingvalue=0)
# quantiles 1%,...,99% in unconstrained space
q <- seq(0.01,0.99,by=.02)
qt <- compute_quantiles(aghqi, q = q)[[1]]
# quantiles 1%,...,99% in constrained space
qx <- sapply(qt, mod2bs$param_constrain)
# plot CDF for X[1]
data.frame(qx=qx, q=q) |>
ggplot(aes(x=qx, y=q)) +
geom_line() +
labs(x='X', y='CDF')
# compute posterior mean for all X
xmean <- array(dim=n)
tic()
for (i in 1:n) {
standati$y<-matrix(c(y1[i],y2[i],y3[i]), nrow=1, ncol=3)
suppressWarnings(mod2bs <- StanModel$new(lib = mod2bs_lib,
data = to_stan_json(standati),
seed = 1))
ffs <- list(fn=mod2bs$log_density,
gr=\(x) {mod2bs$log_density_gradient(x)$gradient},
he=\(x) {mod2bs$log_density_hessian(x)$hessian})
aghqi <- aghq(ffs,k=11,startingvalue=2)
# compute posterior mean in constrained space
xmean[i] <- compute_moment(aghqi, mod2bs$param_constrain, method="correct")
}
toc()
# compare to Stan posterior mean
data.frame(x=x, xpred=xmean) |>
ggplot(aes(x=x_sum$mean, y=xpred)) +
geom_point() +
geom_abline() +
labs(x="X posterior mean via Stan", y="X posterior mean via aghq")