I am trying to teach myself how to conduct a posterior predictive check in Stan. I have read that this can be done using the generated quantities block in the model statement (see here). However I am worried I’m not doing it properly because the mean HPDI of the distribution of each counterfactual predictor value derived from the generated quantities block looks different to the HPDI derived from running the same model in Richard McElreath’s rethinking
package
Here is the toy data.
library(rethinking)
x <- seq(5,14.9,.1)
y <- x*3 + rnorm(100,0,10)
df <- data.frame(x,y)
Now run a simple Bayesian linear regression on the data using the map2stan()
function from the rethinking
package.
mod <- map2stan(
alist(
y ~ dnorm(mu, sigma),
mu <- a + b*x,
a ~ dnorm(0,30),
b ~ dnorm(0,20),
sigma ~ dunif(0,50)
), data = df
)
This is the model output
precis(mod)
# Mean StdDev lower 0.89 upper 0.89 n_eff Rhat
# a -10.24 3.59 -15.55 -4.15 416 1.00
# b 3.87 0.35 3.32 4.40 413 1.00
# sigma 9.92 0.70 8.76 11.04 499 1.01
Now to perform the posterior predictive check we need generate counterfactual candidate values for x and compute the mean and HPDI for each of those values (note these candidate values extend 5 units either side of the actual values of x for the data) using the link()
function from the rethinking
package, which “computes the value of each linear model at each sample for each case in the data. Inverse link functions are applied, so that for example a logit link linear model produces probabilities, using the logistic transform.”
xSeq <- seq(0,20,length.out = 200) # 200 counterfactual values of predictor 'x'
mu <- link(mod, data = data.frame(x = xSeq)) # pass these values into link function
Now create a dataframe with counterfactual values, mean and HPDI (using the rethinking::HPDI
function) of each distribution of counterfactual y values for each counterfactual x value.
cfDF <- data.frame(xSeq = xSeq,
muMean = apply(mu, 2, mean),
muHPDI_low = apply(mu, 2, HPDI)[1,],
muHPDI_hi = apply(mu, 2, HPDI)[2,])
And plot the mean and HPDI at each counterfactual x value against the actual data.
ggplot(df, aes(x, y)) +
geom_point(colour = "blue", shape = 1) +
geom_line(data = cfDF, aes(x=xSeq, y=muMean), colour = "red", linetype = "dashed") +
geom_line(data = cfDF, aes(x=xSeq, y=muHPDI_low), linetype = "dotted") +
geom_line(data = cfDF, aes(x=xSeq, y=muHPDI_hi), linetype = "dotted") +
theme_classic()
The actual values are in blue, the counterfactual predicted means are represented by the red dashed line and the counterfactual HPDI by the black dotted line. Note how the predicted HPDI is quite close to the mean when the counterfactual values range over the existing data, but spreads out as the counterfactual values occupy candidate parameter space that is beyond the range of the existing data (i.e. it incorporates the uncertainty of prediction when there is no actual data to base prediction on).
I use this model as an example for the stan model
Step 1: Put data in list.
In this list we include the outcome (y) the values of the predictor (x) and the number of values to loop over. We ALSO need to include the counterfactual values of the predictor (xSeq) and the number of these counterfactuals.
dList <- list(N = nrow(df), y = y, x = x, Ncf = length(xSeq), xSeq = xSeq)
Step 2: create model
Note the generated quantities block, where we pass the counterfactual x values in xSeq
into the for-loop using the normal_rng()
function.
write("
data{
int<lower=1> N;
int<lower=1> Ncf; // number of counterfactual x-values
real x[N];
real y[N];
real xSeq[Ncf]; // vector of counterfactuals
}
parameters{
real a;
real b;
real <lower=0,upper=50> sigma;
}
model{
vector[N] mu;
a ~ normal(0,30);
b ~ normal(0,20);
sigma ~ uniform(0,50);
for (i in 1:N) {
mu[i] = a + b*x[i];
}
y ~ normal(mu, sigma);
}
generated quantities {
real y_rep[Ncf];
for (i in 1:Ncf) {
y_rep[i] = normal_rng(a + b*xSeq[i], sigma); // note the xSeq instead of x
}
}
", file = "temp.stan")
Step 3: generate the mcmc chains
library(rstan)
chains <- stan(file = "temp.stan",
data = dList,
warmup = 1e3,
iter = 3e3,
cores = 1,
chains = 1)
Step 4: Diagnostics
print(chains, probs = c(0.025, 0.975))
The model looks good (output ommitted), has a high n_eff
and Rhat
. And, in addition to the parameter estimates generated by the linear model, we now also have 200 individual distributions of estimates for each of the 200 counterfactual predictor values in xSeq (similar to the matrix generated by the rethinking::link()
function.
Let’s extract the counterfactual estimates only. There are 2000 predicted y values for each of the 200 counterfactual x values.
y_rep <- as.matrix(chains, pars = "y_rep")
dim(y_rep)
# [1] 2000 200
Let’s create a function to get the HPDI
HPDIFunct <- function (vector) {
sortVec <- sort(vector)
ninetyFiveVec <- ceiling(.95*length(sortVec))
fiveVec <- length(sortVec) - length(ninetyFiveVec)
diffVec <- sapply(1:fiveVec, function (i) sortVec[i + ninetyFiveVec] - sortVec[i])
minVal <- sortVec[which.min(diffVec)]
maxVal <- sortVec[which.min(diffVec) + ninetyFiveVec]
return(list(minVal, maxVal))
}
Now let’s get the mean and HPDI of the distributions of y for each of the 200 counterfactual values of x
cfDF_stan <- data.frame(xSeq = xSeq,
muMean = apply(y_rep, 2, mean),
muHPDI_low = sapply(1:dim(y_rep)[2], function (i) HPDIFunct(y_rep[,i])[[1]]),
muHPDI_hi = sapply(1:dim(y_rep)[2], function (i) HPDIFunct(y_rep[,i])[[2]]))
And, once again, plot these means and HPDIs against the actual data.
ggplot(df, aes(x, y)) +
geom_point(colour = "slategrey", shape = 1) +
geom_line(data = cfDF_stan, aes(x=xSeq, y=muMean), colour = "red", linetype = "dashed") +
geom_line(data = cfDF_stan, aes(x=xSeq, y=muHPDI_low), linetype = "dotted") +
geom_line(data = cfDF_stan, aes(x=xSeq, y=muHPDI_hi), linetype = "dotted") +
theme_classic()
Now once again actual x values are in blue and counterfactual means are represtented by the red line, and HPDIs are represented by the black lines. These counterfactual HPDIs look very different to those generated by the rethinking::link()
function, remaining equidistant from the mean line the entire way, and not changing based on the presence or absence of data.
What am I doing wrong?