Ragged array of simplexes

Hello gurus,

I would like to implement a model with a ragged array of simplexes.

In response to this query about a year ago, Bob mentions (I think) that there was an intention to expose the simplex_constrain function and log determinant of the Jacobian, which would facilitate this: https://groups.google.com/forum/#!msg/stan-users/9WTI41kBf4I/sb_WAt6yBQAJ

I was wondering if this was implemented? (Or perhaps I misunderstood.)

Many thanks,
Jeff

2 Likes

No

Many thanks Ben. I did a quick port of the C++ functions to Stan functions, pasted below in case useful for any others. (Not optimally efficient because it duplicates the transform calculation, but works in a pinch.)

Thanks,
Jeff

  /**
   * Return the simplex corresponding to the specified free vector.
   * A simplex is a vector containing values greater than or equal
   * to 0 that sum to 1.  A vector with (K-1) unconstrained values
   * will produce a simplex of size K.
   *
   * The transform is based on a centered stick-breaking process.
   */
  vector simplex_constrain(vector y) {

    vector[rows(y)+1] x;
    int Km1;
    real stick_len;

    Km1 = rows(y);
    stick_len = 1.0;
    for (k in 1:Km1) {
      real z_k;
      z_k = inv_logit(y[k] - log(Km1 - k + 1));
      x[k] = stick_len * z_k;
      stick_len = stick_len - x[k];
    }
    x[Km1+1] = stick_len;

    return x;
  }

  /**
   * Return the log absolute Jacobian determinant of the simplex
   * transform defined in simplex_constrain().
   */
  real simplex_constrain_lj(vector y) {

    real lj;
    int Km1;
    real stick_len;
    
    lj = 0.0;
    Km1 = rows(y);
    stick_len = 1.0;
    for (k in 1:Km1) {
      real adj_y_k;
      adj_y_k = y[k] - log(Km1 - k + 1);
      lj = lj + log(stick_len) - log1p_exp(-adj_y_k) - log1p_exp(adj_y_k);
      stick_len = stick_len * (1.0 - inv_logit(adj_y_k));
    }
    
    return lj;
  }
1 Like

That is quite possibly correct but more complicated than necessary. Dealing with simplexes of different sizes is annoying, but it it actually the least annoying of the ragged arrays. The trick to avoiding having to deal with Jacobians is to remember that a Dirichlet distribution over simplexes can be constructed by normalizing a set of unit-scale but independent Gamma random variables:


Thus, in Stan you declare a looooong vector of non-negative parameters

parameters {
  vector<lower=0>[sum(K)] gamma; // K is an int[] declared in data
  ...
}

and in the model block you break off however many of those you need at the moment, divide them by their sum, and use the resulting simplex in your likelihood, as in

model {
  int pos = 1;
  for (j in 1:J) {
    vector[K[j]] pi = segment(gamma, pos, K[j]);
    pi = pi / sum(pi); // simplex
    target += // some function of pi
    pos = pos + K[j];
  }
}

Then you put a Gamma prior on gamma

  target += gamma_lpdf(gamma | alpha, 1);
   // implies pi ~ dirichlet(segment(alpha, pos, K[j]))

where alpha is a looooong vector of known shape hyperparameters for the Dirichlet distributions that are concatenated together. If alpha = rep_vector(1, sum(K)); then each pi is uniform a priori over simplexes of size K[j] and you can get away with

  target += exponential_lpdf(gamma | 1);
  // implies pi ~ uniform over simplexes
3 Likes

Unfortunately this will lead to a non-identified posterior given the over-parameterization. In general this approach works poorly with MCMC.

data {
  int<lower=2> K;
}
parameters {
  simplex[K] pi;
  vector<lower=0>[K] gamma;
}
model {
  target += exponential_lpdf(gamma | 1);
}
generated quantities {
  vector[K] pi_rescaled = gamma / sum(gamma);
}

yields

Inference for Stan model: simplexes.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

                 mean se_mean   sd   2.5%    25%    50%    75%  97.5% n_eff Rhat
