Hi all,
I am trying to sample the following function using pystan
\log{\rm Posterior}(a) = \log{\rm Dirichlet Prior}(a) + \cdot \Sigma\log(M\cdot a) -\Sigma\log(C\cdot a)
where a=a_1,a_2,...,a_n is a simplex vector: a_i\geq 0 and \Sigma a_i = 1. The problem arises when matrix M has a large dimension, such as 100 columns and 1000000 rows. The columns and rows correspond to the number of parameters (model) and data(observations), respectively.
The Stan code I copied below usually works for a low number of data and parameters. However, the process is killed when the numbers reach the large values commented on previously because it exceeds the available RAM memory.
I do not have this problem when, for instance, I get the maximum of the posterior function using the Scipy package.
Do you know how to solve this issue?
I hope you can help me to found a way to use Stan in this large data case. Thank you for any attention you paid to this topic.
functions{
real P(int N_j, vector v, matrix Mp, matrix Mc){
return sum(log(Mp*v))-sum(log(Mc*v));
}
}
data {
int<lower=0> Nj; // number of data
int<lower=0> Ni; // number of isochrones
matrix[Nj,Ni] Pij; // Probability matrix
matrix[Nj,Ni] Cij; // Normalization matrix
}
parameters {
simplex[Ni] a;
}
model {
target += dirichlet_lpdf(a | rep_vector(1., Ni));
target += P(Nj,a,Pij,Cij);
}
Hi, @Andres3146 and welcome to the Stan forums!
Assuming your P
function produces a proper density (I couldn’t understand what it’s doing as written), then the problem is most likely due to PyStan’s buffering all the draws.
But, if a
is of dimension 1 million, you’re going to have problems with the way we define the simplex constraint—it doesn’t scale well to this size. Plus, it’s going to generate around 10MB of data per draw, so if you take the default 4K draws from 4 chains, you’ll have roughly 40GB of data in the draws. How many draws are you taking and how much memory does Python have? Stan wasn’t really designed to scale to that level and a lot of our processes like posterior analyses are going to choke (or at least take a very long time) processing this.
To speed things up, you can remove the following statement without affecting sampling or optimization (in Stan or elsewhere), because the Dirichlet with concentration parameters 1 is uniform over simplexes a
, which is the default just from declaring a
to be a simplex.
As an aside, CmdStanPy should be more scaleable as it doesn’t run Stan in the same process as Python and it streams the draws to disk as they are being taken. But I’m afraid that still won’t help you read them back in if there is 40GB of data.
Hello, and thank you very much for the detailed response to my question! I apologize for my late reply.
To clarify, I’m not a mathematician, but I’m working under the assumption that the function P is a well-defined probability density. My aim here is to analyze some astronomical data statistically, and I realize now that I only shared the final equation without much context. If it would help, I’d be happy to provide additional details.
In response to your points, the dimension of the “a” parameter in my case ranges between 10 and 100, not as high as a million. I’m currently using 5,000 draws across 4 chains, and I’ll explore whether expanding the available memory (I can increase both RAM and ROM) might help address the issue.
I appreciate your suggestion about removing the Dirichlet prior, and I’ll definitely try running a computation without it to see if it improves performance.
As for the recommendation to use CmdStanPy, I have only worked with Stan from within Python and am not yet familiar with how to run Stan separately. If there are any resources or documentation you could suggest for this, it would be greatly appreciated.
Thanks again for your support and advice!
CmdStanPy is another Python interface which can be run from within Python.
Here is a case study introducing CmdStanPy (also plotnine, which is an awesome plotting library): Multilevel regression modeling with CmdStanPy and plotnine