I am using pystan to produce samples from the posterior predictive distribution on new data, given a set of parameter samples that were produced by a fit to some training data.
The idea is to store the parameter samples from MCMC in a python object and then pass them to a different piece of stan code in the data block, alongside the new data. The samples from the posterior predictive distribution are then created in the generated quantities block. This works fine, except fit.extract()
takes a very long time for larger samples.
As a simple example, take the following stan code
data {
int<lower=0> n;
}
generated quantities {
vector[n] theta;
for (j in 1:n) theta[j] = normal_rng(0., 1.);
}
I run stan in Fixed_param
mode and then do fit.extract()
. For reference, I attach a graph of the time it takes to do the sampling in stan, and how long fit.extract()
takes, as a function of the length of the generated vector (i.e. n
in the above stan code). When the vector is ~ 100,000 elements long, the sampling still only takes 0.2s or so, but extract takes a few hundred seconds.
I wanted to post in this forum before I dug into the problem further - am I doing something wrong / is this behaviour expected?
pystan_timings.pdf (13.2 KB)
This is a good, clean stress test. Thanks!
100,000 params times, say, 4,000 iterations takes roughly 3 GiB of
memory. Moving that much memory around (between threads or processes) is
going to take some time. That said, I think we can make it much faster.
1 Like
I’m not quite sure I have explained what I am experiencing clearly enough. To be more concrete, if I do:
import numpy as np
import pystan
stan_code = """data {
int<lower=0> n;
}
generated quantities {
vector[n] theta;
for (j in 1:n) theta[j] = normal_rng(0., 1.);
}"""
model = pystan.StanModel(model_code=stan_code)
fit = model.sampling(data={'n': 100000}, chains=1, iter=1, algorithm="Fixed_param")
# very slow (~ 100s)
theta = fit["theta"].ravel()
# much faster (~ 10ms)
theta1 = np.array([float(i) for i in fit.sim["samples"][0].chains.values()][:-1])
theta
and theta1
are identical (admittedly the second approach to extracting the vector is uglier to read), and yet there seems to be a factor of 10,000 difference in the speed. What causes the slowdown? Is the second method unsafe in some way?
I just checked and it looks like in the first case extract() is called
behind the scenes. This will be fixed in PyStan 3 when the
permuted=True
default behavior vanishes from PyStan and RStan.
Give fit.extract('theta', permuted=False)
a try.
fit.extract("theta", permuted=False)
takes about the same time as fit.extract()
unfortunately!
That is strange. It shouldn’t take that long. Want to open an issue on
Github?
I wish this were easier to debug/profile/optimize. Unfortunately all
this code is buried inside a Cython file template, stanfit4model.pyx.
Yes, I had a look at that cython code to figure out the fast hack to get the data out. I had a quick try profiling but didn’t see something immediate. If I get some time I can look into it a bit more. I can write a github issue, sure!
Hi, I answered in the github already, but wanted to give you some tips here on the forum.
If you run your code in Jupyter Lab / Notebook, you can try IPython magics
-
Time cell (not super accurate, but close enough)
%%time
#rest_of_the_code
-
Time line (not super accurate, but close enough)
%time fit = sm.sampling()
-
line_profiler for cell
%%prun
So running %%prun theta = fit.extract('theta', permuted=False)
will show you the problematic parts. (edit. %%prun
needs to be on the first line alone)
70187 function calls in 5.080 seconds
Ordered by: internal time
ncalls tottime percall cumtime percall filename:lineno(function)
10001 4.863 0.000 4.863 0.000 {built-in method pystan._misc.get_samples}
1 0.165 0.165 5.080 5.080 {method 'extract' of 'stanfit4anon_model_128b9b5fb61235c403e1332c7718cebe_3286899518068726564.StanFit4Model' objects}
10001 0.015 0.000 4.878 0.000 misc.py:911(_get_samples)
10001 0.013 0.000 0.023 0.000 shape_base.py:115(atleast_3d)
...
1 Like
Hi, I wanted to make one correction to the line_profiler step.
prun is more like a function profiler and to use line profiler do the following
python -m pip install line_profiler
In notebook
%load_ext line_profiler
And then for a function
%lprun -u 1e-6 -f myfunc -f myinnerfunc myfunc(...)
Where myfunc
the main function and myinnerfunc
is a function called in the myfunc
. Any number of -f func
is ok.
The -u 1e-6
is to force show timings in microseconds.
1 Like