Hi,
I have found for my model that whether I use 2, 4, 8, 16, 32 or more chains (I have 64 cores) that the minimum nested R-hat (see [2110.13017] Nested $\hat R$: Assessing the convergence of Markov chain Monte Carlo when running many short chains ) is < min R-hat.
To be more specific, I am computing nested R-hat (nR) using the posterior
R package (using the posterior::rhat_nested
function), and also computing R-hat (R) using the same package with the posterior::rhat
function.
For nR, I am using each chain as its own superchain (see code below). So 1 chain per superchain.
NOTE: If I instead use e.g. 2 superchains of 16 chains each, then nR decreases even further!
This isn’t the model i’m using (i’m not even using Stan for my model) however I made a minimal working example (MWE) for a simpler model where this occurs. The model is a standard multivariate probit model with the same number of coefficients per outcome but different coefficients per outcome. Specifically for this outcome, the number of covariates is 1 and the dimension (# outcomes) is 5.
Also, see the comments right at the bottom of the R code for examples (with 500 post-burnin iterations per chain) of the (minimum) nR and R. For example, for 2 chains I get R = 1.02 and nR = 1.004, and then for 32 chains I get R = 1.004 and nR = 1.001. In all cases nR is notably less than R!!
What’s going on? when using each chain as its own superchain shouldn’t nR = R?
Which one should I trust/report? (to be honest i’d prefer nR as then I won’t have to run the models anywhere near as long to get R < =1.01 …)
Plot: (red = min R, blue = min nR)
edit: added min ESS values to bottom of R code
Thanks!
Stan model code:
functions {
real Phi_approx_2(real b, real boundary_for_rough_approx) {
real a;
if (abs(b) < boundary_for_rough_approx) { // i.e. NOT in the tails of N(0, 1)
a = Phi( b);
} else {
a = inv_logit(1.702 * b);
}
return a;
}
}
data {
int<lower=1> K;
int<lower=1> D;
int<lower=0> N;
array[N, D] int<lower=0, upper=1> y;
array[N] matrix[D, K] X;
real boundary_for_rough_approx;
}
parameters {
matrix[D, K] beta;
cholesky_factor_corr[D] L_Omega;
array[N, D] real<lower=0, upper=1> u; // nuisance that absorbs inequality constraints
}
model {
L_Omega ~ lkj_corr_cholesky(2);
to_vector(beta) ~ normal(0, 1);
// implicit: u is iid standard uniform a priori
{
// likelihood
for (n in 1 : N) {
vector[D] Xbeta_n;
vector[D] z;
real prev;
for (d in 1:D) {
Xbeta_n[d] = X[n, d, ] * to_vector(beta[d,]) ;
}
prev = 0;
for (d in 1 : D) {
real bound; // threshold at which utility = 0
real stuff = (0 - (Xbeta_n[d] + prev) ) / L_Omega[d, d] ;
bound = Phi_approx_2( stuff , boundary_for_rough_approx);
if (y[n, d] == 1) {
real t;
t = bound + (1 - bound) * u[n, d];
// z[d] = inv_Phi(t); // implies utility is positive
if ( abs(stuff) < boundary_for_rough_approx) z[d] = inv_Phi(t);
else z[d] = logit(t) / 1.702;
target += log1m(bound); // Jacobian adjustment
} else {
real t;
t = bound * u[n, d];
// z[d] = inv_Phi(t); // implies utility is negative
if ( abs(stuff) < boundary_for_rough_approx) z[d] = inv_Phi(t);
else z[d] = logit(t) / 1.702;
target += log(bound); // Jacobian adjustment
}
if (d < D) {
prev = L_Omega[d + 1, 1 : d] * head(z, d);
}
// Jacobian adjustments imply z is truncated standard normal
// thus utility --- Xbeta_n + L_Omega * z --- is truncated multivariate normal
}
}
}
}
generated quantities {
corr_matrix[D] Omega;
Omega = multiply_lower_tri_self_transpose(L_Omega);
}
R code:
set.seed(1)
N <- 500
dim <- 5
n_coeffs_per_outcome <- 1
n_coeffs_total <- dim * n_coeffs_per_outcome
x1 <- runif(N, -1, 1)
x2 <- runif(N, -1, 1)
x3 <- runif(N, -1, 1)
x4 <- runif(N, -1, 1)
x5 <- runif(N, -1, 1)
x_cov_1 <- array(c(x1, x2, x3, x4, x5), dim = c(N, dim))
# x_intercept <- array(1, dim = c(N, dim))
X <- array(dim = c(N, dim, n_coeffs_per_outcome))
X[,,1] <- x_cov_1
# X[,,2] <- x_cov_1
Omega <- matrix(c( 1, 0, 0, 0, 0,
0, 1, 0.50, 0.25, 0,
0, 0.50, 1, 0.40, 0.40,
0, 0.25, 0.40, 1, 0.70,
0, 0, 0.40, 0.70, 1),
dim, dim)
errors <- mvtnorm::rmvnorm(N, c(0,0, 0, 0, 0), Omega)
# plot(errors)
# cor(errors) # realized correlation
# same covariate for each outcome (can extend to make different for each outcome too)
y1 <- 0 + x1 * -1 + errors[,1]
y2 <- 0 + x2 * -0.5 + errors[,2]
y3 <- 0 + x3 * 0 + errors[,3]
y4 <- 0 + x4 * 0.5 + errors[,4]
y5 <- 0 + x5 * 1 + errors[,5]
latent_results <- array(c(y1, y2, y3, y4, y5), dim = c(N, dim))
# plot(x, y1)
# plot(x, y2)
# plot(y1, y2)
binary_results <- ifelse(latent_results > 0, 1, 0)
y_MVP <- binary_results
library(rstan)
rstan_options(auto_write = TRUE)
options(mc.cores = parallel::detectCores())
file <- file.path(file = "MVP_example_model_1.stan")
mod <- cmdstan_model(file)
data = list(K = n_coeffs_per_outcome,
D = dim,
N = length(y1),
X = X,
y = y_MVP,
boundary_for_rough_approx = 5)
seed <- 1
n_chains <- 32
iter_warmup <- 500
iter_sampling <- 500
adapt_delta <- 0.80
max_treedepth <- 8
metric_type <- "diag_e"
u_initial <- array(0.01, dim = c(N, dim))
L_Omega <- t(chol(Omega))
beta <- array(0.01, dim = c(dim, n_coeffs_per_outcome))
beta[1,1] <- -1
beta[2,1] <- -0.5
beta[3,1] <- 0
beta[4,1] <- 0.5
beta[5,1] <- 1
init <- list(u = u_initial,
L_Omega = L_Omega,
beta = beta)
model <- mod$sample(
# data = stan_data_list[[df_i]],
data = data,
seed = seed,
chains = n_chains,
parallel_chains = n_chains,
iter_warmup = iter_warmup,
iter_sampling = iter_sampling,
refresh = round( ((iter_warmup + iter_sampling)/100), ),
init = rep(list(init), n_chains),
save_warmup = 1, # for some efficiency stats (can turn off for sim. study)
metric = metric_type,
adapt_delta = adapt_delta,
max_treedepth = max_treedepth)
cmdstanr_model_out <- model$summary(variables = c( "beta", "Omega"), "mean", "median", "sd", "mad", ~quantile(.x, probs = c(0.025, 0.975) ), "rhat" , "ess_bulk", "ess_tail")
print(cmdstanr_model_out, n = 100)
# min_ess <- round(min(cmdstanr_model_out$ess_bulk, na.rm=TRUE), 0)
# print(paste("min ESS = ", min_ess))
n_chains_for_rhat_comp <- 2
stan_draws_array <- model$draws()[,,][,1:n_chains_for_rhat_comp,]
n_us_nuisance <- N * dim
n_elements_Omegas <- dim * dim
n_main_params <- n_coeffs_total + n_elements_Omegas
index_coeffs <- 2:(n_coeffs_total + 1)
index_Omega <- (n_us_nuisance + n_main_params + 2):(n_us_nuisance + n_main_params + 2 + n_elements_Omegas - 1)
index_main_params <- c(index_coeffs, index_Omega)
rhats_nested <- rhats <- ess <- c()
for (i in 1:length(index_main_params)) {
rhats_nested[i] <- posterior::rhat_nested( array(c(stan_draws_array[,,index_main_params[i]]), dim = c(iter_sampling, n_chains_for_rhat_comp)) , superchain_ids = seq(from = 1, to = n_chains_for_rhat_comp, by = 1))
rhats[i] <- posterior::rhat( array(c(stan_draws_array[,,index_main_params[i]]), dim = c(iter_sampling, n_chains_for_rhat_comp)) )
ess[i] <- posterior::ess_basic( array(c(stan_draws_array[,,index_main_params[i]]), dim = c(iter_sampling, n_chains_for_rhat_comp)) )
}
print(round(max(rhats, na.rm = TRUE), 3))
print(round(max(rhats_nested, na.rm = TRUE), 3))
print(round(max(ess, na.rm = TRUE), 0))
### with 2 chains:
# min R = 1.02
# min nR = 1.004
# min ESS = 2794
### with 4 chains:
# min R = 1.009
# min nR = 1.003
# min ESS = 5221
### with 8 chains:
# min R = 1.008
# min nR = 1.002
# min ESS = 10563
### with 16 chains:
# min R = 1.005
# min nR = 1.002
# min ESS = 21625
### with 32 chains:
# min R = 1.004
# min nR = 1.001
# min ESS = 43485
n_chains_vec <- c(2, 4, 8, 16, 32)
min_rhats <- c(1.02, 1.009, 1.008, 1.005, 1.004)
min_nested_rhats <- c(1.004, 1.003, 1.002, 1.002, 1.001)
min_ess <- c(2794, 5221, 10563, 21625, 43485)
par(mfrow = c(1, 2))
plot(n_chains_vec, min_rhats, ylim = c(1, 1.025), col = "red", lwd = 3, cex = 3, pch = 19)
points(n_chains_vec, min_nested_rhats, col = "blue", lwd = 3, cex = 3, pch = 19)
plot(n_chains_vec, min_ess, ylim = c(1000, 50000), col = "red", lwd = 3, cex = 3, pch = 19)