I have a possibly naive question for the forum. I have been trying to implement a Generalized Dirichlet distribution (due to Connor & Mosimann, 1969) as a prior for a simplex valued parameter say \pi with \sum_{j=1}^{K} \pi_j = 1. The constructive definition of the generalized Dirichlet distribution is below.
Put simply, it allows a general Beta(a,b) distribution for each of the “broken sticks” instead of a Beta(\alpha/K, \alpha (1-k/K)) distribution that leads to the symmetric Dir(\alpha/K, \ldots, \alpha/K) distribution.
My question is: how to specify the Generalized Dirichlet as a prior in Stan? The snippet below is my initial attempt at doing this following the definitions but keeping the parameters of Beta to match the symmetric Dirichlet. Is this the correct way? Should I also add the Jacobian somewhere?
data{
int<lower=1> n; // number of observations
int<lower=1> K; // number of categories
int <lower=0,upper=n> y[K]; // data
real<lower=0> alpha; // Dirichlet shape
}
parameters{
simplex[K] pi; // category probabilties
}
transformed parameters{
// stick-breaking construction
real<lower=0,upper=1> Z[K-1];
Z[1] = pi[1];
for(j in 2:(K-1)){
Z[j] = pi[j]/(1-sum(pi[1:j-1]));
}
}
model{
for(j in 1:(K-1)){
Z[j] ~ beta(alpha/K, alpha*(1-j/K)); // Dirichlet
}
y ~ multinomial(pi);
}
Thanks a lot @bgoodri !
This is a more direct and straightforward approach … I am getting a few error messages for negative values in \pi for this implementation, but could be something that I’ve missed and should be easily fixable.
Alternatively you could work on log-scale by using a ordered sequence apply inv_logit link to replace the simplex and exploit inv_logit(x) = 1 - inv_logit(-x). See how to implement a categorical logit model old school. Just my 2 cents.
This approach is quadratic in K, so if K is tiny, no big deal. Otherwise, you want to keep a running sum in this loop to avoid re-adding the prefixes of pi each iteration of the loop.
That’s a really helpful tip. Thank you. My K \approx 700 so definitely computation is a big concern … but it seems numerical stability is becoming more of an issue with this stick-breaking construction especially with highly unbalanced and long-tailed count distributions.
Hello. I implemented a compound Generalized Dirichlet Multinomial, and the correponding _rng function.
Question: When I sample from the model, and use that data as training data, in an attempt to infer the parameters, I cannot seem to recover them. Moreover, if I sample from the distribution using the returned parameters, the results are not similar to the training data.
Perhaps there is a bug in the model? I cannot seem to find literature to compute the _rng so I used the generative process as above. Maybe that is the issue?
Any assistance would be greatly appreciated!
functions {
real dirichlet_multinomial_lpmf(int[] y, vector alpha) {
real alpha_plus = sum(alpha);
return lgamma(alpha_plus) + sum(lgamma(alpha + to_vector(y)))
- lgamma(alpha_plus+sum(y)) - sum(lgamma(alpha));
}
int[] dirichlet_multinomial_rng(vector alpha, int N) {
return multinomial_rng(dirichlet_rng(alpha), N);
}
real generalized_dirichlet_multinomial_lpmf(int[] y, vector alpha, vector beta_) {
// y is num_categories dimensional, alpha and beta are num_categories-1 dimensional
int D = dims(alpha)[1]; // D = num_categories-1
vector[D+1] x = cumulative_sum(to_vector(y));
vector[D] z;
vector[D] alpha_prime = alpha + to_vector(y)[1:D];
vector[D] beta_prime;
z[1] = x[D+1];
z[2:D] = rep_vector(x[D+1],D-1) - x[1:D-1];
beta_prime = beta_ + z;
return (lgamma(x[D]+1) - sum(lgamma(to_vector(y)+rep_vector(1,D+1)))
+ sum(lgamma(alpha_prime)) - sum(lgamma(alpha))
+ sum(lgamma(beta_prime)) - sum(lgamma(beta_))
- sum(lgamma(alpha_prime + beta_prime)) + sum(lgamma(alpha+beta_))
);
}
int[] generalized_dirichlet_multinomial_rng(vector alpha, vector beta_, int N) {
int D = dims(alpha)[1];
vector[2] tmp;
int out[D+1];
out[1] = N;
for(n in 1:D){
tmp[1] = alpha[n];
tmp[2] = beta_[n];
out[n:n+1] = dirichlet_multinomial_rng(tmp,out[n]);
}
return out;
}
}
data{
int<lower=1> num_obs; // number of observations
int<lower=1> num_test_obs; // number of test observations
int<lower=1> num_categories; // number of categories
int<lower=0> obs[num_obs,num_categories]; // data
int<lower=0, upper=1> run_estimation;
}
parameters {
vector<lower=0>[num_categories-1] alpha_std;
vector<lower=0>[num_categories-1] beta_std;
}
transformed parameters {
vector<lower=0>[num_categories-1] alpha;
vector<lower=0>[num_categories-1] beta_;
alpha = exp(1.0 + alpha_std);
beta_ = exp(1.0 + beta_std);
}
model{
alpha_std ~ normal(0,1);
beta_std ~ normal(0,1);
if (run_estimation==1){
for(n in 1:num_obs){
obs[n] ~ generalized_dirichlet_multinomial(alpha,beta_);
}
}
}
generated quantities {
int y_test[num_test_obs,num_categories]; // test data
if (run_estimation==0){
for (n in 1:num_test_obs) {
y_test[n] = generalized_dirichlet_multinomial_rng(alpha,beta_,1000);
}
}
}
real generalized_dirichlet_multinomial_lpmf(int[] y, vector alpha, vector beta_) {
// y is num_categories dimensional, alpha and beta are num_categories-1 dimensional
int D = dims(alpha)[1]; // D = num_categories-1
vector[D+1] x = cumulative_sum(to_vector(y));
vector[D] z = rep_vector(x[D+1],D) - x[1:D];
vector[D] alpha_prime = alpha + to_vector(y)[1:D];
vector[D] beta_prime;
beta_prime = beta_ + z;
return (lgamma(x[D+1]+1) - sum(lgamma(to_vector(y)+rep_vector(1,D+1)))
+ sum(lgamma(alpha_prime)) - sum(lgamma(alpha))
+ sum(lgamma(beta_prime)) - sum(lgamma(beta_))
- sum(lgamma(alpha_prime + beta_prime)) + sum(lgamma(alpha+beta_))
);
}
Much better results. Still, any resources involving how to efficently generate samples from the generalized dirichlet multinomial (I copied matlab code for the CRAN rgdirmn.R function)…
The Dirichlet (and Generalized Dirichlet) can take anything as input (as long as they are positive). By your line simplex ~ generalized_dirichlet_multinomial(alphas, betas)
I am assuming you want the output of the Generalized Dirichlet, and not the compound distribution? If that is the case, you would have to rewrite the _lpmf function.
@stemangiola Sorry this is not an answer by any means, and excuse me if it’s a silly question, but do you have data as proportions or counts? Also, I think the GD offers a more general dependence structure but can be difficult when dealing with too many categories.
Also, what I originally wanted to do (but didn’t finish) was to look at variable selection for the integrated Dirichlet Multinomial model for microbiome data, something like this: https://github.com/duncanwadsworth/dmbvs (but without the spike-and-slab), and thought it would be cool to explore Generalized Dirichlet Multinomial as well. We have count data and the integrated models are more flexible for handling overdispersion.
From deconvolution I infer proportions from mixed counts of different cell type (similar to what you might have with metagenomics I guess). So my proportions are parameters them self, and the Dirichlet regression is hierarchical.
I have max 20ish categories, so I don’t need to do variable selection, just a kinda robust generalised Dirichlet regression framework, I can plug in into my stan model.
In case this is still relevant for people, here’s implementation of the two methods described on the Wiki page, yielding the same result. Feel free to optimize the code to your liking and implement the vectorized version for the tilde notation.
functions{
real gendir_wong_s_lpdf(vector x, vector a, vector b){
// x is simplex
// a, b paramter vectors of CM distribution
int k=rows(x)-1;
real res=1;
for(i in 1:(k-1)){
res *= pow(x[i], a[i]-1)*pow(1-sum(x[1:i]), b[i]-a[i+1]-b[i+1])/beta(a[i],b[i]);
}
res *= pow(x[k], a[k]-1)*pow(1-sum(x[1:k]), b[k]-1)/beta(a[k], b[k]);
return log(res);
}
// RNG function for Wong method
vector gendir_wong_s_rng(vector a, vector b){
int k = rows(a);
vector[k+1] x;
real s = beta_rng(a[1], b[1]);
x[1]=s;
for (j in 2:k){
x[j] = beta_rng(a[j], b[j])*(1-s);
s+=x[j];
}
x[k+1]=1-sum(x[1:k]);
return x;
}
// CM distribution (classical)
real gendir_s_lpdf(vector x, vector a, vector b){
// x is simplex
// a, b paramter vectors of CM distribution
int k=rows(x);
real res=1;
res *= pow(x[1], a[1]-1)*pow(sum(x[1:k]),b[1]-(a[1]+b[1]))/beta(a[1], b[1]);
for(i in 2:(k-1)){
res *= pow(x[i], a[i]-1)*pow(sum(x[i:k]),b[i-1]-(a[i]+b[i]))/beta(a[i], b[i]);
}
res *= pow(x[k], b[k-1]-1);
return log(res);
}
} // end of functions block