Using the PSIS statistics to check variational inference used for BNNs

I found the paper Yes, but Did It Work?: Evaluating Variational Inference very insightful and was wondering whether one could apply the two proposed statistics to the following problem:

Currently people try to scale BNNs using various variational inference approaches, like for instance Weight uncertainties in neural networks, also known as “Bayes by backprop”.

Couldn’t we use, at least, the PSIS-statistics to check whether the variational approximations they use in this paper might be flawed? If I’m not mistaken, calculating the NN induced joint density function is tractable, since we can run easily forward passes in the NN. For instance one could start with the MNIST or regression example mentioned in the paper.

Just wanted to check if I oversee some complexity, here?!

1 Like

Yes, you could use it for what you propose. I’m afraid that it would just give high k-values, because the posterior is highly multimodal and high-dimensional, which just reflects that the variational approximations for neurak networks are “flawed” in that sense that they don’t present well the whole posterior (although they can capture something which is useful)

1 Like

Thanks, I will check on one example which already has code available, just out of curiosity. In this regard, I was thinking about the following practical question:

Suppose we have N “training data points” with a corresponding N-dimensional vector y (in the notation of “Yes, but Did It Work? Evaluating Variational Inference”). In the particular case I want to explore (MNIST) we have 60000 points. For each sample \theta_s (weights and biases of the NN) from the VI approximation, would you really calculate p(\theta_s, y) for the complete training data points (the model assumes the points are independent so the likelihood factorises), or would you think evaluating p(\theta_s, y) on (randomly chosen) batches of size, say 128, would suffice?

Provided I did not do a mistake in calculating the log-weights for psislw in the loo package, I get a k-value which is 176.5059. A value that large surprised me, but since I am new to the PSIS concept, this doesn’t mean anything.

For reference here is the code:

df_lw <- readr::read_csv("lws_bnn.txt",col_names = FALSE )
psis <- psislw(df_lw$X1)
psis_k <- psis$pareto_k

And the file of the log of the ratio of p(\theta_s, y) and q(\theta_s) samples is attached. A snapshot:


It seems to me that this is way too small, since we consider log’s?!

lws_bnn.txt (25.4 KB)

It seems equations don’t display well in quotes. I don’t understand the question. ADVI in Stan uses all observations and doesn’t have minibatch variant. After ADVI has stopped, you can choose how many draws to sample from the approximation. For these draws compute log(p(theta|y)) and log(q(theta|y)) (unnormalized) and Pareto fit is computed for the largest ratios exp(log(p(theta|y))-log(q(theta|y)).

Any k>1 is very large indicating that you would need infinite draws. We have observed also before k>100.

What is too small?

I am not very surprised to see such extremely large log ratios log q/p if the approximation is far away from the true density.

I referred to the actual log of weights. Just off the top of my head, I was expecting large weights and not weights that have a log that is around -10^6. I guess this expectation of mine is rooted in the following quote from the “Yes, but did it work?”-paper:

Second, VI approximation q(\theta) is not designed for an optimal IS proposal, for it has a lighter tail than p(\theta\vert y) as result of entropy penaltization, which lead to heavy right tail of r_s. A few large-valued r_s dominates the summation, bringing large uncertainty.

I was therefore expecting log weights much larger, like at least -10, possibly even positive.

If you normalize the weights to sum to 1, then with heavy right tail one or a few weights will be non-zero within floating-point accuracy. Maximum normalized weight is 1 and log of that is 0. Before normalization log weights can be far away from 0 (negative or positive) if you compute normalization term, for example, only for q but not for p.