I am trying to model a Hierarchical Dirichlet Process in Stan. The mathematical model is as follows:
G_j | G_0 , \alpha_0 \sim DP(\alpha_0, G_0) \quad j=1,2
G_0 \sim DP(2, \mathcal N(0, 5))
The data are drawn from two univariate normal distribution with different means.
Here is the model, I use a truncated version of the DP with the same number of components (H) for all the 3 of them:
data {
int<lower=0> H; // number of components in the DP
int<lower=0> J; // number of groups
int<lower=0> max_num_samples;
int<lower=0> num_samples[J]; // number of samples in each group
matrix[J, max_num_samples] samples;
}
parameters {
real<lower=0> alpha_0;
vector<lower=0, upper=1>[H-1] nus[J];
vector<lower=0, upper=1>[H-1] nu_top;
vector[H] means; // cluster means
}
transformed parameters {
simplex[H] weights[J];
simplex[H] weights_top;
real prod1_nu = 0;
for (j in 1:J) {
weights[j][1] = nus[j][1];
prod1_nu = 1 - nus[j][1];
for (h in 2:(H-1)) {
weights[j][h] = nus[j][h] * prod1_nu;
prod1_nu *= (1 - nus[j][h]);
}
weights[j][H] = fmax(0.0, 1 - sum(weights[j][1:(H-1)]));
}
weights_top[1] = nu_top[1];
prod1_nu = 1 - nu_top[1];
for (h in 2:(H-1)) {
weights_top[h] = nu_top[h] * prod1_nu;
prod1_nu *= (1 - nu_top[h]);
}
weights_top[H] = fmax(0.0, 1 - sum(weights_top[1:(H-1)]));
}
model {
// hyperparams for G0 | gamma, H
// H ~ N(0, 5);
real sigmaH = 5;
real gamma = 1.0;
// Top level DP
for (h in 1:H) {
means[h] ~ normal(0, sigmaH);
}
nu_top ~ beta(1, gamma);
alpha_0 ~ gamma(3,3);
// Bottom level DPs
for (j in 1:J) {
nus[j] ~ beta(1, alpha_0);
}
for (j in 1:J) {
for (i in 1:num_samples[j]) {
real partial_sums[H];
for (h in 1:H) {
partial_sums[h] = log(weights[j][h]) + normal_lpdf(samples[j, i] | means[h], 1.0);
}
target += log_sum_exp(partial_sums);
}
}
}
This together with this R code:
mu = 0
sigma = 1
J = 2
num_samples = 100
samples = matrix(nrow=J, ncol=num_samples)
for (j in 1:J) {
samples[j, ] = rnorm(num_samples, mu + 2*j, sigma)
}
dat = list(
H=10,
J=2,
max_num_samples=num_samples,
num_samples=c(num_samples,num_samples),
samples=samples)
fit = stan(file="hdp.stan", data=dat)
Produces divergent chains: if i try to plot the trace of the means, I get 4 flat lines.
Instead, if one fixes the hyperparameter \alpha_0 to be equal to 1, the model recovers the data (apart from a little bit of label switching).
Has anyone encountered the same problem?