Hi,
I am trying to understand the working of ‘predict’ function by manually calculating predictions and comparing them with those obtained by using the ‘brms predict’ function. However, even after setting the seed, the results differ. Unlike the predicted values, I was successful in getting the exact same fitted values.
A simple example is shown below
library(brms)
library(tidyverse)
set.seed(123)
# generate data
n = 100
x = sort(rnorm(n, 0, 1))
b0 = 10
b1 = 2
sd = 0.5
y = b0 + b1*x + rnorm(n = n, mean = 0, sd = sd)
dat <- data.frame(cbind(x,y))
# model fitting
bfit <- brm(bf(y ~ x), backend = "rstan", data = dat,
seed = 123, chains = 2, iter = 2000, cores = 2)
# set seed for reproducibility (predicted values)
set.seed(123)
#####################
# get fitted and predicted values by using brms functions
#####################
# fitted estimates
bfitted_values <- fitted(bfit)
# head(bfitted_values[,1])
# predicted estimates
bpredict_values <- predict(bfit)
# head(bpredict_values[,1])
#####################
# manually calculate fitted and predicted values
#####################
# parnames(bfit)
nsamples <- nsamples(bfit)
psb <- posterior_samples(bfit)
Xmat <- cbind(rep(1, nrow(dat)) , x)
fitted_values <- matrix(0, nrow = nsamples , ncol = nrow(dat) )
predicted_values <- matrix(0, nrow = nsamples , ncol = nrow(dat) )
for(j in 1:nsamples) {
eb0 <- psb[j,"b_Intercept"] %>% unlist() %>% as.matrix()
eb1 <- psb[j,"b_x"] %>% unlist() %>% as.matrix()
mu <- Xmat[,1] %*% eb0 + Xmat[,2] %*% eb1
sd <- psb[j,"sigma"]
fitted_values[j,] <- mu
predicted_values[j,] <- rnorm(nrow(dat) , mu, sd)
}
mfitted_values <- apply(fitted_values, 2, mean, na.rm = TRUE)
# head(mfitted_values)
mpredicted_values <- apply(predicted_values, 2, mean, na.rm = TRUE)
# head(mpredicted_values)
.fitted <- cbind(bfitted_values[,1], mfitted_values) %>% data.frame()
colnames(.fitted ) <- c("brms", "manual")
.predicted <- cbind(bpredict_values[,1], mpredicted_values) %>% data.frame()
colnames(.predicted ) <- c("brms", "manual")
# compare fitted values
head(.fitted)
brms manual
5.229839 5.229839
5.928913 5.928913
6.500177 6.500177
6.781684 6.781684
7.359954 7.359954
7.360638 7.360638
# compare predicted values
head(.predicted)
brms manual
5.243650 5.222802
5.922818 5.925199
6.500232 6.502756
6.775731 6.775499
7.352880 7.362450
7.372583 7.344495
- Operating System:
- brms Version: 2.15.0