This seems like an edge case, but I’ve been running into it so wanted to post to see if there is a simple fix. From Gelman et al 2019, R^2 can vary from 0 to 1. In other models, I think it’s possible to have a negative R^2, specifically when it is calculated using: R^2 = 1 - \frac{SSE}{SST}, where SST = \sum{(y_i - \bar{y})^2} and SSE = \sum{(y_i - \hat{y_i})^2}. I was playing around with how a shifting prior affects the Bayesian R^2, and found some strange outcomes where a worse model gives a higher R^2, specifically because of the prior. (See example below).
I wonder if there is a way to rework the Gelman et al 2019 R^2 such that it can become negative, which might address this edge case?
Stan Model 1: Strong (Bad) Prior
functions {
real linear_predictor(real alpha, real beta, real x) {
real out;
out = alpha + beta*x;
return out;
}
}
data {
int<lower=0> N; // number of observations
vector[N] x; // predictor variable
vector[N] y; // response variable
}
parameters {
real alpha; // intercept
real beta; // slope
real<lower=0> sigma; // error standard deviation
}
model {
// Prior distributions
alpha ~ normal(0, 0.2); // prior for intercept
beta ~ normal(-3, 0.01); // prior for slope
// Likelihood
y ~ normal(alpha + beta * x, sigma); // likelihood function
}
generated quantities {
vector[N] y_pred; // generated quantities for predicted values
vector[N] y_pred_w_variance;
for (i in 1:N) {
y_pred[i] = linear_predictor(alpha, beta, x[i]);
y_pred_w_variance[i] = normal_rng(alpha + beta * x[i], sigma);
}
}
Stan Model 2: Strong Prior around 0
functions {
real linear_predictor(real alpha, real beta, real x) {
real out;
out = alpha + beta*x;
return out;
}
}
data {
int<lower=0> N; // number of observations
vector[N] x; // predictor variable
vector[N] y; // response variable
}
parameters {
real alpha; // intercept
real beta; // slope
real<lower=0> sigma; // error standard deviation
}
model {
// Prior distributions
alpha ~ normal(0, 0.2); // prior for intercept
beta ~ normal(0, 0.01); // prior for slope
// Likelihood
y ~ normal(alpha + beta * x, sigma); // likelihood function
}
generated quantities {
vector[N] y_pred; // generated quantities for predicted values
vector[N] y_pred_w_variance;
for (i in 1:N) {
y_pred[i] = linear_predictor(alpha, beta, x[i]);
y_pred_w_variance[i] = normal_rng(alpha + beta * x[i], sigma);
}
}
Comparing the two, you can see that the median R^2 for the worse model is ~0.4 and the (still bad) model is ~0.
library(rstan)
library(tidybayes)
library(magrittr)
library(ggplot2)
## generate a small dataset
x <- 1:5 - 3
y <- c(1.7, 2.6, 2.5, 4.4, 3.8) - 3
xy <- data.frame(x,y)
## read in stan model
stan_model_strong <- stan_model("scripts/stan_model_r_sq_check_strong_prior.stan")
stan_model_weak <- stan_model("scripts/stan_model_r_sq_check_weak_prior.stan")
## fitting the model in stan
fit_strong <- sampling(stan_model_strong,
data = list(N = 5,
x = xy$x,
y = xy$y),
iter = 4000,
chains = 4)
## fitting the model in stan
fit_weak <- sampling(stan_model_weak,
data = list(N = 5,
x = xy$x,
y = xy$y),
iter = 4000,
chains = 4)
## calculate r squared using method from Gelman et al
calculate_r_sq_gelman <- function(fit_out){
y_pred_solved <- rstan::extract(fit_out)$y_pred
residuals_y <- rstan::extract(fit_out)$y_pred - rstan::extract(fit_out)$y_pred_w_variance
var_residuals <- apply(residuals_y, 1, var)
var_pred_out <- apply(y_pred_solved, 1, var)
hist(var_pred_out/(var_pred_out + var_residuals))
return(median(var_pred_out/(var_pred_out + var_residuals)))
}
# plot the two r squared distributions
calculate_r_sq_gelman(fit_strong)
calculate_r_sq_gelman(fit_weak)
# plot the posterior predictions
post_pred_strong <- fit_strong %>%
spread_draws(alpha, beta, sigma) %>%
sample_n(100) %>%
ggplot() +
theme_bw() +
ggtitle(paste("R2 = ", round(calculate_r_sq_gelman(fit_strong), 2))) +
geom_point(data = xy, aes(x = x, y = y)) +
geom_abline(aes(slope = beta, intercept = alpha), alpha = 0.2)
post_pred_weak <- fit_weak %>%
spread_draws(alpha, beta, sigma) %>%
sample_n(100) %>%
ggplot() +
theme_bw() +
ggtitle(paste("R2 = ", round(calculate_r_sq_gelman(fit_weak), 3))) +
geom_point(data = xy, aes(x = x, y = y)) +
geom_abline(aes(slope = beta, intercept = alpha), alpha = 0.2)
ggarrange(post_pred_strong, post_pred_weak)