Speed of evaluating (gradients of) log probabilities in pystan 2.x vs 3

This is more of an FYI than anything else. In our variational inference package VIABEL, we have been using pystan to evaluate (gradients of) model log probabilities. While this worked very well with pystan 2.x, with pystan 3 evaluation is about 60x slower, and thus is essentially unusable. Here are timing results from VIABEL (10K iterations, centered 8 schools example):

I know we are not the only ones using pystan in this way. That said, I realize in the short or even medium term there may not be much that can be done, given the architecture of pystan 3.

cc @ariddell @avehtari


Not sure how PyStan 3 does this, but everything gets piped through some http interface, doesn’t it?

Maybe it’s possible (to extend the interface such that it is possible) to pass a batch of parameters?

That being said, I’d love the ability to easily compute batches of gradients of log probabilities with pystan/cmdstanpy.

This is odd. Something doesn’t seem right.

You’re just calling an 8 schools model’s log_prob function 10k times?

You might try writing a simple httpstan test case to show the slow speed. If you can’t reproduce it there, then the problem is in the pystan 3 code and a fix should be easy.


Actually 100K times (10 samples per iteration). One of my grad students has been running these tests. Here is what she gets using ipython:

For pystan version 2.19
%time fit.grad_log_prob(param.tolist())
CPU times: user 136 µs, sys: 29 µs, total: 165 µs
Wall time: 180 µs

For pystan version 3
%time model.grad_log_prob(param.tolist())
CPU times: user 15.7 ms, sys: 2.92 ms, total: 18.6 ms
Wall time: 16.6 ms


The performance hit is most certainly due to the current intermediate httpstan interface.

Reading from a csv file recovers the factor of roughly 100, see Add `log_prob_grad` method whose interface mimicks `generate_quantities` · Issue #1012 · stan-dev/cmdstan · GitHub

I guess being able to request a batch of evaluations would restore performance.


Actually, could you do with batches, @jhuggins?

I think it currently works only for one draw per call.

Yes, I think this is what PyStan supports. The question is whether @jhuggins needs the gradient at a prespecified set of points, or whether the points depend on the previous points.

1 Like

@Funko_Unko the samples within a single iteration can be evaluated all at once (so 10 samples in my original example). I guess we could hope that requesting 10 evaluations at once would provide a 10x speedup, which would leave a ~5-10x performance gap compared to pystan 2.x.

Hm. Yeah that is unfortunate.

Also, I just checked and the cmdstan implementation actually introduced a larger overhead per call/batch than the current pystan implementation. Ie it’s much faster than pystan if you pass a large batch, which pystan (currently) has to process one by one, but slower if you also pass one draw after the other. Wonder what exactly introduces the overhead.


Ok I had some time to test this.

For this application you probably should tap into httpstan directly

In the examples feel free to use different values for the input, I just used the first sampled draw as an example. (notice that param_constrained needs to have the parameters in the correct shape → ndim object needs to be in ndim shape)

posterior = stan.build(schools_code, data=schools_data)
fit = posterior.sample(num_chains=4, num_samples=1000)

param_constrained = {key: fit[key][:, 0].tolist() for key in fit.param_names}
param_unconstrained = posterior.unconstrain_pars(param_constrained)

pystan solution (call log_prob from pystan)

%timeit -n 1000 posterior.log_prob(param_unconstrained)
# 2.71 ms ± 38.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%timeit -n 1000 posterior.grad_log_prob(param_unconstrained)
# 2.7 ms ± 46.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

httpstan solution

import httpstan
module = httpstan.models.import_services_extension_module(posterior.model_name)

%timeit -n 1000 module.log_prob(posterior.data, param_unconstrained, True)
# 38.5 µs ± 5.64 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%timeit -n 1000 module.log_prob_grad(posterior.data, param_unconstrained, True)
# 43.3 µs ± 7.75 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)