pi[1]            0.19    0.00 0.17   0.00   0.06   0.15   0.28   0.61  4000 1.00
pi[2]            0.20    0.00 0.16   0.01   0.07   0.16   0.30   0.60  4000 1.00
pi[3]            0.20    0.00 0.17   0.01   0.07   0.16   0.30   0.60  4000 1.00
pi[4]            0.20    0.00 0.16   0.01   0.07   0.16   0.30   0.61  4000 1.00
pi[5]            0.20    0.00 0.16   0.01   0.07   0.16   0.29   0.60  4000 1.00
gamma[1]         1.01    0.02 1.01   0.03   0.29   0.69   1.39   3.71  4000 1.00
gamma[2]         1.01    0.02 1.01   0.03   0.29   0.70   1.39   3.76  4000 1.00
gamma[3]         1.00    0.02 0.98   0.03   0.30   0.70   1.39   3.52  4000 1.00
gamma[4]         1.00    0.02 0.97   0.02   0.28   0.70   1.40   3.61  4000 1.00
gamma[5]         1.00    0.02 0.98   0.02   0.31   0.70   1.39   3.65  4000 1.00
pi_rescaled[1]   0.20    0.00 0.16   0.01   0.07   0.16   0.30   0.59  4000 1.00
pi_rescaled[2]   0.20    0.00 0.16   0.01   0.07   0.16   0.29   0.60  4000 1.00
pi_rescaled[3]   0.20    0.00 0.16   0.01   0.07   0.16   0.29   0.60  4000 1.00
pi_rescaled[4]   0.20    0.00 0.17   0.01   0.07   0.16   0.30   0.61  4000 1.00
pi_rescaled[5]   0.20    0.00 0.16   0.01   0.07   0.16   0.29   0.59  4000 1.00
lp__           -18.29    0.07 2.44 -24.10 -19.64 -17.92 -16.50 -14.66  1191 1.01

Not bad for a non-identified posterior given the over-parameterization that works poorly with MCMC.

3 Likes

Iā€™ve seen it mentioned on the forums, here and other threads, that this kind of non-identifiability can be problematic - but practically it seems to work sometimes, and is even implemented in the unit vector transformation, and is covered in the manual as an example in the ā€œReparameterizing a Student-t Distributionā€

However, it seems like it certainly simplifies the code, and potentially could flatten out the posterior geometry. What sort of problems should you look out for when doing this?

Thereā€™s a quick discussion in the manual. Stan should fail to fit if thereā€™s unbounded non-identifiability in the model.

Usually thereā€™s some kind of weak identifiability through priors (including bounds), in which case, your mileage will vary depending on the specifics.

Thanks @bob_carpenter, I guess Iā€™m still confused since it seems like some sections of the manual (23.1 mitigating the invariances) seem to say that non-identifiabilities are bad and should be eliminated, others say that it is fine if it is constrained (unit vector) and others seem to have no issue with it at all (reparameterizing a student-t distribution).

I feel like there is some nuance between these different cases that I am missingā€¦

@jeffeaton - if you have performance issues, you might e able to do something like this, where the lp and simplex are calculated at the same time, but then you would have to split it up in the transformed parameters block.

  vector simplex_constrain_new(vector y) {
    real lj = 0;
    vector[rows(y)+1] x;
    int Km1 = rows(y);
    real stick_len = 1.0;

    for (k in 1:Km1) {
      real eq_share = -log(Km1 - k+1);
      real adj_y_k = y[k] + eq_share;      
      real z_k = inv_logit(adj_y_k);
      x[k] = stick_len * z_k;
      lj = lj + log(stick_len) - log1p_exp(-adj_y_k) - log1p_exp(adj_y_k);      
      stick_len = stick_len - x[k];
    }
    x[Km1+1] = stick_len;

    return append_row(lj,x);
  }

I think you may have misinterpreted my post as being literal.

