Where to find stick breaking transform and Jacobian code?

I am writing a model where I need to implement some custom simplex parameters (see here) but would rather not have to re-implement the stick breaking transform and its Jacobian. Where can I find those functions in the stan codebase and how would I call them from a stan program? Thanks!

I found the stick breaking transform code here and here. But I could use help importing these stan library functions into my stan program.

I’m attaching two stan programs that fit the following simple model:

Z = \sum_{i=1}^k X_i
log(X_i) \sim Normal(0,1)

One version (simplex_k.stan) uses the built-in simplex type. Another verison (manual_k.stan) implements the simplex manually but does so using the less efficient transform that @aaronjg proposed in this thread.

You can test the models like so:

simplex_k_model = cmdstanr::cmdstan_model("simplex_k.stan")

simplex_k_fit = simplex_k_model$sample(data=list(z=5,k=100),
                                  seed=123,
                                  chains=4,
                                  parallel_chains=4,
                                  iter_warmup=1000,
                                  iter_sampling=1000)

And similarly for manual_k. The inference is mathematically equivalent but I find that the simplex model (which uses the stick breaking transform under the hood) has about 5x more effective samples per second.

What I would like: another manual version that uses the stick breaking transform code in the stan codebase, so I can adapt it to situations where the built-in simplex type doesn’t work (e.g. ragged array of simplexes). Thank you!
simplex_k.stan (253 Bytes)
manual_k.stan (401 Bytes)

Hi, @potash and sorry this didn’t get answered before.

We (Sean Pinkney, Seth Axen, Nicholas Siccha, Meenal Jhajaria and me) are about to wrap up our paper on comparing different simplex transforms. If you want to do the simplex transform yourself from scratch, you can just implement it directly in Stan.

The description of the transforms we use is in the Reference Manual: Constraint Transforms

It looks like you’re manual program is using a different transform than the stick-breaking transform.

We are also about to expose jacobian += in the transformed parameters block and also expose all of our built-in transforms for just this kind of application. As is, I’d implement as a function like this:

vector to_simplex_lp(vector y) {
  N = rows(y);

  # stick-breaking ratios
  vector[N] z;
  for (n in 1:N) {
    z[n] = inv_logit(y[n] - log(N - n));
  }

  # construct simplex x by stick-breaking
  vector[N + 1] x;
  x[1] = z[1];
  x_sum = x[1];
  for (n in 2:N) {
    x[n] = (1 - x_sum) * z[n];
    x_sum += x[n];
  }
  x[N + 1] = 1 - sum(x[1:N]);

  # change of variables adjustment
  target += log(z) + log1m(z);
  for (n in 1:N) {
    target += log1m(sum(x[1:n - 1]));
  }

  return y;
}

I may have messed up some of the algebra, but the basic structure will work (though you can eliminate some of the redundant calculation at the risk of making it less readable).

Then you can use it in transformed parameters or in the model block like this:

vector[N] y;
vector[N + 1] simplex = to_simplex_lp(y);

and the Jacobian will get added to the log density.

3 Likes

Thanks @Bob_Carpenter, I fixed a couple of typos in your code and the following function works and I am attaching my simple model using this:

vector to_simplex_lp(vector y) {
  int N = rows(y);

  # stick-breaking ratios
  vector[N] z;
  for (n in 1:N) {
    z[n] = inv_logit(y[n] - log(N + 1 - n));
  }

  # construct simplex x by stick-breaking
  vector[N + 1] x;
  x[1] = z[1];
  real x_sum = x[1];
  for (n in 2:N) {
    x[n] = (1 - x_sum) * z[n];
    x_sum += x[n];
  }
  x[N + 1] = 1 - x_sum;

  # change of variables adjustment
  target += log(z) + log1m(z);
  for (n in 1:N) {
    target += log1m(sum(x[1:n - 1]));
  }

  return x;
}

FWIW in my demo it runs a little (~20%) slower than the built in simplex type but the inference looks good and the effective samples per iteration are the same.

Looking forward to your paper comparing simplex transforms and the updates you mentioned to Stan.

manual_k_stickbreak.stan (814 Bytes)

2 Likes

Thanks for fixing and following up. Yes, the straight-up C++ is a little more efficient than doing everything with autodiff in Stan, but the answer should be the same.