Parse and convert stan code to a generalized format

I am a Python user trying to compare my library code with stan. Is there an obvious way to parse stan code and get all the relevant information automatically in such a way so that I can convert all the logic (e.g. parameter names, size, constraints) automatically in my Python code?

P.S.: I am also a GSoC 2022 contributor with TensorFlow and I am willing to create an open-source parser if I get some tips and support from the developers of stan.

I guess you want to have a look at the stan transpiler that is generating c++ code from stan model: GitHub - stan-dev/stanc3: The Stan transpiler (from Stan to C++ and beyond).

Edit: you can also check : GitHub - stan-dev/stan2tfp: Stan2TFP is a work-in-progress alternative backend for Stanc3 which targets TensorFlow Probability which generate TF code from stan model

1 Like

I suggest that you look at the YAPS project - GitHub - IBM/yaps: A surface language for programming Stan models using python syntax
It’s a system which translates Python code to Stan.

3 Likes

It also translates Stan to their Python representation, which is a neat feature, though it is probably outdated by now, given last commit is 4 yo.

There is also a fork at GitHub - deepppl/stanc3: Rewriting the Stan compiler in OCaml which targets Pyro/NumPyro (by @louis-mandel)

1 Like

Looking around this seems the most relevant to my question: I would like to write a Stan model and then run a bayeux backend. That repo currently supports tfp, pymc, and numpyro models, so my first thought was to just convert the Stan model to either tfp or numpyro then convert from that to bayeux.

Looking at the above repos though it appears these have been abandoned?

What is the current state of the art for transpiling Stan to new backends?

I am not aware of any non-orphaned project that does this at the moment.

If the goal is just to get access to a log density, you can use GitHub - roualdes/bridgestan: BridgeStan provides efficient in-memory access through Python, Julia, and R to the methods of a Stan model., but that will still use the C++ backend, which means it will primarily be CPU bound still

Yeah the idea here is to grab the log density and the priors presumably and convert that to a bayeux compatible model.

Thanks for the link, I will check it out!

Everything that’s available is represented in the middle layer of our OCaml transpolar.

What specifically do you want to get out of it? Are you looking to define the exact same density in NumPy/SciPy or JAX, for example? Or translate the Stan code to NumPyro code where possible?

With GPT-o1, you can do a lot of this translation automatically. I’m using it regularly to convert Python density definitions to Stan and vice-versa. The challenge is dealing with the automatic variable transforms and the generated quantities block in Stan programs.

And if you target a less expressive language, then not everything will have a translation.

As far as I know, there are no active Stan projects to generate Pyro, PyMC, TFP, JAX, or any other target output.

Specifically I wanted to write a program in Stan then run a sampler from the bayeux library.

I think that just means you need a JAX function to compute the unconstrained log density.

To give you an example of what ChatGPT version o1 (not pro) can do, I asked it to transform the following Stan program to a JAX function and I’ll show the result after:

data {
  int<lower=0> N;
  int<lower=0> M;
  array[N] int<lower=0, upper=M> y;
}
parameters {
  real mu;
  real<lower=0> sigma;
  vector[N] logit_theta;
}
model {
  mu ~ normal(0, 1);
  sigma ~ lognormal(0, 1);
  logit_theta ~ normal(mu, sigma);
  y ~ binomial_logit(M, logit_theta);
}

Here’s the JAX that it produced after I asked it to use built-in functions rather than unfolding all the densities manually (it codes like a physicist, not like a statistician!).

import jax
import jax.numpy as jnp
from numpyro.distributions import Normal, LogNormal, Binomial

def log_density(params, y, M):
    # Unpack parameters
    mu, unconstrained_sigma, logit_theta = params
    N = y.shape[0]

    # Transform sigma to constrained scale
    sigma = jnp.exp(unconstrained_sigma)

    # Priors
    # mu ~ Normal(0,1)
    lp = Normal(0,1).log_prob(mu)
    # sigma ~ LogNormal(0,1) but we must add the Jacobian for the transform from unconstrained
    lp += LogNormal(0,1).log_prob(sigma) + unconstrained_sigma
    # logit_theta[i] ~ Normal(mu, sigma)
    lp += jnp.sum(Normal(mu, sigma).log_prob(logit_theta))

    # Likelihood
    # y[i] ~ Binomial(M, logit=logit_theta[i])
    lp += jnp.sum(Binomial(total_count=M, logits=logit_theta).log_prob(y))

    return lp

You could presumably coach it into making slightly better use of the pytree abstraction in JAX. And you could have it just bind the data rather than take it as an argument. We’ve been looking at how to best define JAX models in the Stan style, and have some recommendations we’re about to publish.

Here’s a link to the ChatGPT session: ChatGPT - JAX function from Stan

1 Like