Fit.extract() takes a long time

I am using pystan to produce samples from the posterior predictive distribution on new data, given a set of parameter samples that were produced by a fit to some training data.

The idea is to store the parameter samples from MCMC in a python object and then pass them to a different piece of stan code in the data block, alongside the new data. The samples from the posterior predictive distribution are then created in the generated quantities block. This works fine, except fit.extract() takes a very long time for larger samples.

As a simple example, take the following stan code

data {
    int<lower=0> n;
generated quantities {
    vector[n] theta;
    for (j in 1:n) theta[j] = normal_rng(0., 1.);

I run stan in Fixed_param mode and then do fit.extract(). For reference, I attach a graph of the time it takes to do the sampling in stan, and how long fit.extract() takes, as a function of the length of the generated vector (i.e. n in the above stan code). When the vector is ~ 100,000 elements long, the sampling still only takes 0.2s or so, but extract takes a few hundred seconds.

I wanted to post in this forum before I dug into the problem further - am I doing something wrong / is this behaviour expected?

pystan_timings.pdf (13.2 KB)

This is a good, clean stress test. Thanks!

100,000 params times, say, 4,000 iterations takes roughly 3 GiB of
memory. Moving that much memory around (between threads or processes) is
going to take some time. That said, I think we can make it much faster.

1 Like

I’m not quite sure I have explained what I am experiencing clearly enough. To be more concrete, if I do:

import numpy as np
import pystan

stan_code = """data {
    int<lower=0> n;
generated quantities {
    vector[n] theta;
    for (j in 1:n) theta[j] = normal_rng(0., 1.);
model = pystan.StanModel(model_code=stan_code)
fit = model.sampling(data={'n': 100000}, chains=1, iter=1, algorithm="Fixed_param")

# very slow (~ 100s)
theta = fit["theta"].ravel()

# much faster (~ 10ms)
theta1 = np.array([float(i) for i in fit.sim["samples"][0].chains.values()][:-1])

theta and theta1 are identical (admittedly the second approach to extracting the vector is uglier to read), and yet there seems to be a factor of 10,000 difference in the speed. What causes the slowdown? Is the second method unsafe in some way?

I just checked and it looks like in the first case extract() is called
behind the scenes. This will be fixed in PyStan 3 when the
permuted=True default behavior vanishes from PyStan and RStan.

Give fit.extract('theta', permuted=False) a try.

fit.extract("theta", permuted=False) takes about the same time as fit.extract() unfortunately!

That is strange. It shouldn’t take that long. Want to open an issue on

I wish this were easier to debug/profile/optimize. Unfortunately all
this code is buried inside a Cython file template, stanfit4model.pyx.

Yes, I had a look at that cython code to figure out the fast hack to get the data out. I had a quick try profiling but didn’t see something immediate. If I get some time I can look into it a bit more. I can write a github issue, sure!

Added an issue here.

Hi, I answered in the github already, but wanted to give you some tips here on the forum.

If you run your code in Jupyter Lab / Notebook, you can try IPython magics

  • Time cell (not super accurate, but close enough)

  • Time line (not super accurate, but close enough)

    %time fit  = sm.sampling()
  • line_profiler for cell


So running %%prun theta = fit.extract('theta', permuted=False) will show you the problematic parts. (edit. %%prun needs to be on the first line alone)

         70187 function calls in 5.080 seconds

   Ordered by: internal time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
    10001    4.863    0.000    4.863    0.000 {built-in method pystan._misc.get_samples}
        1    0.165    0.165    5.080    5.080 {method 'extract' of 'stanfit4anon_model_128b9b5fb61235c403e1332c7718cebe_3286899518068726564.StanFit4Model' objects}
    10001    0.015    0.000    4.878    0.000
    10001    0.013    0.000    0.023    0.000
1 Like

Hi, I wanted to make one correction to the line_profiler step.

prun is more like a function profiler and to use line profiler do the following

python -m pip install line_profiler

In notebook

%load_ext line_profiler

And then for a function

%lprun -u 1e-6 -f myfunc -f myinnerfunc myfunc(...)

Where myfunc the main function and myinnerfunc is a function called in the myfunc. Any number of -f func is ok.

The -u 1e-6 is to force show timings in microseconds.

1 Like