That’s an interesting example, thanks for sharing.
Cases like this, where both samplers have parameters with low ess are common, but I never could really figure out how do deal with those in benchmarks. In the paper we simply exclude all models where none of the samplers got a decent min-ess. That’s not because I don’t care about those, but I really don’t know how to compare things there. ess/second is a measure of efficiency, and efficiency only makes sense once you’ve established that you’re getting to the right place. If you’re not, faster isn’t better, it’s just producing wrong draws faster. If all samplers fail to some degree the question has to shift a bit: Instead of “how efficient are we” it should be “how close to the true posterior do we get”. And I just don’t know how to answer that in general.
A very natural thing to do is to just count how many parameters look bad, or to look at the median ess or something like this. But it isn’t hard to come up with counter examples to this:
Let’s take this simple model with a few coupled mild funnels, where we do know the true posterior anyway. Both nutpie and stan struggle with this quite a bit, and in essentially the same way.
data {
int N;
vector[N] scale;
}
parameters {
real log_std;
vector[N] y;
}
model {
log_std ~ normal(0, 1);
y ~ normal(0, exp(scale * log_std));
}
With these settings
rng = np.random.default_rng(1234)
scale = rng.normal(size=1000) * 0.2
compiled = nutpie.compile_stan_model(filename="model.stan")
If we sample with seed 1234 we get something that looks superficially ok based on most diagnostics:
tr = nutpie.sample(compiled.with_data(N=len(scale), scale=scale), chains=4, seed=1234)
No divergences, all the y variables have decent ess (even tail ess doesn’t look too bad), and rhat is at most 1.02 for y. Only log_std has an rhat of 1.07 and a low ess.
But the posterior for y is completely incorrect. Here is a plot of the actual standard deviations vs the standard deviations the draws give us for the different y values:
Suffice it to say that the dots are not on the diagonal…
If we just change the seed to 123456 we get way worse diagnostics. Suddenly rhat of log_std is 1.23, and lots of y values have bad ess and bad rhat values. But still, the standard deviations of y look much better:
This is of course a somewhat synthetic example, and I’m not trying to say that this is what’s going on in your model. But comparing non-convergent or only somewhat convergent model runs is a hard problem that I wish I knew how to do.