Some questions on the simplex transform

I was looking at implementing the simplex type in a model, rather than the non identified softmax transform. I had some questions about the simplex transform and how it scales to large simplexes.

  • The source code refers to a “centered stick breaking process.” Is this the same as the regular stick-breaking process referred to in the manual?
  • Why is the transform done in the real scale rather than the log scale? It seems like log-scale might be more stable, especially if you are interested in log(p) in the end anyway. But maybe I’m overlooking something?
  • Given the way the simplex is constructed, it seems like the ordering of the variables might be important for numerical stability/efficiency. Is it?
  • It looks like with a simplex of size k would have the derivative propegate through k-1 nodes, but if it were constructed through recursive splitting it would only have to propegate through 2*log(k) nodes, which might be more numerically stable for large simplexes.

Is my intuition correct here, or are these wrong or unimportant?

Yes, it’s the stick-breaking process in the model, but I implemented it in a centered way so that (0, 0, …, 0) on the unconstrained scale leads to a symmetric simplex (i.e., representing a uniform distribution).

All of our transforms are down to the natural constrained scales. It would be possible to provide a log_simplex data structure either as a built-in or as a manual addition.

Yes, the ordering can matter when you have K in the thousands. There was a thread about this on discourse but I can’t find it.

You can work with completely unconstrained values (log odds) directly through something like our categorical_logit distribution. Why would you want to go to just the log scale?

Splitting in a binary rather than linear fashion sounds like it could be promising if the Jacobians are more stable. They could still be arranged so that everything’s lower-triangular and the determinant remains tractable. Usually underflow’s not an issue, though, if that’s what you’re worrying about.

It would be easy enough to try. The code for the transforms is all very modular in the implementations (look for simplex_constrain and simplex_unconstrain functions).

I wasn’t aware of the categorical_logit, and I was looking for it’s equivalent for multinomial.

Though, the categorical logit seems to be operating on unconstrained k rather than k-1 scale.
CategoricalLogit(y|β) = Categorical(y|softmax(β)) - which is valuable for the specification of my model now, but less so when I move to the k-1 parameterization to resolve the identifiability issues.

I’ll keep the binary tree thing in mind if underflow becomes an issue. I presume this would start to show up as divergent transitions if it becomes an issue?

It actually might be interesting to allow arbitrary trees in the simplex construction, and then the unconstrained parameters could be meaningful. This sort of stuff comes up as “balances” in compositional data in geology and biology.

Yes, categorical_logit takes K-vectors, so it’s not identified. It’d be easy enough to add a K - 1 version.

We do run into problems out in the 1000s for simplexes and this balanced tree construction should be more effective there at making things symmetric. The trick in implementing it would be getting all the boundary conditions right in the recursion.

What sort of problems do you run into? I’m using a simplex with ~500-2000 parameters and want to keep an eye out for it. I am getting more divergent transitions than when I was using the weekly identifiable model but I’m not sure if that’s due to the parameter transformation or some other related change… It would be good to know what specifically to keep an eye out for - and if it turns out to be a problem, I can look into coding up the balanced tree type transformation.

See: https://github.com/stan-dev/stan/issues/2273

Thanks! With working test case as well. Looks like Jussi was running into that with optimizing, does this come up with the HMC as well?

I haven’t checked.

Thanks. I’m running it now - it’s a slow test case - and will post back here when it finishes.

Gradient evaluation took 0.66874 seconds
1000 transitions using 10 leapfrog steps per transition would take 6687.4 seconds.

I didn’t see any major issue running that model with HMC. The treedepth was very high at first but settled down after about 100 samples. There were no major changes between the beginning and end of the simplex in terms of estimated masses, Rhat, or parameter bias.

Good news and glad that you’re testing. I wonder why optimization is less robust. It may have something to do with the way L-BFGS approximates the Hessian interacting poorly with the simplex transform or with adaptation not being fast enough (I’d think it’d be the opposite and the problem would be too-fast adaptation).