@aseyboldt just dropped the following for PyMC and provides an example of how to use it with Stan models:
I’m curious how this will work with Stan models, which cannot be efficiently parallelized or JIT-ed in JAX. Adrian’s asking for feedback if you have any.
Thanks for linking it, I am indeed quite curious how well this works for other people.
Please keep in mind that the choice of normalizing flow and details of the optimization can be quite important, and I don’t think I have good defaults for those yet. If someone here has experience with those and wants join the fun, I’d be glad. :-)
Just to clarify a bit how jax and stan interact:
In the new algorithm we repeat two steps:
Sampling in a transformed space
Optimizing the posterior transformation
Compared with standard nuts (at least if it works well…), we tend spend a lot more time in the optimization, and less in sampling, as the number of required gradient evaluations is often much smaller than with standard nuts, but the optimization can be quite expensive.
But we only need the stan model during sampling, not during optimization. The input to the optimization is simply a set of posterior draws, together with their scores (logp gradients), and so we can run the optimization itself in jax and on a GPU.
The use-case right now is models that are slow with nuts or don’t converge at all.
At least with some parameter tuning it can sample funnels, correlations, or weird non-normal posteriors. Surprisingly, it was even fine with a low-dimensional multimodal posterior.
I’m currently working my way through posteriordb, trying to find good normalizing flows for lot’s of non-converging models in there, and hopefully I can then generalize good defaults.
The framework itself is also quite a bit more general, it allows for instance automatically choosing between centered and non-centered parametrizations (if it knows where they are, I don’t think that’s easy to know in a stan model).
Edit: To be fair, I should mention that I often still have a few divergences left, combining this with the variable step size nuts might be really cool.
Is there doc on what it’s doing for that somewhere? This is something we’ve been trying to deal with for years. We usually do know where funnels might arise—whether they do is often an empirical question about the quality of the data.
I meant to ask what this means. The traditional mass matrix M is used as a parameter in a multivariate normal defining the kinetic energy function for velocity \rho:
\rho \sim \textrm{normal}(0, M)
Then it’s used in the leapfrog algorithm updates for position \theta with step size \epsilon > 0,
\theta = \theta + \epsilon \cdot \dfrac{\rho}{M},
where \dfrac{\rho}{M} = \rho \cdot M^{-1}.
I can see replacing the whole normal distribution with the normalizing flow under refresh, but how do you use it like a preconditioner in the \theta update? Matt Hoffman et al. did something similar in this paper, but told me they could never get the normalizing flows to fit automatically enough:
The idea is for sure related to the paper you linked.
Just as in the paper, we learn a transformation, and then sample with HMC in the transformed space.
Were we to modify their approach in steps until we end up with Fisher HMC it might look like this:
In the neutralizing-bad-geometry paper they split finding the transformation and sampling into two completely separate steps:
Find a transformation that minimizes KL(q, p) where q is a standard normal.
Sample in the transformed space
First, we use KL(p, q) instead of KL(q, p). So we want to integrate with respect to the posterior, not the standard normal. But to compute KL(p, q) we don’t need draws from q, but draws from p. We can get those with HMC. That leads to an iteration like
Start with some dumb transformation
Run HMC for a while to get “draws” from the posterior
Minimize KL(p, q) a bit by using those draws to find a new transformation
repeat from 2.
This is very similar to mass matrix adaptation, only there we used to just choose the transformation A = Cov(p)^{1/2} instead of minimizing a KL divergence. (This is because we can think of mass matrix adaptation as a linear transformation f_A(x) = Ax and then running HMC with an identity mass matrix. If you go through the math of HMC, this turns out to be equivalent to using the mass matrix M^{-1}=AA^T. (That was mentioned in the Neal HMC paper: Section 4 here: https://arxiv.org/pdf/1206.1901))
In the next step to get to Fisher HMC, we switch out the KL divergence with a fisher divergence. This is very similar to the approach here, only that they don’t switch around the order of q and p, and use a different norm for the gradients in the fisher divergence (it’s related though).
And as a final step, I use different normalizing flows than they are. I’m not so sure about those though :-)
In pymc I’m trying to get something like this to work:
If we have a
pm.Normal("x", mu=some_mu, sigma=some_sigma)
We can associate a transformation object with that distribution, just as we do to unconstrain parameters. So far those transformation never have parameters of their own, but they could have. In this case, we could add a “parametrization parameter” \lambda and then use the “constrain transformation” x = \text{some_mu} + \text{some_sigma}^{1 - \lambda} \text{x_unconstrained}^\lambda. We can easily compute the logdet of that transformation, and compute forward and inverse transformations, so we can learn \lambda by minimizing the fisher divergence just as with the normalizing flows. (Or we just think of this transformation as a layer in the normalizing flow).
In stan, I don’t know where the information that we want to do something with a variable like this would come from. You could just add it manually I guess.
(cc @avehtari has more transformations like this, this is just the one I knew before we talked).