I think that just means you need a JAX function to compute the unconstrained log density.
To give you an example of what ChatGPT version o1 (not pro) can do, I asked it to transform the following Stan program to a JAX function and I’ll show the result after:
data {
int<lower=0> N;
int<lower=0> M;
array[N] int<lower=0, upper=M> y;
}
parameters {
real mu;
real<lower=0> sigma;
vector[N] logit_theta;
}
model {
mu ~ normal(0, 1);
sigma ~ lognormal(0, 1);
logit_theta ~ normal(mu, sigma);
y ~ binomial_logit(M, logit_theta);
}
Here’s the JAX that it produced after I asked it to use built-in functions rather than unfolding all the densities manually (it codes like a physicist, not like a statistician!).
import jax
import jax.numpy as jnp
from numpyro.distributions import Normal, LogNormal, Binomial
def log_density(params, y, M):
# Unpack parameters
mu, unconstrained_sigma, logit_theta = params
N = y.shape[0]
# Transform sigma to constrained scale
sigma = jnp.exp(unconstrained_sigma)
# Priors
# mu ~ Normal(0,1)
lp = Normal(0,1).log_prob(mu)
# sigma ~ LogNormal(0,1) but we must add the Jacobian for the transform from unconstrained
lp += LogNormal(0,1).log_prob(sigma) + unconstrained_sigma
# logit_theta[i] ~ Normal(mu, sigma)
lp += jnp.sum(Normal(mu, sigma).log_prob(logit_theta))
# Likelihood
# y[i] ~ Binomial(M, logit=logit_theta[i])
lp += jnp.sum(Binomial(total_count=M, logits=logit_theta).log_prob(y))
return lp
You could presumably coach it into making slightly better use of the pytree abstraction in JAX. And you could have it just bind the data rather than take it as an argument. We’ve been looking at how to best define JAX models in the Stan style, and have some recommendations we’re about to publish.
Here’s a link to the ChatGPT session: ChatGPT - JAX function from Stan