HMC for Bayesian neural networks?

Bayesian neural networks, even relatively simple ones like two layer multilayer perceptrons, seem on the face of it that they will be plagued by mulitmodality and lack of identifiability. The discussion in the thread Why are Bayesian Neural Networks multi-modal? seems to confirm that impression.

I was surprised then to read the paper “What Are Bayesian Neural Network Posteriors Really Like?” by Izmailov et al that uses HMC for large scale Bayesian neural networks. The size of the models and simulations is impressive. For example, they use 20 layer networks applied to 60000 images, and so multi-million parameters, and parallelize the HMC computation over 512 TPUs. They report, see Section 5, that \hat{R} values are mostly acceptably low. Overall, they do not seem concerned by multimodality or lack of identifiability. In fact, there is not much mention of multimodality problems in the paper, and no mention of lack of identifiability.

Overall, this paper seems to imply that Bayesian neural networks using HMC are practically feasible, albeit using impressively powerful hardware. This seems somewhat at odds with the impression I get from the Stan community about Bayesian neural networks. For example, see above mentioned Stan forum thread or Neural Networks in Stan: Or how I was utterly surprised that it worked at all. | Stephen R. Martin, PhD where author writes For sampling neural nets, I wouldn’t recommend it. The Stan team wouldn’t recommend it either. It’s just an intractable posterior.

I was wondering if anyone has anyone here has any comments to make about this? Are the problems with Bayesian neural nets greater than they seem in the Izmailov paper?

3 Likes

I think it’s extremely questionable our current MCMC diagnostic tools whether that is ESS or \widehat{R} are actually reliable at neural network scale. It’s possible to say that HMC is effective at exploring a healthy portion of the posterior, but good \widehat{R} is not a guarantee that it is exploring all of it.

I think this is partly explained by the authors using a relative loose definition of acceptably low as \hat{R} \leq 1.1 which is quite a loss less stringent than both the recommendation in Vehtari et al. (2022) of \hat{R} \leq 1.01 and in the rstan documentation of \hat{R} \leq 1.05. Even with this relatively coarse threshold a non-negligible fraction of the functions tested have \hat{R} exceeding this value for the ‘function-space’ test functions and an even larger number for the ‘weight-space’ test functions. As the numbers of test functions considered in both the function- and weight-spaces are high (from the histogram scale > 10^5) and the chains are relatively short I think we would expect some larger \hat{R} values even for well converged chains because of the variance in the \hat{R} estimator but the distribution of \hat{R} values here still seems to be indicative of non-convergence.

There are some oblique references I believe to non-identifiability / multimodality in the paper. I think the decision to concentrate on convergence diagnostics in the prediction / function-space rather than the parameter space is inherently due to the non-identifiability issues. Invariance of the predictions to permutations of the hidden units and scale ambiguities between successive layers will reflect as multimodality / manifold like structure in the posterior geometry in the parameter space but not the pushforward distribution in the prediction space. The references to ‘connected regions’ and ‘mode connectivity’ in the Implications for the Posterior Geometry section are I also think implicitly recognising that there will be non-identifiabilities leading to the posterior concentrating around submanifold(s) of the parameter space.

4 Likes

I agree that the threshold for \hat{R} that they use is higher than that in the Stan community. I am still surprised that it wasn’t much worse, i.e. \hat{R} values that indicate no mixing at all, with each chain exploring one of the very large number of different posterior modes.

Overall, I am surprised it worked at all, and wasn’t an abject failure, similarly to how Stephen Martin was surprised that his Stan neural network worked ( see Neural Networks in Stan: Or how I was utterly surprised that it worked at all. ) I assume that just due to label switching alone that there would be very large numbers of disconnected posterior modes.

Here’s their plot.

First note the y axis is on the log scale, which means most values really are near 1.

Things mix much better for predictions (“function space”) than parameters (“weight space”).

They don’t indicate any special tricks for alignment under label switching, so we know this shouldn’t work. But, if you have a bunch of weights that don’t do anything—they just look like the prior, then any two of them will look the same under R-hat.