All well and good if you want to sample directly from a Dirichlet, which is not the relevant task here. When sampling directly thereā€™s no problem because the constraint is applied only after the fact. When using this is a prior, however, you have to confront the non-identifiability. The Gamma priors will weakly-identify the posterior to a neighborhood, but within that neighborhood the data will constrain the posterior onto a nonlinear surface that will eventually obstruct efficient sampling. Consequently in the relevant task the softmaxā€™d gamma parameterization will work poorly with MCMC. As I said.

I didnā€™t realize your relevant tasks required 64 bit integers. With multinomial data that Stan can actually ingest, the phenomenon you are referring to manifests itself in the leapfrogger taking more and smaller steps and yields imprecise estimates of the irrelevant vector gamma, but MCMC sampling for the simplex pi is fine when you do

post <- stan("simplexes.stan", data = list(K = 4, x = 2 * 
             rmultinom(1, size = .Machine$integer.max, prob = 1:4 / 10)[,1]), 
             control = list(max_treedepth = 17L))

to get

> print(post, digits = 4)
Inference for Stan model: simplexes.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

             mean se_mean     sd     2.5%      25%      50%      75%    97.5% n_eff   Rhat
gamma[1]   0.4009  0.0183 0.2186   0.0888   0.2469   0.3643   0.5115   1.0050   143 1.0113
gamma[2]   0.8018  0.0366 0.4373   0.1776   0.4938   0.7287   1.0232   2.0101   143 1.0113
gamma[3]   1.2027  0.0549 0.6560   0.2664   0.7407   1.0931   1.5347   3.0151   143 1.0113
gamma[4]   1.6036  0.0732 0.8746   0.3552   0.9877   1.4574   2.0462   4.0202   143 1.0113
pi[1]      0.1000  0.0000 0.0000   0.1000   0.1000   0.1000   0.1000   0.1000  2658 0.9998
pi[2]      0.2000  0.0000 0.0000   0.2000   0.2000   0.2000   0.2000   0.2000  4000 0.9994
pi[3]      0.3000  0.0000 0.0000   0.3000   0.3000   0.3000   0.3000   0.3000  4000 1.0005
pi[4]      0.4000  0.0000 0.0000   0.4000   0.4000   0.4000   0.4000   0.4000  4000 0.9994
lp__     -39.5860  0.0715 1.4965 -43.3719 -40.3148 -39.2481 -38.4819 -37.7376   438 1.0062

for simplexes.stan defined as

data {
  int<lower=2> K;
  int<lower=0,upper=2147483647> x[K];
}
parameters {
  vector<lower=0>[K] gamma;
}
transformed parameters {
  vector[K] pi = gamma / sum(gamma);
}
model {
  target += exponential_lpdf(gamma | 1);
  target += multinomial_lpmf(x | pi);
}

In exchange for substantial increases in computation ā€“ 2^17 leapfrog steps is a huge computational burden that will make many analyses impractical. I think thatā€™s a pretty fair criterion for ā€œnot working wellā€.

2^17 for over 4 billion observed counts that pin the posterior distribution of pi down to many decimal places. If you only have a couple hundred thousand counts, then you donā€™t have to change max_treedepth to get the same essentially perfect estimates.

thanks @bgoodri - itā€™s really helpful to see when these things work and how they break down (i.e. statistically valid but slow, or introducing biased).

Out of curiosity, I coded up the identifiable model

data {
  int<lower=2> K;
  int<lower=0,upper=2147483647> x[K];
}
parameters {
  vector<lower=0>[K-1] gamma;
}
transformed parameters {
  real sum_gamma = 1 + sum(gamma);
  vector[K] pi = append_row(1,gamma) / sum_gamma;
}
model {
  target += -K *log(sum_gamma);
  target += dirichlet_lpdf(pi | rep_vector(1,K));
  target += multinomial_lpmf(x | pi);
}

which has comparable performance:

        mean se_mean   sd   2.5%    25%    50%    75%  97.5% n_eff Rhat
