Education-oriented, functional, interpreted Stan/HMC implementation

I would like to propose an education and teaching-oriented project to make Bayesian inference and Hamiltonian Monte Carlo more accessible to newcomers.

MCMC can be baffling, and it is often easier to learn-by-doing than understanding how everything works beforehand – a lot of users will then go on to peek under the hood, but it is also common practice to just trust the black box and general diagnostics.

While Stan is very intuitive from a user/modeling perspective and based on straightforward algorithms (that in some of the publications can be described by a page of pseudocode) its real-world implementation is difficult to understand for people without a programming background, particularly C++/OOP experience – which is the case of many students and people from other fields who may get started in programming with Python/Julia/R.

The goal here is not to make a full-featured implementation in an “easier” language, but a near-minimal version of the essential building blocks, possibly using a funcional paradigm that mimics the mathematical functions and facilitates teaching and step-by-step assembling and understanding of how the whole method comes together.

I have previously implemented bare-bones MH and HMC samplers in Python, Julia, Haskell, but I think this could be a task that with some help could be achieved more easily, efficiently, and with better overall outcome, especially if any developers can lend a hand.

If bad puns are permitted I suggest this implementation is called UnderStan.

12 Likes

While it might not be what you wanted, I previously used the sampyl package in Python for this purpose. Because it’s in Python and not verbose, it’s easier to dive into the sampling and trace the leapfrog steps themselves.

1 Like

Thanks. I didn’t know about this but will take a look. Other packages like Turing.jl also have some features I’m hoping for, like different kinds of samplers, but I think in general I think there tends to be quite a leap between the theory and the implementations, and I think it’s mostly a matter of development choices – not that they are unfounded, just that they don’t lend themselves to teaching and education.

Yep, the Python version is readable, with a 40 line method implementing the NUTS step, with the tree recursion done in a separate 30 line function. If you have the NUTS paper I think this is succinct enough i.e. executable pseudocode.

I can imagine the production Julia implementation will enjoy a similar lack of readability as the templated C++ version, since the mechanisms for writing general code with dispatch are similar.

2 Likes

Yes, that’s definite close to the pseudocode in the NUTS paper, although for teaching purposes I’d still not use any classes/OO implementation.

1 Like

Missing from sampyl though are important aspects of the warmup phase like adaptation of the covariance. It’d be helpful to include those as well, to understand why the Stan implementation of HMC works so well.

2 Likes

Yes, I’d say that is something that should go into a basic but fully-functional implementation. I’d be interested to find out how many lines would be necessary to get the pseudocode into functioning code.

You definitely wanna check the rethinking package in R and the associated book Statistical Rethinking from McElreath.

2 Likes

Johnathon Auerbach suggested to me that a version of Stan be built around the approach that ggplot uses to create plots. The idea might be something like (ggplot syntax can me mapped to Python as well, see plotnine for example):

my_data = dataframe(y=[0,3...], pred_1=[...], pred_2=[...])
model <- UnderStan(data=my_data, parameters=[alpha, beta])

The above is enough to sample:

draws <- sample(model)

Which would be unconstrained uniform over alpha, beta. This would be interpreted much as R and Python are with single line incremental execution possible (REPL loop).

The model could be elaborated further,

my_data = dataframe(y=[0,3...], pred_1=[...])
model <- UnderStan(data=my_data, parameters=[alpha, beta])
model <- model + UnderStan([alpha ~ normal(0,1), beta ~ exponential(1)]) # added priors
model <- model + UnderStan([y = alpha + pred_1 * beta]) #likelihood
draws <- sample(model)

The above model could be developed/evaluated line by line which is super helpful for incrementally creating ggplots so I think the advantages would carry over. What is not clear to me is how to make the process interpreted instead of having a compile step. Do the RStanArm trick of pre-compiling models with a fall back to compilation for uncached models?

Generated quantities could be handled as:

