Nested R-hat < R-hat regardless of # of chains?!?

Hi,

I have found for my model that whether I use 2, 4, 8, 16, 32 or more chains (I have 64 cores) that the minimum nested R-hat (see [2110.13017] Nested $\hat R$: Assessing the convergence of Markov chain Monte Carlo when running many short chains ) is < min R-hat.

To be more specific, I am computing nested R-hat (nR) using the posterior R package (using the posterior::rhat_nested function), and also computing R-hat (R) using the same package with the posterior::rhat function.

For nR, I am using each chain as its own superchain (see code below). So 1 chain per superchain.
NOTE: If I instead use e.g. 2 superchains of 16 chains each, then nR decreases even further!

This isn’t the model i’m using (i’m not even using Stan for my model) however I made a minimal working example (MWE) for a simpler model where this occurs. The model is a standard multivariate probit model with the same number of coefficients per outcome but different coefficients per outcome. Specifically for this outcome, the number of covariates is 1 and the dimension (# outcomes) is 5.

Also, see the comments right at the bottom of the R code for examples (with 500 post-burnin iterations per chain) of the (minimum) nR and R. For example, for 2 chains I get R = 1.02 and nR = 1.004, and then for 32 chains I get R = 1.004 and nR = 1.001. In all cases nR is notably less than R!!

What’s going on? when using each chain as its own superchain shouldn’t nR = R?

Which one should I trust/report? (to be honest i’d prefer nR as then I won’t have to run the models anywhere near as long to get R < =1.01 …)

Plot: (red = min R, blue = min nR)

edit: added min ESS values to bottom of R code

Thanks!

@charlesm93
@avehtari

Stan model code:



functions {
  real Phi_approx_2(real b, real boundary_for_rough_approx) {
    real a;
    if (abs(b) < boundary_for_rough_approx) { // i.e. NOT in the tails of N(0, 1)
      a = Phi( b);
    } else { 
      a = inv_logit(1.702 * b);
    }
      return a;
  } 
}
 
 
 
 data {
  
  int<lower=1> K;
  int<lower=1> D;
  int<lower=0> N;
  array[N, D] int<lower=0, upper=1> y;
  array[N] matrix[D, K] X;
  real boundary_for_rough_approx;
  
}

parameters {
  
  matrix[D, K] beta;
  cholesky_factor_corr[D] L_Omega;
  array[N, D] real<lower=0, upper=1> u; // nuisance that absorbs inequality constraints
  
}

model {
  
  L_Omega ~ lkj_corr_cholesky(2);
  to_vector(beta) ~ normal(0, 1);
  // implicit: u is iid standard uniform a priori
  
  {
    // likelihood
    for (n in 1 : N) {
      vector[D] Xbeta_n;
      vector[D] z;
      real prev;
            
                 for (d in 1:D) {
                     Xbeta_n[d] =  X[n, d, ] * to_vector(beta[d,]) ; 
                 }
      
      prev = 0;
      for (d in 1 : D) {
        real bound; // threshold at which utility = 0
        real stuff =  (0 - (Xbeta_n[d] + prev) ) / L_Omega[d, d] ;
        bound = Phi_approx_2( stuff  , boundary_for_rough_approx);
        if (y[n, d] == 1) {
          real t;
          t = bound + (1 - bound) * u[n, d];
          // z[d] = inv_Phi(t); // implies utility is positive
            if   ( abs(stuff) < boundary_for_rough_approx)            z[d] = inv_Phi(t);
            else                                                      z[d] = logit(t) / 1.702;
            target += log1m(bound); // Jacobian adjustment
        } else {
          real t;
          t = bound * u[n, d];
        //   z[d] = inv_Phi(t); // implies utility is negative
            if   ( abs(stuff) < boundary_for_rough_approx)            z[d] = inv_Phi(t);
            else                                                      z[d] = logit(t) / 1.702;
            target += log(bound); // Jacobian adjustment
        }
        if (d < D) {
          prev = L_Omega[d + 1, 1 : d] * head(z, d);
        }
        // Jacobian adjustments imply z is truncated standard normal
        // thus utility --- Xbeta_n + L_Omega * z --- is truncated multivariate normal
      }
    }
  }
}

generated quantities {
  
  corr_matrix[D] Omega;
  Omega = multiply_lower_tri_self_transpose(L_Omega);
  
}


R code:


set.seed(1)
N <- 500

dim <- 5
n_coeffs_per_outcome <- 1
n_coeffs_total <- dim * n_coeffs_per_outcome

x1 <- runif(N, -1, 1)
x2 <- runif(N, -1, 1)
x3 <- runif(N, -1, 1)
x4 <- runif(N, -1, 1)
x5 <- runif(N, -1, 1)

x_cov_1 <- array(c(x1, x2, x3, x4, x5), dim = c(N, dim))
# x_intercept <- array(1, dim = c(N, dim))

X <- array(dim = c(N, dim, n_coeffs_per_outcome))

X[,,1] <- x_cov_1
# X[,,2] <- x_cov_1

Omega <- matrix(c(  1,  0,     0,        0,        0,
                    0,  1,     0.50,     0.25,     0,
                    0,  0.50,  1,        0.40,     0.40,
                    0,  0.25,  0.40,     1,        0.70,
                    0,  0,     0.40,     0.70,     1), 
                dim, dim)

 
errors <- mvtnorm::rmvnorm(N, c(0,0, 0, 0, 0), Omega)
# plot(errors)
# cor(errors) # realized correlation



# same covariate for each outcome (can extend to make different for each outcome too)
y1 <- 0 + x1 * -1    + errors[,1]
y2 <- 0    + x2 * -0.5  + errors[,2]
y3 <- 0  + x3 * 0     + errors[,3]
y4 <- 0    + x4 * 0.5   + errors[,4]
y5 <- 0  + x5 * 1     + errors[,5]

latent_results <- array(c(y1, y2, y3, y4, y5), dim = c(N, dim))
# plot(x, y1)
# plot(x, y2)
# plot(y1, y2)


 
binary_results  <- ifelse(latent_results > 0, 1, 0)
y_MVP <- binary_results



library(rstan)
rstan_options(auto_write = TRUE)

options(mc.cores = parallel::detectCores())



 

file <- file.path(file = "MVP_example_model_1.stan")
mod <- cmdstan_model(file)

data = list(K = n_coeffs_per_outcome, 
            D = dim,
            N = length(y1),
            X = X, 
            y = y_MVP, 
            boundary_for_rough_approx = 5)

seed <- 1 
n_chains <- 32
iter_warmup <- 500
iter_sampling <- 500
adapt_delta <- 0.80
max_treedepth <- 8
metric_type <- "diag_e" 


u_initial <- array(0.01, dim = c(N, dim))
L_Omega <- t(chol(Omega))
beta <- array(0.01, dim = c(dim, n_coeffs_per_outcome))
beta[1,1] <-  -1
beta[2,1] <-  -0.5
beta[3,1] <-   0
beta[4,1] <-   0.5
beta[5,1] <-   1

init <- list(u = u_initial, 
             L_Omega = L_Omega, 
             beta = beta)

model <- mod$sample(
  # data = stan_data_list[[df_i]],
  data = data,
  seed = seed,
  chains = n_chains,
  parallel_chains = n_chains,
  iter_warmup = iter_warmup,
  iter_sampling = iter_sampling, 
  refresh = round( ((iter_warmup + iter_sampling)/100), ),
  init = rep(list(init), n_chains), 
  save_warmup = 1, # for some efficiency stats (can turn off for sim. study)
  metric = metric_type,
  adapt_delta = adapt_delta, 
  max_treedepth = max_treedepth)



 


cmdstanr_model_out <- model$summary(variables = c( "beta", "Omega"), "mean", "median", "sd", "mad",  ~quantile(.x, probs = c(0.025,  0.975) ), "rhat" , "ess_bulk", "ess_tail")
print(cmdstanr_model_out, n = 100)
# min_ess <- round(min(cmdstanr_model_out$ess_bulk, na.rm=TRUE), 0)
# print(paste("min ESS = ", min_ess))

n_chains_for_rhat_comp <- 2

stan_draws_array <- model$draws()[,,][,1:n_chains_for_rhat_comp,]


n_us_nuisance <- N * dim 
n_elements_Omegas <-  dim * dim
n_main_params <-  n_coeffs_total +  n_elements_Omegas 
index_coeffs <- 2:(n_coeffs_total + 1)
index_Omega <-  (n_us_nuisance + n_main_params + 2):(n_us_nuisance + n_main_params + 2 + n_elements_Omegas - 1)
index_main_params <-  c(index_coeffs, index_Omega)

 

rhats_nested <-  rhats <- ess <-  c()
for (i in 1:length(index_main_params)) {
  rhats_nested[i] <-   posterior::rhat_nested( array(c(stan_draws_array[,,index_main_params[i]]), dim = c(iter_sampling, n_chains_for_rhat_comp)) , superchain_ids = seq(from = 1, to = n_chains_for_rhat_comp, by = 1))
  rhats[i] <-   posterior::rhat( array(c(stan_draws_array[,,index_main_params[i]]), dim = c(iter_sampling, n_chains_for_rhat_comp)) )
  ess[i] <- posterior::ess_basic( array(c(stan_draws_array[,,index_main_params[i]]), dim = c(iter_sampling, n_chains_for_rhat_comp)) )
}


print(round(max(rhats, na.rm = TRUE), 3))
print(round(max(rhats_nested, na.rm = TRUE), 3))
print(round(max(ess, na.rm = TRUE), 0))

### with 2 chains: 
# min R = 1.02
# min nR = 1.004
# min ESS = 2794

### with 4 chains: 
# min R = 1.009
# min nR = 1.003
# min ESS = 5221

### with 8 chains: 
# min R = 1.008
# min nR = 1.002
# min ESS = 10563

### with 16 chains: 
#  min R = 1.005
# min  nR = 1.002
# min ESS = 21625

### with 32 chains: 
# min R = 1.004
# min nR = 1.001
# min ESS = 43485



n_chains_vec <- c(2, 4, 8, 16, 32)
min_rhats <- c(1.02, 1.009, 1.008, 1.005, 1.004)
min_nested_rhats <- c(1.004, 1.003, 1.002, 1.002, 1.001)
min_ess <- c(2794, 5221, 10563, 21625, 43485)

par(mfrow  = c(1, 2))
plot(n_chains_vec, min_rhats, ylim = c(1, 1.025), col = "red", lwd = 3, cex = 3, pch = 19)
points(n_chains_vec, min_nested_rhats, col = "blue", lwd = 3, cex = 3, pch = 19)

plot(n_chains_vec, min_ess,  ylim = c(1000, 50000), col = "red", lwd = 3, cex = 3, pch = 19)



1 Like

Hi Enzo,
Thanks for sharing the example. I don’t have a good answer and I’ll need to take a closer look. I have some guesses as to what might be happening.

When you use a single chain per superchain, nested rhat (\widehat R_\mathfrak{n}) should behave like \widehat R, although there may be some small differences. The most obvious difference is that \widehat R_{\mathfrak n} uses a slightly different estimate of the sample variance (scaling by 1 / (N -1) instead of 1 / N). This is to insure \widehat{R}_\mathfrak{n} \ge 1, but it usually means that \widehat{R}_\mathfrak{n} > \widehat R; see footnote 3 in [2110.13017] Nested $\hat R$: Assessing the convergence of Markov chain Monte Carlo when running many short chains. But that wouldn’t explain what you’re observing.

There may be some differences in the implementation of the diagnostics in posterior. The two I can think of are:

  • \widehat R uses rank-normalization.
  • \widehat R uses folding.

See [1903.08008] Rank-normalization, folding, and localization: An improved $\widehat{R}$ for assessing convergence of MCMC. I don’t know if rank-normalization and folding are used for \widehat R_{\mathfrak n}, but that could explain the difference. One way to check this would be to run a longer warmup phase, and see if the difference between the diagnostics persists.

Another remark is that \widehat R and \widehat R_\mathfrak{n} become less noisy as you increase the number of Markov chains, but tneither decrease systematically. So something odd is going on here.

NOTE: If I instead use e.g. 2 superchains of 16 chains each, then nR decreases even further!

Yes, that makes sense. The higher the number of subchains in each superchain, the smaller the persistent variance. However the nonstationary variance remains unchanged. That’s why you can use \widehat R_\mathfrak{n} when running a large number of short chains.

2 Likes

Can you try with rhat_basic() instead of rhat() as the latter uses rank-normalization and rhat_nested() doesn’t?

2 Likes

Hey,

Thanks - I tried this on a couple of examples and I get no difference between posterior::rhat and posterior::rhat_basic

I also sent this on email but in hindsight should’ve just posted here:

I also experimented with multiple chains per superchain (results not included in the Stan post) e.g. 48 chains with 4 superchains (so 12 chains per superchain) and the nR-hat goes down a lot… for example from 1.007 (with 1 chain per superchain) to 1.001.

To be honest doing the above would make some my of research A LOT easier as I wouldnt need to run the chains anywhere near as long! (if I use like, 32+ chains at least with at least 2 superchains)

So do you think its OK if I use nR-hat with 48-64 chains and 4 superchains? Can I justify this for a paper?

edit:

Also, this is very ad-hoc and not particularly scientific, but my intuition tells me that the nR-hat estimates make more sense than the R-hat ones. For example, if I use 48 chains each with 500 iterations (so 24,000 total post-burnin samples), on one example I get ESS=4368. That’s pretty big. However R-hat is 1.011 (> 1.01). However nR=1.007 (1 chain per superchain) and 1.001 with 12 chains per superchain (4 superchains total)!

The regular R-hat at 1.011 seems to be kind of low given the ESS and how stable the estimates are across runs …

1.01 is not a magical threshold so that <1.01 is perfect and >1.01 is useless. Your observed 1.011 is just 0.1% away which is so little that you can’t see any meaningful effect in ESS or MCSE (which use R-hat internally).

Having ESS=4368 with S=24000, indicates there is high autocorrelation in each chain dominating the loss of efficiency. Did you get other diagnostic warnings?

We don’t have yet nested-Rhat based ESS computation, as taking into account the autocorrelation from very short chains is challenging. Your 500 is not short, and the ESS estimate using rhat_basic() internally is likely to be good.

1 Like

No I dont get other warnings in these examples.

Indeed this model is hard to sample from so there can be significant autocorrelation, even with strong priors.

I appreciate that the rhat_basic() estimate is good, but I don’t understand why I can use this but not the nR-hat with (e.g.) 4 superchains / 48-64 chains / 500 iter (or - why would rhat_basic() be better than rhat_nested() in this case? Would things change if I kept everything the same but used only 100 iterations per chain?

You can use.

I don’t think anyone claimed it would be better.

Compared to using 500 iterations, your ESS probably would go down by 80%, and due to very high autocorrelation rhat_basic() would start to be less stable, but not necessarily too much. ESS estimate would probably suffer more as you would have only ESS=18 per chain and autocorrelation time estimation gets difficult.

1 Like

Thanks! sorry for misunderstanding.

In that case perhaps i’ll report both measures but use the threshold of 1.01 just for nR-hat. If I use this threshold for regular R-hat it will take ~ 5-10x longer to conduct the study!

You can just report the actual Rhat values, no need to do any explicit thresholding for either Rhat. The values do tell you how close you are to 1, and any threshold is arbitrary (we try to say this in both Rhat papers, but it seems to be easily missed)

1 Like

Yes I agree about thresholds being totally arbitrary. However at some point (at least for publishing papers) you do probably need to set some sort of “rule” to follow? E.g. you probably wouldnt want to report outputs with a max(R) = 1.09. But between 1.01 and 1.05 things are a bit less clear. AFAIK 1.05 used to be the “default” threshold but now people seem to be using 1.01 more. In Stan I think it still uses 1.05 for the warnings.

I responded by email, so let me post here so that the conversation is public.

This makes sense. nRhat is designed to decrease as you increase the number of chains per superchains. However, you need to make sure that every chain within a superchain is initialized at the same point, else the diagnostic is not reliable. I recommend rereading Section 1 of our preprint (https://arxiv.org/pdf/2110.13017.pdf), and I can clarify points if needed.

We do recommend adjusting the threshold based on the number of subchains per superchain, so, as Aki pointed out, the 1.01 threshold is not universal. For example, you could try \sqrt{1 + 1/M} where M is the number of subchains. In your case, that threshold is 1.04. (This is the rule I’d consider, although other heuristics can be used.)
nRhat is designed for cases where we want to trade length of the sampling phase and number of chains. This seems relevant to your problem! So yes, I think nested Rhat is well suited for your application (if properly used).

I’ll echo Aki’s point that no threshold is magic, and what matters is to report the diagnostic. If you look at max(nested Rhat) across all parameters of interest, you’re already being conservative.

2 Likes

Thank you for all this info! very helpful

BTW perhaps a silly qs, but how did you get 1.008? I have 48 chains in total with 12 chains per superchain (so 4 superchains) - so that means in my case M = 12? or is “subchains” mean something else?

Yes, subchains and chains are used as synonyms. Subchains specifically refers to a chain inside a superchain. M is the number of chains per superchains. But you need to make sure all chains within a superchain starts at the same point.

If I have 12 chains per superchain, M = 12 so 1 + 1/M = 1 + 1/12 = 1.08 , not 1.008? Also, looiking at your paper (e.g. section 4.2) it looks like it says \sqrt{1 + 1/M} ?

Yes, indeed, I meant \sqrt{1 + 1 / M}.

1 Like

Thanks!

Is there any suggestion as to what the min # of chains should be before one should consider looking at n\hat{R} ? for example 16 chains?