I don’t know if @matthewdhoffman is still getting messages from our forum, but maybe he can explain what’s going on as an author of the paper in question.

2 Likes

Hi everyone! Turns out I do still get emails about @ mentions :)

I think the summaries thus far have been pretty on point. There’s still a lot we don’t understand about the geometry of Bayesian neural net (hereafter BNN) posteriors, but I definitely count myself among those who are “utterly surprised that it work[s] at all.”

A few points that might be worth emphasizing:
• The deep learning theory community has been converging on a consensus that heavily overparameterized neural networks have cost landscapes (i.e., log-likelihoods) that are much easier to move around than small neural networks. So posteriors from small BNNs seem likely to be qualitatively harder to sample from than large BNNs. (But of course large BNNs are also more expensive to compute gradients for.)
• My mental picture of overparameterized BNN posteriors with Normal(0, σ^2 I) priors (and σ reasonably large and a moderately large dataset like CIFAR-10) resembles a shell made of finely veined marble. As long as the network is large enough that we can get almost all of the examples right, the posterior will concentrate on the intersection of the typical set of the prior (which gives us the sphere) and the set of parameter values that get (almost) all training examples right with high confidence (which gives us the marbling). The marble “veins” are likely to be pretty evenly distributed throughout the prior shell, due to both the obvious nonidentifiability (from label switching and arbitrary rotations) and the overparameterization of the model (which allows for multiple parameter settings that get 100% agreement on the training set but make substantively different predictions on held-out data).
• Under the “marble sphere” model above, we’d expect the marginal mean and variance of each parameter to match the mean and variance of the prior if the marbling is fine enough. Near as I can tell this is pretty true for the middle layers, but IIRC things are a bit heavier tailed near the top and bottom of the network. In any case, we know that all of the weights in a layer should have the same marginal distribution under the posterior, since we can permute things arbitrarily.
• If we were really mixing perfectly (hope springs eternal!), we’d expect R-hat to be happy in weight space. R-hat is just a screen, of course—it can’t tell us that we’re actually mixing, and we could easily trick it by applying random likelihood-preserving permutations (which would be a valid MCMC move).
• I take the plot above as meaning that R-hat is not fully happy in weight space. That is, HMC is not managing to find paths between “basins of attraction” or whatever you want to call them (I don’t love the term “mode” in this context since it feels like it suggests a certain amount of convexity, and I think there’s decent evidence that these basins of attraction are extremely nonconvex).
• The R-hat story in function space is more mixed. Samples from different chains agree very well on almost all examples, but “almost all” is not the same thing as “all”. One might have hoped that all of the challenges in mixing were due to label switching, in which case weight-space R-hat would have been unhappy but function-space R-hat would have been fine. That doesn’t seem to quite be the case. Nonetheless, the within-chain variance is (to me) surprisingly close to the between-chain variance.

I’m still not sure that anyone (including us) has actually gotten an unbiased sample from the posterior of a (finite-width) BNN, but I guess my takeaway from this work is that HMC for BNNs is a lot less hopeless than I would have thought a few years ago.

Also, note that I am emphatically not claiming that this will typically be a practical recipe! This was mostly meant as a science experiment—if you just want good results on CIFAR-10 or IMDB reviews, you can get them with far fewer TPU hours. I guess there might be some small datasets where it’s worth considering, though.

9 Likes

@matthewdhoffman Thank for you for that excellent commentary and for your, as you call it, “science experiment”. It was extremely thought provoking, to say the least.

Interesting question. I don’t have anything substantive to contribute to the discussion, except to suggest going back to Radford Neal’s papers in the 90’s, including his dissertation (1992 I think) in which he introduced the idea of applying HMC to neural networks (pretty sure it was his idea, if not, he’s pretty close to the origin of it). Neal’s insight is pretty strong, at least I thought so, so if he had any comment on “how it works at all”, it would be worth considering.

I suppose what counts as a large or small network has changed a lot in 30 years, although I would guess the underlying issues remain the same.

Sorry that I can’t be more helpful.

Robert

1 Like