Summarising Rhat values over multiple variables/fits

In the SBC package under development, I use the maximum Rhat value over all parameters in a fit as one of diagnostics whether the fit was OK overall. When exploring a set of all fits of the SBC runs I also look at the maxima of those maxima, as with 100+ fits one just cannot really look at single values.

In this setting it turns out that the default 1.01 threshold is quite strict and with a lot of variables and lots of fits, there will almost always be some Rhats that are a bit a larger (something like 1.015 or 1.023). My current understanding is that this is to be expected even for well-behaving models: Rhat is itself stochastic, so some proportion of slightly high Rhats will be seen.

So the question is:

  1. Is there a better way to summarise multiple Rhat values to get a good diagnostic than just taking the max?
  2. If taking the maximum is sensible, should I adjust the 1.01 threshold? And how? My current hacky idea is to assume that Rhats in a well behaving model are approximately i.i.d. N(1, 0.005) distributed and then use the extreme value distribution to estimate 0.99 quantile of the distribution of the maxima and use this as a threshold. This gives me the following thresholds:
   n_vars rhat_thresh
1       1    1.011632
2       2    1.013574
3       3    1.014711
4       4    1.015517
5       5    1.016142
6       6    1.016653
7       7    1.017085
8       8    1.017460
9       9    1.017790
10     10    1.018085
11    100    1.019771
12   1000    1.022022
13  10000    1.024239

Does that make sense? Or maybe I should just avoid summarising and show histograms of rhats or similar? Or report the percentage of rhats that exceeds the 1.01 threshold and assume some low percentage is not a concern (and thus would not be reported as a warning)?

Thanks for any ideas.

2 Likes

What if you use Rstar?

2 Likes

What about a user-customisable metric: the proportion of SBC iterations with \leq N_p parameters with an Rhat \geq \rho.

The user can define N_p and \rho to be suitable for their context, e.g., depending on the total number of parameters of the user’s model.

You might even consider plotting this metric (i.e. the proportion value) in a heat map with N_p and \rho on the x and y axes. Such a plot might share some typical characteristics across different models that have or have not converged.

Edit: typo fix

1 Like

While it’s new and thereby a little less time-tested, I really like r-star. A couple notes on it: (1) It should probably include lp__ and possibly also energy__; (2) I’ve wondered if it might work best when run with all parameters on their unconstrained-scale representations.

1 Like

Update: the problem seems to be almost exclusively the stochasticity of Rhat if the number of post-warmup samples is low.

Taking a very simple model:

data {
   int<lower=0> N;
   real y[N];
}

parameters {
   real mu;
}

model {
   mu ~ normal(0, 2);
   y ~ normal(mu, 1);
}

I simulate data exactly from this model.

Running 1000 fits with 1000 warmup iterations but 200 sampling iterations: 499 fits had Rhat > 1.01. Largest Rhat was 1.042.

Running 1000 fits with 200 warmup iterations but 1000 sampling iterations: All rhats < 1.01

So it seems that the model is able adapt and warmup quickly, so all the chains actually explore the same distribution in both cases, just that the Rhat estimate has larger variance with fewer samples (which sounds unsurprising).

I did check that 200 i.i.d. samples have unproblematic Rhat while autocorrelated samples can have problems, so it probably is the combination of autocorrelation + low number of samples that is tripping Rhat to give false positives.

Code for simulation

I.i.d. samples:

n_iter <- 200
for(i in 1:10) {
   var <- posterior::rvar(array(rnorm((n_iter)  * 4), dim = c(n_iter, 4)), with_chains = TRUE, nchains = 4)
   print(posterior::rhat(var))
}

A representative result:

[1] 1.000692
[1] 1.001277
[1] 0.9993861
[1] 0.9969532
[1] 1.000892
[1] 1.001532
[1] 1.006309
[1] 1.001748
[1] 0.9996341
[1] 1.004682

Autocorrelated samples (a sliding-window linear combination of i.i.d normals - hope that’s not particularly stupid):

n_iter_high <- 1000
n_iter_low <- 200
lag_coeffs <- rev(c(1, 0.8,0.5))

N_sims <- 10
res <- array(NA_real_, dim = c(N_sims, 2), dimnames = list(NULL, paste0(c(n_iter_low, n_iter_high), "_iter")))
n_lags <- length(lag_coeffs)
for(i in 1:N_sims) {
   draws_latent <- array(rnorm((n_iter_high + n_lags) * 4), dim = c(n_iter_high + n_lags, 4))
   draws_observed <- array(NA_real_, dim = c(n_iter_high, 4))

   for(n in 1:n_iter_high) {
      for(c in 1:4) {
         draws_observed[n, c] <- sum(draws_latent[n:(n + n_lags - 1), c] * lag_coeffs)
      }
   }
   var_high <- posterior::rvar(draws_observed, with_chains = TRUE, nchains = 4)
   var_low<- posterior::rvar(draws_observed[(n_iter_high - n_iter_low):n_iter_high, ], with_chains = TRUE, nchains = 4)
   res[i, 1] <- posterior::rhat(var_low)
   res[i, 2] <- posterior::rhat(var_high)
}
res

Gives something like:

      200_iter 1000_iter
 [1,] 1.005455  1.001189
 [2,] 1.002999  1.000354
 [3,] 1.020892  1.003734
 [4,] 1.009019  1.003877
 [5,] 1.011578  1.002240
 [6,] 1.004152  1.001302
 [7,] 1.006214  1.001396
 [8,] 1.006232  1.001696
 [9,] 1.010586  1.002576
[10,] 1.010942  1.000831

I.e. the last 200 iterations from the same array have substantially more high Rhats than the whole thing.

I forgot about this - yes, that could help partially. But given that I have many fits with an R* value for each, I still need to determine whether to raise a warning overall. Do you have an idea how one could pick a sensible threshold. Also since R* is even more stochastic than Rhat, I would expect to see similar problems with short chains (but I didn’t check yet).

  • In the end you should care about MCSE, but as a quick scale free diagnostic Rhat is useful, but any Rhat threshold not derived from MCSE is ad hoc
  • 1.01 was chosen assuming one or a few Rhats are examined and the chains are run long enough to be able to infer autocorrelations well, too.
  • In the new Rhat paper we didn’t explicitly discuss multiple comparisons, but what you write is the natural way to think about it.
  • Multiple comparison correction as you describe is one way. When there are many variables, a more fancy approach would be to use a model to learn the variation.
  • Looking at just the percentage exceeding is not enough as those exceeding might exceed a lot, so it’s better to assume some distribution for the Rhats and compare to that.
  • For repeated automated testing a single binary decision can be useful, but in case of triggering the threshold, there should be more information available. I usually just eyeball the Rhats, but plotting a histogram of Rhats with a assumed distribution overlaid could be a useful way to look if the highest Rhats are suspiciously high.
  • If in doubt, run more iterations
  • If more iterations would be very costly, look at the other diagnostics such as ESS, MCSE, R*, etc.
3 Likes