# Generalized Dirichlet Distribution as a prior

Hi Everyone,

I have a possibly naive question for the forum. I have been trying to implement a Generalized Dirichlet distribution (due to Connor & Mosimann, 1969) as a prior for a simplex valued parameter say \pi with \sum_{j=1}^{K} \pi_j = 1. The constructive definition of the generalized Dirichlet distribution is below.
Put simply, it allows a general Beta(a,b) distribution for each of the “broken sticks” instead of a Beta(\alpha/K, \alpha (1-k/K)) distribution that leads to the symmetric Dir(\alpha/K, \ldots, \alpha/K) distribution.

My question is: how to specify the Generalized Dirichlet as a prior in Stan? The snippet below is my initial attempt at doing this following the definitions but keeping the parameters of Beta to match the symmetric Dirichlet. Is this the correct way? Should I also add the Jacobian somewhere?

data{
int<lower=1> n; // number of observations
int<lower=1> K; // number of categories
int <lower=0,upper=n> y[K]; // data
real<lower=0> alpha; // Dirichlet shape
}
parameters{
simplex[K] pi; // category probabilties
}
transformed parameters{
// stick-breaking construction

real<lower=0,upper=1> Z[K-1];
Z[1] = pi[1];
for(j in 2:(K-1)){
Z[j] = pi[j]/(1-sum(pi[1:j-1]));
}
}
model{
for(j in 1:(K-1)){
Z[j] ~ beta(alpha/K, alpha*(1-j/K)); // Dirichlet
}
y ~ multinomial(pi);
}

2 Likes

If you want to do it that way, do

data {
int<lower = 1> K;
vector<lower = 0>[K - 1] a;
vector<lower = 0>[K - 1] b;
}
parameters {
vector<lower = 0, upper = 1>[K - 1] z;
}
transformed parameters {
simplex[K] pi; // do the stick breaking thing
pi[1] = z[1];
for (k in 2:(K - 1)) pi[k] = z[k] / (1 - sum(pi[1:(k - 1)]));
pi[k] = 1 - sum(pi[1:(K - 1)]);
}
model {
target += beta_lpdf(z | a, b);
}

2 Likes

Thanks a lot @bgoodri !
This is a more direct and straightforward approach … I am getting a few error messages for negative values in \pi for this implementation, but could be something that I’ve missed and should be easily fixable.

It might not be numerically stable enough. You could try forcing elements of pi to zero if the previous elements sum to 1.

1 Like

Alternatively you could work on log-scale by using a ordered sequence apply inv_logit link to replace the simplex and exploit inv_logit(x) = 1 - inv_logit(-x). See how to implement a categorical logit model old school. Just my 2 cents.

2 Likes

This approach is quadratic in K, so if K is tiny, no big deal. Otherwise, you want to keep a running sum in this loop to avoid re-adding the prefixes of pi each iteration of the loop.

1 Like

That’s a really helpful tip. Thank you. My K \approx 700 so definitely computation is a big concern … but it seems numerical stability is becoming more of an issue with this stick-breaking construction especially with highly unbalanced and long-tailed count distributions.

1 Like

Hello. I implemented a compound Generalized Dirichlet Multinomial, and the correponding _rng function.

Question: When I sample from the model, and use that data as training data, in an attempt to infer the parameters, I cannot seem to recover them. Moreover, if I sample from the distribution using the returned parameters, the results are not similar to the training data.

Perhaps there is a bug in the model? I cannot seem to find literature to compute the _rng so I used the generative process as above. Maybe that is the issue?

Any assistance would be greatly appreciated!

functions {

real dirichlet_multinomial_lpmf(int[] y, vector alpha) {
real alpha_plus = sum(alpha);

return lgamma(alpha_plus) + sum(lgamma(alpha + to_vector(y)))
- lgamma(alpha_plus+sum(y)) - sum(lgamma(alpha));
}

int[] dirichlet_multinomial_rng(vector alpha, int N) {
return multinomial_rng(dirichlet_rng(alpha), N);
}

real generalized_dirichlet_multinomial_lpmf(int[] y, vector alpha, vector beta_) {
// y is num_categories dimensional, alpha and beta are num_categories-1 dimensional
int D = dims(alpha)[1]; // D = num_categories-1

vector[D+1] x = cumulative_sum(to_vector(y));

vector[D] z;

vector[D] alpha_prime = alpha + to_vector(y)[1:D];
vector[D] beta_prime;

z[1] =  x[D+1];
z[2:D] = rep_vector(x[D+1],D-1) - x[1:D-1];
beta_prime = beta_ + z;

return (lgamma(x[D]+1) - sum(lgamma(to_vector(y)+rep_vector(1,D+1)))
+ sum(lgamma(alpha_prime)) - sum(lgamma(alpha))
+ sum(lgamma(beta_prime)) - sum(lgamma(beta_))
- sum(lgamma(alpha_prime + beta_prime)) + sum(lgamma(alpha+beta_))
);

}

int[] generalized_dirichlet_multinomial_rng(vector alpha, vector beta_, int N) {

int D = dims(alpha)[1];
vector[2] tmp;
int out[D+1];

out[1] = N;

for(n in 1:D){
tmp[1] = alpha[n];
tmp[2] = beta_[n];
out[n:n+1] = dirichlet_multinomial_rng(tmp,out[n]);

}

return out;
}

}

data{
int<lower=1> num_obs; // number of observations
int<lower=1> num_test_obs; // number of test observations
int<lower=1> num_categories; // number of categories
int<lower=0> obs[num_obs,num_categories]; // data

int<lower=0, upper=1> run_estimation;
}

parameters {
vector<lower=0>[num_categories-1] alpha_std;
vector<lower=0>[num_categories-1] beta_std;
}

