Ordered simplex constraint transform

OK, I was overcomplicating things - the PDF I derived actually can be used directly without any worries about inverse CDFs!

Here’s ordered_simplex_min.stan:

functions {
 vector ordered_simplex_constrain_min_lp(vector y) {
    int Km1 = rows(y);
    vector[Km1 + 1] x;
    real remaining = 1; // Remaining amount to be distributed
    real base = 0; // The minimum for the next element
    for(i in 1:Km1) {
      int K_prime = Km1 + 2 - i; // Number of remaining elements
      //First constrain to [0; 1 / K_prime]
      real invlogity = inv_logit(y[i]);
      real x_cons = inv(K_prime) * invlogity;
      // Jacobian for the constraint
      target += -log(K_prime) + log(invlogity) + log1m(invlogity);

      // Add the lowest element log density
      target += log(K_prime - 1) +  log(K_prime) + (K_prime - 2) * log1m(K_prime*x_cons);

      x[i] = base + remaining * x_cons;
      base = x[i];
      //We added  remaining * x_cons to each of the K_prime elements yet to be processed
      remaining -= remaining * x_cons * K_prime; 
    }
    x[Km1 + 1] = base + remaining;

    return x;
 }
}
data {
  int K;
  int<lower=0> observed[K];
  real<lower=0> prior_alpha;
}


parameters {
  vector[K - 1] y;
}

transformed parameters {
  simplex[K] x = ordered_simplex_constrain_min_lp(y);
}

model {
  x ~ dirichlet(rep_vector(prior_alpha, K));
  observed ~ multinomial(x);
}

EDIT: For the record, I probably don’t understand why the code works. I would swear the code is missing a target += log(remaining) correction (for the x[i] = base + remaining * x_cons; line), but adding it makes the SBC fail… Also I can see why it works for the uniform simplex distribution, but I much less clear whether it should work with the added Dirichlet prior… At some point I will learn enough to understand it, but today is not the day.

Simulator and testing code

And here’s our :

First, we’ll test whether this works well as an implicit flat prior over ordered simplices:

library(SBC) # remotes::install_github("hyunjimoon/SBC")
library(cmdstanr)
library(MCMCpack)
library(ggplot2)

library(future)
plan(multisession)
options(SBC.min_chunk_size = 5)

m <- cmdstan_model("ordered_simplex_min.stan")
backend <- SBC_backend_cmdstan_sample(m, chains = 2)

generate_one_dataset <- function(N, K, prior_alpha = 1) {
  x_raw <- rdirichlet(1, alpha = rep(prior_alpha, K))
  x <- sort(x_raw)
  observed <- as.integer(rmultinom(1, size = N, prob = x))
  
  list(
    parameters = list(x = x),
    generated = list(K = K, observed = observed, prior_alpha = prior_alpha)
  )
}

datasets_flat <- generate_datasets(
    SBC_generator_function(generate_one_dataset, N = 20, K = 4),
    n_datasets = 1000)
  
res_flat <- compute_results(datasets_flat, backend)
  
plot_rank_hist(res_flat)
plot_ecdf_diff(res_flat)

plot_sim_estimated(res_flat, alpha = 0.2)


Looking good!

And now let’s try with a concentrated prior:

datasets_6 <- generate_datasets(
    SBC_generator_function(generate_one_dataset, N = 10, K = 4, prior_alpha = 6),
    n_datasets = 1000)
  
res_6 <- compute_results(datasets_6, backend)
  
plot_rank_hist(res_6)
plot_ecdf_diff(res_6)

plot_sim_estimated(res_6)

Also looking good, although we apparently can’t learn much about the parameters:


Note that for this kind of investigation, it is actually useful to have weak likelihood (i.e. low N in the simulator), as we are interested whether the implied prior is correct and the stronger likelihood we have, the lower effect of the prior on our posterior and the harder it is to find a discrepancy with SBC. At the same time I don’t think we would want to have no likelihood at all, as I am also interested whether some problems do not arise in the interaction between the prior and the likelihood.

There is however likely room for improvement as the implied geometry on y can look a bit weird - here’s one of the worse looking pairs plots:

2 Likes