Hierarchical Dirichlet Process: divergent transitions with hyperprior


#1

I am trying to model a Hierarchical Dirichlet Process in Stan. The mathematical model is as follows:

G_j | G_0 , \alpha_0 \sim DP(\alpha_0, G_0) \quad j=1,2
G_0 \sim DP(2, \mathcal N(0, 5))

The data are drawn from two univariate normal distribution with different means.

Here is the model, I use a truncated version of the DP with the same number of components (H) for all the 3 of them:

data {
  int<lower=0> H; // number of components in the DP
  int<lower=0> J; // number of groups
  int<lower=0> max_num_samples;
  int<lower=0> num_samples[J]; // number of samples in each group
  matrix[J, max_num_samples] samples;
}

parameters {
  real<lower=0> alpha_0;
  vector<lower=0, upper=1>[H-1] nus[J];
  vector<lower=0, upper=1>[H-1] nu_top;
  vector[H] means; // cluster means
}

transformed parameters {
  simplex[H] weights[J];
  simplex[H] weights_top;
  real prod1_nu = 0;

  for (j in 1:J) {
    weights[j][1] = nus[j][1];
    prod1_nu = 1 - nus[j][1];
    for (h in 2:(H-1)) {
      weights[j][h] = nus[j][h] * prod1_nu;
      prod1_nu *= (1 - nus[j][h]);
    }
    weights[j][H] = fmax(0.0, 1 - sum(weights[j][1:(H-1)]));
  }

  weights_top[1] = nu_top[1];
  prod1_nu = 1 - nu_top[1];
  for (h in 2:(H-1)) {
    weights_top[h] = nu_top[h] * prod1_nu;
    prod1_nu *= (1 - nu_top[h]);
  }
  weights_top[H] = fmax(0.0, 1 - sum(weights_top[1:(H-1)]));
}

model {

  // hyperparams for G0 | gamma, H
  // H ~ N(0, 5);
  real sigmaH = 5;
  real gamma = 1.0;

  // Top level DP
  for (h in 1:H) {
    means[h] ~ normal(0, sigmaH);
  }

  nu_top ~ beta(1, gamma);

  alpha_0 ~ gamma(3,3);

  // Bottom level DPs
  for (j in 1:J) {
    nus[j] ~ beta(1, alpha_0);
  }

  for (j in 1:J) {
    for (i in 1:num_samples[j]) {
      real partial_sums[H];
      for (h in 1:H) {
        partial_sums[h] = log(weights[j][h]) + normal_lpdf(samples[j, i] | means[h], 1.0);
      }
      target += log_sum_exp(partial_sums);
    }
  }
}

This together with this R code:

mu = 0
sigma = 1
J = 2
num_samples = 100

samples = matrix(nrow=J, ncol=num_samples)

for (j in 1:J) {
  samples[j, ] = rnorm(num_samples, mu + 2*j, sigma)
}

dat = list(
  H=10,
  J=2,
  max_num_samples=num_samples,
  num_samples=c(num_samples,num_samples),
  samples=samples)

fit = stan(file="hdp.stan", data=dat)

Produces divergent chains: if i try to plot the trace of the means, I get 4 flat lines.

Instead, if one fixes the hyperparameter \alpha_0 to be equal to 1, the model recovers the data (apart from a little bit of label switching).

Has anyone encountered the same problem?