transformed parameters {
vector<lower=0>[num_categories-1] alpha;
vector<lower=0>[num_categories-1] beta_;

alpha = exp(1.0 + alpha_std);
beta_ = exp(1.0 + beta_std);
}

model{
alpha_std ~ normal(0,1);
beta_std ~ normal(0,1);

if (run_estimation==1){
for(n in 1:num_obs){
obs[n] ~ generalized_dirichlet_multinomial(alpha,beta_);
}
}
}

generated quantities {
int y_test[num_test_obs,num_categories]; // test data

if (run_estimation==0){
for (n in 1:num_test_obs) {
y_test[n] = generalized_dirichlet_multinomial_rng(alpha,beta_,1000);
}
}
}

2 Likes

update: I found a bug in my code involving z:

real generalized_dirichlet_multinomial_lpmf(int[] y, vector alpha, vector beta_) {
// y is num_categories dimensional, alpha and beta are num_categories-1 dimensional
int D = dims(alpha)[1]; // D = num_categories-1

vector[D+1] x = cumulative_sum(to_vector(y));

vector[D] z = rep_vector(x[D+1],D) - x[1:D];

vector[D] alpha_prime = alpha + to_vector(y)[1:D];
vector[D] beta_prime;

beta_prime = beta_ + z;

return (lgamma(x[D+1]+1) - sum(lgamma(to_vector(y)+rep_vector(1,D+1)))
+ sum(lgamma(alpha_prime)) - sum(lgamma(alpha))
+ sum(lgamma(beta_prime)) - sum(lgamma(beta_))
- sum(lgamma(alpha_prime + beta_prime)) + sum(lgamma(alpha+beta_))
);

}


Much better results. Still, any resources involving how to efficently generate samples from the generalized dirichlet multinomial (I copied matlab code for the CRAN rgdirmn.R function)…

3 Likes

Draw a bunch of independent betas and run them through the transformation at

3 Likes

If I understand this is a dirichlet with independent variance parameters for each component. Any chance you could make proportions as input?

e.g.,

simplex ~ generalized_dirichlet_multinomial(alphas, betas)

The Dirichlet (and Generalized Dirichlet) can take anything as input (as long as they are positive). By your line
simplex ~ generalized_dirichlet_multinomial(alphas, betas)

I am assuming you want the output of the Generalized Dirichlet, and not the compound distribution? If that is the case, you would have to rewrite the _lpmf function.

1 Like

I was wondering if there were some evolution on the topic.

I am performing a standard Dirichlet regression

	real dirichlet_regression_lpdf(vector p, row_vector X, matrix alpha, real phi){

// Build sum to zero variable
int c = cols(alpha);
int r = rows(alpha);
matrix[r, c]  alpha_ = alpha;
alpha_[1,c] = -sum(alpha_[1, 1:(c-1)]);

// Calculate log prob
return (dirichlet_lpdf(p | softmax( to_vector(X * alpha_ ))  * exp(phi) + 1 ));
}


However I find that the biological variability leads definitely to multiple variances needed, as for the generalised dirichlet distribution.

I need N proportions as input with N means and N variances, and a log probability as output.

@stemangiola Sorry this is not an answer by any means, and excuse me if it’s a silly question, but do you have data as proportions or counts? Also, I think the GD offers a more general dependence structure but can be difficult when dealing with too many categories.

Also, what I originally wanted to do (but didn’t finish) was to look at variable selection for the integrated Dirichlet Multinomial model for microbiome data, something like this: https://github.com/duncanwadsworth/dmbvs (but without the spike-and-slab), and thought it would be cool to explore Generalized Dirichlet Multinomial as well. We have count data and the integrated models are more flexible for handling overdispersion.

Sorry if this is a pointless response. :(

From deconvolution I infer proportions from mixed counts of different cell type (similar to what you might have with metagenomics I guess). So my proportions are parameters them self, and the Dirichlet regression is hierarchical.

I have max 20ish categories, so I don’t need to do variable selection, just a kinda robust generalised Dirichlet regression framework, I can plug in into my stan model.

In case this is still relevant for people, here’s implementation of the two methods described on the Wiki page, yielding the same result. Feel free to optimize the code to your liking and implement the vectorized version for the tilde notation.

functions{
real gendir_wong_s_lpdf(vector x, vector a, vector b){
// x is simplex
// a, b paramter vectors of CM distribution
int k=rows(x)-1;
real res=1;
for(i in 1:(k-1)){
res *= pow(x[i], a[i]-1)*pow(1-sum(x[1:i]), b[i]-a[i+1]-b[i+1])/beta(a[i],b[i]);
}
res *= pow(x[k], a[k]-1)*pow(1-sum(x[1:k]), b[k]-1)/beta(a[k], b[k]);
return log(res);
}
// RNG function for Wong method
vector gendir_wong_s_rng(vector a, vector b){
int k = rows(a);
vector[k+1] x;
real s = beta_rng(a[1], b[1]);
x[1]=s;
for (j in 2:k){
x[j] = beta_rng(a[j], b[j])*(1-s);
s+=x[j];
}
x[k+1]=1-sum(x[1:k]);
return x;
}
// CM distribution (classical)
real gendir_s_lpdf(vector x, vector a, vector b){
// x is simplex
// a, b paramter vectors of CM distribution
int k=rows(x);
real res=1;
for(i in 1:(k-1)){
res *= pow(x[i], a[i]-1)*pow(sum(x[i:k]),b[i-1]-(a[i]+b[i]))/beta(a[i], b[i]);
}
res *= pow(1-sum(x[1:(k-1)]), b[k-1]-1);
return log(res);
}

} // end of functions block

4 Likes

I’m curious, did you do some more work on

or you recon, this is an already optimised version.