gamma[1]    2.00    0.00 0.00   2.00   2.00   2.00   2.00   2.00  1234    1
gamma[2]    3.00    0.00 0.00   3.00   3.00   3.00   3.00   3.00  1378    1
gamma[3]    4.00    0.00 0.00   4.00   4.00   4.00   4.00   4.00  1350    1
sum_gamma  10.00    0.00 0.00  10.00  10.00  10.00  10.00  10.00  1171    1
pi[1]       0.10    0.00 0.00   0.10   0.10   0.10   0.10   0.10  1171    1
pi[2]       0.20    0.00 0.00   0.20   0.20   0.20   0.20   0.20  2797    1
pi[3]       0.30    0.00 0.00   0.30   0.30   0.30   0.30   0.30  4000    1
pi[4]       0.40    0.00 0.00   0.40   0.40   0.40   0.40   0.40  4000    1
lp__      -38.69    0.03 1.16 -41.81 -39.26 -38.39 -37.84 -37.36  1851    1

but is nearly 1000 times faster:

sum(sapply(post.2@sim$samples,function(z) attr(z,ā€˜elapsed_timeā€™)))
[1] 0.33933
sum(sapply(post@sim$samples,function(z) attr(z,ā€˜elapsed_timeā€™)))
[1] 275.2609

I wonder if there would be situations where the underdefined solution would perform betterā€¦

2 Likes

Thanks Ben for the suggestion to sample from gamma distributions and transform to avoid fussing with the Jacobians.
Sorry, I should have acknowledged your suggestion for this approach in the previous thread I referenced from last year. I went for the option of implementing the simplex functions with explicit constraints because I thought the extra parameters might reduce efficiency, but didnā€™t do any comparisons on it.

At some point Iā€™ll try to implement both in my model and report how the performance compares (though at the moment I think my sampling issues are more because Iā€™m trying fit a model that isnā€™t really well identified from the data rather than implementation of the simplexā€¦).

Thanks also Aaron for the clever suggestion for returning the log jacobian and transform together in the Stan function.

Thanks,
Jeff

The identified versions are generally better, but we obviously havenā€™t tried every case.

The distinguishing feature is that true unbounded non-identifiability leads to an improper posterior, as in the example from the problematic posteriors chapter of the manual:

y ~ normal(alpha + beta, sigma);

with no priors on alpha or beta. This is just the simplest case of a non-identified regressionā€”itā€™s two intercepts. If you add a prior, the posterior becomes proper, but the likelihood is still non-identified. Itā€™s of course better in this case just to use

y ~ normal(gamma, sigma);

I really like this approach of sampling from a normalized set of gamma random variables so that the lengths of the simplexes can vary. Iā€™m not sure where to actually turn this into a ragged array. It seems like pi is just a single simplex in the code examples. Am I missing something?

If, for example, you need a simplex of size K1 and a simplex of size K2, declare a positive_ordered[K1 + K2] vector, separately normalize the first K1 and the last K2 elements, and put unit-scale gamma priors on things accordingly.

Actually, if you know that you need exactly two simplexes, you should just do

simplex[K1] pi_1;
simplex[K2] pi_2;

but this general idea works when the number of simplexes needed is not known until runtime.

1 Like

Thanks for the input Ben. Iā€™m trying to get as far as I can with my hierarchical model before attending StanCon next week so I can ask better questions when I get there :).

The number of simplexes is a function of the number of observations and theyā€™ll number in the hundreds or thousands.

Can you give me another hint on how your proposal works? If I normalize the elements of the positive_ordered vector wont it lose itā€™s ordering?

Also, is the prior this function?

target += exponential_lpdf(gamma | 1);

Iā€™m not sure how to split this off from the likelihood. Would it be something like:

data {
   int<lower=1> O;  // count of observations where each observation contains 1 or more sub-observations
   int<lower=1> S;  // count of all sub-observations
   int<lower=2> K;  // maximum size of the simplex for all sub-observations
}

parameters {
   vector<lower=0>[K] gamma[O];  //randomly sampled values from gamma distribution for each observation
}

transformed parameters {
   for (o in 1:O)   
      vector[K] pi[o] = gamma[o] / sum(gamma[o]);
}

model {
   for (o in 1:O) {
      pi[o] ~ exponential(1);
   }
   // use pi in the likelihood function
}

Am I close?