predictions <- UnderStan(model=model, draws=draws, [pred_y = normal_rng(alpha + beta * pred_1)]

This is less syntax leveraged than RstanArm and brms and if done with a REPL loop would be easier to build up models incrementally. It also maps pretty cleanly to the underlying Stan programs.

I didn’t try and get the syntax past R/Python intepreters but hopefully the general idea is clear.

Breck

1 Like

Thanks. I think there’s at least two different aspects here: one is streamlining or clarifying the process of putting together MCMC sampling, which is the case of this ggplot-like syntax, which may help you understand each step that needs to be there for things to happen; the other is breaking down each step, which won’t help you set everything up and get it going, but may help you understand why each of them is there. The latter is kind of what I’m thinking about, e.g. what is the sampler doing? if it’s MH it’s just randomly proposing a jump out of a gaussian distribution with one tuning parameter and accepting.rejecting it, if it’s HMC it’s flowing through a gradient using an integrator with a step size and number (static or dynamic).

I think even the simple samplers are sophisticated/non-intuitive enough that te general algorithm needs a good explanation (I guess that’s partly true of MLE and least squares, but they don’t have many moving parts). But I think the two aspects can be combined to both explain what and how MCMC is put together.

The model ought to be declarative and independent of the inference algorithm. Unfortunately, the reality is that knowledge of inference is necessary for models that work so what you suggest makes sense. I suspect that any sort of interpreted REPL model will need implementation in R/Python but given that warmup is distinct from sampling and warmup is the computationally expensive part of HMC/NUTS, I think this will be very hard to do as an interpreted process that has any sort of liveness.

I agree. This proposal is really orthogonal to what Stan accomplishes.
Also, I guess the cost of warmup in static and dynamic HMC versus other samplers is also of interest in explaining why we need costly sophisticated algorithms instead of lighter ones. Here are probably many ramifications, but I’d have trouble making them clear without thinking more about it.

Maybe the REPL loop version doesn’t attempt to draw accurate samples from the posterior at all. Just do a single draw or 10 draws from priors and push through however much of the model exists. That would help with:

  • Getting data structures working right
  • Validating the information flow in the model
  • Expose model authors to literally running a sampling process, just not one that is going to get to the appropriate posterior.

Once the model ‘flows’ properly, then bring out the fancy algorithms. This also would help learners understand what the fancy stuff is bringing to the game.

1 Like

I agree that generally pedagogical coding demonstrations can be useful, but statistical algorithms are a particularly complicated beast. The problem is that unlike many other algorithms statistical algorithms don’t always produce statistically meaningful outputs. Most of the difficulty in statistical computation isn’t how to implement an algorithm but rather understanding when the algorithmic output is meaningful. It’s straightforward to code up a Markov chain Monte Carlo sampler in one’s favorite language, but it’s much harder to show that a given Markov chain is converging to something meaningful without a whole lot of a theoretical understanding (I can say this from extensive personal experience). Consequently the key to demonstrations like this is conveying what they don’t show so that the audience does not become over-confident in their understanding.

Hamiltonian Monte Carlo is even more subtle to demonstrate because many of its core features seem trivial in a simple algorithmic implementation without an understanding of the mathematical motivation (I’m looking at symplectic integration in particular). Knowing that motivation one can come up with informative failure examples, but they still only scratch the surface of the mathematics.

The key to a useful demonstration is to identify the target audience and what they should be learning. What does a better understanding of Markov chain Monte Carlo generally and Hamiltonian Monte Carlo in particular provide to a user that improves their Stan practice? Or is the desired audience people interested in implementing their own algorithms, in which case what else do they need to know to not only build and implementation but also employ it responsibly? There are many statistical computation tutorials in the wild and I honestly believe that most of them do more harm than good because they present the material as being simpler than it actually is.

My best attempt to integrate code demonstrations for simple Markov chain Monte Carlo implementations with meaningful mathematical context can be found in Markov Chain Monte Carlo in Practice. I have further code that demonstrates some Hamiltonian Monte Carlo basics that so far I have used for only courses, one example of which is publicly accessible at tutorials/mcmc at master · mlss-2019/tutorials · GitHub. Coincidentally a Hamiltonian Monte Carlo chapter that would include these exercises and more is on my short term todo list, but that shouldn’t preclude anyone else from working on their own demonstrations.

I’m happy to comment on particular demonstrations that anyone puts together, with the caveat that I’ll be coming from the above perspective.

Finally let me mention that I’m always a bit hesitant about fully functional programming implementations for pedagogical demonstrations like these. In my experience people don’t intuit pure functional processing all that well, and many Markov chain Monte Carlo algorithms aren’t well-adapted to the functional paradigm. For example Metropolis-Hastings samplers either require state (to keep the initial point around in case of a rejection) or need to be implicated on a lifted space (which has some nice mathematical benefits but also complicates the implementation). But this is nitpick.

3 Likes

I also agree that theoretical understanding is much harder than coding up any one sampler, but I also think Bayesian statistics is often easier by doing rather than understanding all the theory and then going on to run a chain – i.e. if you set up your model, likelihood, priors, run a chain, look at the marginals, compute some quantities and realize you are integrating from the posterior you’ll understand faster what the point is, that’s step 1. Step 2 is understanding how MCMC is even able to get you that distribution, and going from the practice to theory, and vice versa (why does it work, what do I need to tune?).

Hamiltonian Monte Carlo is more subtle, because the theory is more complex and the algorithms are more sophisticated (and also requires more tuning), but the idea here is getting past the implementation hurdles and into what the main differences are (e.g. not using random jumps, but flowing along the derivatives) and some mathematical context to illustrate why it’d generally work better.

I think improving MCMC practice and exploring the infinite esoteric mathematical details of Bayesian statistics are orthogonal, and neither is really the goal here, but some combination where people would understand a bit more of the theory and maybe help you identify issues with the practice. As I mentioned above, you don’t have to know this, you can know what least squares does without understanding it, but if you do it may give you intuition about what it can and cannot do for you.

I’m not sure what you mean by this, this doesn’t require any state, the current “state” is always what you are working with, and the traces can be kept in a list/vector without there being any actual state making it fully functional. The main point of a fully functional implementation, though, is just programming each function as the mathematical function and giving it the arguments and outputs to make a better correspondence with the math, but I think it actually makes for an easier paradigm rather than more complicated implementations.

First, if you don’t know it, check out Chi Feng’s animation app, which covers a bunch of simple densities with a bunch of samplers.

It’s not just OOP, it’s also our heavy use of template metaprogramming (largely traits and expression templates). If you want simple to understand HMC code, I’d recommend MacKay’s book on information theory (the code is in Octave, the OS version of MATLAB, but it’s easy to read even for a non-MATLAB coder like me).

We’ve found it very hard to disentangle the theory from the practice. Mathematical details like stationarity of Markov chains are intimately tied up with issues of warmup/burn-in and diagnosing “convergence”. Concentration of measure is important because not every model you write down is going to converge. But I agree that you can get started without front-loading a ton of theory. But it’s hard to really get to a serious level as a practitioner without biting the bullet and understanding at least the basics of MCMC theory.

Like deciding on mathematical notation, it will depend heavily on the audience and the problem. Like a lot of programmers, I find functional code more intuitive for most recursive data structures. It’s very much like mathematicians finding it easier to work with linear algebra than a bunch of indexes.

That’s not the sense of “functional” that’s meant with functional programming. As @betanalpha points out, it’s really about statelessness and not using local variables/assignment statements. Having said that, you can hack almost anything into a functional form pretty easily using continuation-passing style. But it sounds like you’re just arguing for reasonable programming practices, not a purely functional style, which is fine.

For example, here’s a purely functional Metropolis implementation

metropolis_sample(theta0, proposal_rng, target_density, num_steps):
  return complete([theta0], proposal_rng, taget_density, num_steps)

complete(chain, _, _, 0):
  return chain

complete([theta | chain], q, p, N):
  return complete([step(theta, q, p) , theta | chain], q, p, N-1)

step(theta, q, p):
  return accept_reject(q(theta), uniform_rng(0, 1), theta, p)
  
accept_reject(theta*, u, theta, p):
  return ifelse(u < p(theta*)/p(theta), theta*, theta)

Note, in particular, that it’s done without any explicit assignment or local variables.

There’s an implicit pseudo-RNG that would need to be threaded through if you wanted to do this more thoroughly.

Here’s my best shot at writing out pseudocode for the naive version of NUTS (without biasing toward final doubling, which adds more complexity). You can see it uses assignments and is not purely functional.

Z(theta, rho)

TREE(z-, z+, C)

NUTS(logp, eps, theta0)
  H = function(z) -logp(z.theta) - normal(z.rho | 0, 1)
  theta(0) = theta0
  for m in 1:M
    theta(m) = NUTS-STEP(H, eps, theta(m - 1))
  return theta

NUTS-STEP(H, eps, theta)
  rho ~ normal(0, 1)
  z = theta, rho
  minp ~ uniform(0, exp(-H(z)))
  tree = z, z, { z }
  for depth = 0; depth <= MAX_DEPTH; ++depth
    fwd ~ bernoulli(1/2)
    tree' = NUTS-TREE(H, minp, eps, fwd, depth, EDGE(tree, fwd))
    break if tree' == FAIL || U-TURN(tree')
    tree = JOIN-TREES(fwd, tree, tree')
    break if U-TURN(tree)
  return uniform(tree.C)

NUTS-TREE(H, minp, eps, fwd, depth, z)
  dir_eps = fwd ? eps : -eps
  if depth == 0
    z' = LEAPFROG(H, dir_eps, z)
    if DIVERGE(H, z, z') return FAIL
    C' = (exp(-H(z')) >= minp) ? { z' } : { }
    return z', z', C'
  else
    tree1 = NUTS-TREE(H, minp, eps, fwd, depth - 1, z)
    if tree1 == FAIL || U-TURN(tree1) return FAIL
    tree2 = NUTS-TREE(H, minp, eps, fwd, depth - 1, EDGE(tree1, fwd))
    if tree2 == FAIL || U-TURN(tree2) return FAIL
    return JOIN-TREES(fwd, tree1, tree2)

LEAPFROG(H, eps, z)
  rho = z.rho + eps / 2 * grad(H.logp)(z.theta)
  theta = z.theta + eps * rho
  rho += eps / 2 * grad(H.logp)(z.theta)
  return theta, rho

JOIN-TREES(fwd, tree1, tree2)
  C = tree1.C UNION tree2.C
  return fwd ? tree1.z-, tree2.z+, C
             : tree2.z-, tree1.z+, C

EDGE(tree, fwd)
  return fwd ? tree.z+ : tree.z-

U-TURN(tree)
  return (tree.z+.theta - tree.z-.theta)' * tree.z-.rho < 0
         || (tree.z-.theta - tree.z+.theta)' * tree.z+.rho < 0

DIVERGE(H, z, z')
  return abs(H(z) - H(z')) > DIVERGENCE_THRESHOLD

However you slice it, it’s a complicated algorithm.

7 Likes

Yes, @mitzimorris pointed me to that recently, it’s a very nice visual example.

What I mean is not necessarily making easy to understand code, but rather code that is easy to associate with how the theory is (normally) explained. The main point is exactly to go back and forth between the math and the implementation, because it is indeed difficult to separate theory from practice and the algorithms are indeed complicated.

I understand the definition of pure functional programming, and it’s true that being completely stateless is not required for the implementation to be “reasonable” in the sense of being more readable, it’s just easier to explain in that way instead of saying “let’s code this in a way that mimics the math and makes each mathematical function have a close correspondent in the implementation to make it easier to understand”, since with FP you are kind of forced to do that by design. My point there was that there is no real hurdle in making it functional, since the “state” is just the parameter values at that iteration which would get passed to the next one and stored into a list of states (but otherwise wouldn’t be or have to be kept around a program state that could be changed externally).
Maybe it’s not as convenient for some parts of the algorithm, and maybe I’m overlooking some important detail of its implementation, but in principle it shouldn’t be a problem.

I completely agree, but again that’s the point. When I started (and popular samplers where a lot simpler) I was baffled by how MCMC algorithms worked and it took me a while to go from running a program to understanding what it actually did and why it worked or failed. Since things are not getting simpler I think the education part of it needs to keep up.

w/r/t MCMC samplers, things haven’t really gotten more complicated. I think that there are a few good papers with nice visualizations - e.g., this one: https://besjournals.onlinelibrary.wiley.com/doi/full/10.1111/2041-210X.12681

3 Likes

The difficulty is that there are many actions/algorithms that one can naively apply to the states of a Markov chain that don’t correspond to any meaning mathematical operations; indeed a nontrivial number of introductions to Bayesian inference advocate for actions that aren’t all that meaningful. Teaching “doing” without any understanding of why just gives people tools that they don’t know how to apply appropriately. Even worse there ends up being strong heterogeneity in what people believe they understand – depending on exactly what heuristics they encountered early on – which makes more careful eduction later all the more difficult.

Integration/expectation as the main probabilistic operation is a great example. All probabilistic computational methods approximate expectation values with respect to a given distribution in one way or another, but most introductions to probability theory fall back on one-dimensional visuals of density functions that obscure this fact. People want to see density functions at the end of their analysis even if probabilistic computational methods cannot produce them in a well-defined way. Emphases on “doing” motivates code/packages/examples that rely on implicit density estimators to satisfy that desire, reinforcing misunderstandings of what’s actually going on.

Even when a more careful perspective is taken – so that the audience understands exactly what is being computed – examples that implicitly assume sufficiently nice conditions for the action/algorithm outputs to correspond to something meaningful, but never show counter examples where the outputs are less meaningful, make it very easy to misinterpret the outputs as always being meaningful. This is an even more common sin of many introductory treatments that show a Markov chain Monte Carlo algorithm working on a low-dimensional, simple target distribution and leave the reader to presume that the algorithm will always work.

To be clear there is much that can be done from a “doing” perspective that doesn’t compromise the theory, but it requires a lot more steps. For example starting distribution and expectation values, analytic integration of density functions for evaluating expectation values, then moving on to Monte Carlo methods as a way of approximating/estimating expectation values with “samples” and what can go wrong when the necessary conditions for those approximations/estimators to be well-behaved. That said the concept of “samples” is a huge rabbit hole of its own, especially when one moves to correlated samples where the approximation/estimator behavior is a mess.

What’s the end goal? Is it to motivate Hamiltonian Monte Carlo methods over random walk Metropolis methods in genera? Assuming ideal conditions (where all of the Markov chain Monte Carlo estimators are well-behaved) then even the relatively simple static integration time and Metropolis-corrected implementations can be useful. Indeed there are many examples that demonstrate exactly this. The optimal Hamiltonian Monte Carlo implementation even in these ideal conditions, however is a much more difficult discussion as both the implementations themselves and the mathematics defining “optimality” become much more complicated. At the same time motivating which method is most robust and relatedly which method is more diagnostic about its failures requires going underneath those “ideal conditions”.

Can know always know what least squares does, though? Sure one can always apply a least squares estimator/algorithm to arbitrary data and know what the format of the output will be. But is that an actual understanding of least squares?

We can never expect practitioners to be an expert at all aspects of the underlying theory, and so we will always have to rely on abstractions. The question, however, is what makes a useful and robust abstraction? At the very least a useful abstraction has to help a practitioner understand what is being estimated. Ideally it would also help them understand how the output of an algorithm/method can relate to that target. Even better it would help clarify under what circumstances one has a useful relationship.

In other words does the abstraction help inform not only how to implement/evaluate a method but also to understand what are the meaningful outputs of the method and to identify when those outputs are well-behaved. I’m advising so much caution here only because the common abstractions that are presented in introductory statistics material are so broken that it’s often impossible to even frame the right questions within them, let alone answer them properly.

Another complication is that there are so many possible abstractions here, for example one targeted at people who want to use Markov chain Monte Carlo algorithms for probabilistic computing in general, at people who want to use Markov chain Monte Carlo for Bayesian inference in particular, at people who want to design new Markov chain Monte Carlo methods for general problems, at people who want to design new Markov chain Monte Carlo methods for particular problems, at people who want to tweak existing Markov chain Monte Carlo methods, etc, etc, etc.

With the caveat that all of the examples necessarily visualize only one or two-dimensional target distributions, and hence can’t demonstrate behaviors common in higher-dimensional spaces, and none of them show how the Markov chain states are related to Markov chain Monte Carlo estimators let alone the expectation values those estimators are approximating.

The heart of the Stan’s Hamiltonian Monte Carlo sampler isn’t even all that object-oriented; the object-oriented structure is all for organizing the Markov chain logistics. The complexity of the sampler implementation is in the template meta programming and all of the machinery that’s needed to implement a dynamic integration time efficiently.

Most of the code examples of Hamiltonian Monte Carlo, including in MacKay, are of implementations with static integration time and Metropolis correction using just the final point. More efficient implementations consider the entire trajectory, even more efficient implementations consider trajectories that dynamically expand, that the most efficient implementation the expansion in a way the minimizes memory usage.

In the code examples I linked to above I demonstrate static integration time methods that utilize the entire trajectory. These are pretty straightforward when the entire trajectory is kept in memory and still reasonable when only the transition state is kept in memory. Dynamic implementations with simple termination criteria can be made more straightforward by keeping the entire trajectory in memory (and avoiding the recursion)
but the performance strongly suffers. The more sophisticated termination criteria we use in Stan, however, complicates this quite a bit.

No disagreement there so long as one is careful to identify for whom any given tutorial is targeted. My experience is based on less professional programmers who often get caught up in more functional concepts.

One complication here is that beyond the natural Markov chain abstraction (where a transition function generates sequential states to form a Markov chain and those states are used to constructed Markov chain Monte Carlo estimators) there isn’t necessarily a unique translation between math and implementation. Depending on the details a Hamiltonian Monte Carlo transition can be decomposed into intermediate steps in different ways, for example generating an entire transition first and then sampling a point or sampling a point sequentially as a trajectory is being generated. These can all shown to be equivalent mathematically but they can have very different performance behaviors.

Again the question comes down to the desired abstraction/audience – a superficial understanding of how gradients are used, or how precisely integrator error is corrected, or how the most efficient implementation is designed, etc, etc, etc?

Just to emphasize some of my previous points: the Markov chain Monte Carlo method has not changed in many decades, rather it is the design and theoretical understanding of particular Markov transitions that has largely evolved. The latter isn’t relevant for someone who does’t understand what Markov chain Monte Carlo does. It’s not even necessarily relevant for someone who wants to understand why basic techniques like Metropolis transitions or Gibbs transitions preserve a given target distribution as needed for a Markov chain Monte Carlo algorithm.

All of that said, because the education is uniformly a mess there is much to be done to improve pedagogy at all levels. We just have to be careful to carefully identify and communicate the target abstraction/audience so as to not further contribute to the existing mess.

1 Like

I have a Julia implementation for the NUTS sampler on \mathbb{R}^n that is reasonably well-documented internally and written in a functional style (eg tree expansions are abstracted out):

Incidentally, this also made unit testing much easier, allowing me to decouple a lot of the implementation from the RNG.

3 Likes