An MLIR-based IR target for probabilistic programming languages

Hello everyone,

I am currently “on the hunt” for a research topic on the intersection of compilers and mathematics. While I have not contributed to Stan before I’ve been aware of it and similar languages/libraries for quite some time and have always wanted to bring about a larger project involving probabilistic programming.

Recently a lot of work, both academically and in industry, has gone into creating optimizing tensor program compilers for machine learning applications, most prominently TMV and a whole host of projects based on MLIR. That got me thinking whether the latter might offer some novel advantages over the current approach Stan takes to compilation.

As I understand, there already exists an IR target (MIR) that optimized and transpired to C++, with the possibility of executing code on the GPU. The main benefit of MLIR as I see it is that there already exists a wide variety of dialects for compiling to LLVM, C and running on GPU-like hardware accelerators. By creating an MLIR-based IR for Stan support for these could possibly be more easily established and eventually maintained.

Another major advantage that could materialize would be establishing a “shared dialect” for probabilistic programming languages in general onto which Stan as well as Pyro, TF Probability etc. could be lowered, similar to what is now happening popular ML frameworks an TOSA. If this were to draw effort from multiple projects, progress in middle end optimization could be accelerated.

So before I dive into this I wanted to collect some feedback from the community to make sure this is actually a worthwhile idea. Some immediate questions:

  • Do you see this delivering actual value over the current compilation flow?
  • Are there any hardware targets and/or compiler optimizations that you would like to see that have so far not been implemented in the Stan compiler that could be part of this effort?
  • Do you think popular existing languages/libraries similar to Stan are sufficiently similar that a shared intermediate representation is feasible?

I would be really happy about any feedback, even if my idea turns out to be nonsense.

Best,
Timo

2 Likes

Hi @Time0o, welcome to the Stan forum. This particular topic is outside my wheelhouse, but @WardBrian and/or @Bob_Carpenter can probably answer your questions (or point you to the right person who can).

Hi @Time0o

Thanks for your interest! I think there is a lot more the Stan compiler could be doing along these lines, but also some concerns.

The big barrier to previous attempts to compile Stan to anything other than the current C++ architecture is the dependence on a math/matrix/autodiff library. ML models tend to use comparatively fewer kinds of operations, which makes them a deceptive comparison. If you really wanted to be able to compile arbitrary existing Stan models, you will need to have implementations for a lot of functions in your intermediary or target language, from the various distributions to ODE solvers to heavily overloaded basic math ops, alongside high-quality autodiff for all of those. The sheer size of this is difficult to estimate, really, and I’m not sure how much of the existing templated C++ could be re-used by MLIR

On the other hand, the stuff I’ve seen people use MLIR for previously could have huge potential upsides for Stan, like automatic vectorization or GPU-ization of code. So it’s definitely an interesting area!

I think my overall concern would be how this changes the installation process. We already have a lot of users who struggle to even get a working C++ tool chain. Any initiative that made this even more onerous (like requiring that toolchain to be LLVM based?) would likely be a nonstarter. On the other hand, if a better IR allowed us to compile directly to assembly without another program installed, or to even “interpret” Stan code for faster model iteration, it could answer some feature requests that are nearly as old as the project itself.

I’d be happy to hop on a call sometime to discuss more if you’re interested!

2 Likes

Hi Brian,

That is indeed a big compilation target. How much of this is still present in the current MIR? Reusing code, especially C++ might be possible but I’d have to look deeper into it.

Regarding installation, that is definitely a concern. It might be possible to provide a statically linked compiler that includes all LLVM dependencies but I have not tried this out yet. But yet, it is actually relatively easy to JIT compile to LLVM IR and then execute it.

I’m not really a frequent discourse user, can you send me a private message to set up a call? I think I don’t have permissions to do that.

All of it. Calls into the standard library aren’t really touched by the compiler, we just code generate and let the C++ compiler handle the template and overload resolution.

Our current MIR is not really that much lower than the AST in many ways. It has typing and overload information, and a bit of the surface level syntax has been desugared, but we leave a lot of high-level constructs because we know we’re targetting a high-level language. We do some light optimization (the most interesting/bespoke of which is determining if we can use struct-of-array matrix types in autodiff), do some tree-rewriting for IO code, and then generate the C++

Hi @Time0o and welcome to the Stan forums!

The Stan developer who has been thinking about this the most is @Matthijs. He was thinking of working at the layer above XLA (the foundation of TensorFlow and JAX)

The transpiler generates the same code and we have a flag that’ll turn on the GPU options at compile time for some functions. We don’t have a way to fully run in-kernel on GPU.

Those an PyMC all compile down to JAX right now. I think that’s where people are landing for the time being as a target rather than XLA. One issue is that there are not a lot of compiler developers working on these projects.

Is TMV mean to be a competitor to JAX? I’d never even heard of it before, but that’s not saying much.

It would certainly be a challenge. PyMC and NumPyro are largely graphical modeling languages embedded in Python whereas something like Stan is a standalone language. Then you have popular alternatives like Turing.jl in Julia and NIMBLE in R as well as standalone packages like ADMB/TMB. None of these projects is well funded other than TensorFlowProbability and JAX.

If you look at something like JAX, the library is pretty extensive:

What’s missing is the implicit solvers we have in Stan: 1D integrators, root finders, differential algebraic equation solvers, ordinary differential equation solvers, HMMs, etc.

What they have a lot more of is reshaping and mapping operations—our map_rect and reduce_sum are comparatively weak.

That doesn’t really contradict what I was saying. If you didn’t allow JAX to compile the code into more elementary operations, it would also be quite hard to represent arbitrary JAX code with something like TOSA mentioned in the original post.

The advantage JAX has over Stan (in this area, at least) is that reasoning about the library code is much more similar to just reasoning about more user code, since it’s all in Python (or in JAX primitives). The Stan compiler knows nothing about inv_erfc besides that it exists, where as JAX can look at the definition, inline, transform, etc. This almost- homoiconicity gives JAX a lot of its power.

An equivalently powerful Stan compiler, built on the current Stan library, would need to include most if not all of a C++17-compliant frontend to be able to do the same kinds of transformations before lowering, I think