Using Python code in a model

When working with team members I was asked how to reuse parts of Stan models or even some of our existing Python code. After some reflection I wrote a Stan C++ user function which invokes Python code to compute the model’s lp and gradient,

I’m not sure it has much value as-is beyond proof of principle (it leaks memory, for example) but I was brainstorming if use cases exist to justify cleaning it up, such as

a) reuse existing codebase in Python or with Python interface
b) use AD from an ML lib like PyTorch Tensorflow but Stan’s NUTS
c) avoid having to cram all your data into Stan’s numerical data structures
d) use numba’s openmp support for easier parallelization
e) debug model with pdb instead of print()
f) ?

what do you think?

even if it is completely useless, it was quite fun to write
2 Likes

Sounds cool.

Any ideas what is “leaking”.

Any idea if python >>pythran>>C++ would be viable to create user functions automatically from python source?

I didn’t decrement reference counts as explained here, so I’m guessing the NumPy arrays returned frmo the Python functions aren’t being freed. The next step would be to write a Cython shim to implement the function in the pycall namespace, and Cython would handle that automatically. Pybind11 could be another easy approach since it knows how to map Eigen datatypes to NumPy datatypes (while avoiding memory leaks).

A quick read through the docs suggests yes indeed. I was already looking at Numba’s C callback generation, but I think the shared libs that Pythran generates would be similar, and either choice would sidestep the overhead of entering the Python interpreter, as if written by hand in C++.

The primary disadvantage is having to code the gradient by hand. The pystanpy example currently shows using autograd to avoid that, but it’s very slow on trivial models compared to Stan, though if the arrays are big enough (>L3 cache) we could hope the slow down to be tolerable. I don’t care too much since we’ve not found the NUTS implementations elsewhere (PyMC3, Pyro, Julia stuff) to converge nearly as fast as Stan’s. For example, we had a model where we saw Rhat down to 1.1 after 200 iterations with Stan vs 1000 for others, so 5x slower gradient evaluation but still using Stan is just fine.

The best may eventually be JAX, a next generation autograd package, which is backed by XLA (an optimizing linear algebra thingy for TensorFlow) to generate parallelized native code, along with C callbacks (so, supporting fast calls from the Stan core algos), just from your regular NumPy functions.

1 Like