Ragged array of simplexes

I meant to say just vector<lower=0>[K1 + K2]; in general, the elements do not need to be ordered. The exponential prior corresponds to a gamma prior with a shape of 1 and a scale of 1, which implies the resulting simplex vector has a uniform distribution. You can use other shapes but the prior is on the thing that you declare in parameters rather than the simplexes that you declare in the transformed parameters or the model block.

Your example would not parse. There is a more complicated example at


where zeta is what you call gamma and I am using the segment function to extract the next nc elements after zeta_mark from it.

1 Like

Thanks again Ben! So I think Iā€™ve managed to incorporate your feedback into a model (that compiles). Iā€™m pasting it below for anybody else who has the same question.

My remaining question is how to integrate the shape parameter like youā€™ve mentioned. As you point out, this example code is using a uniform distribution, but I really need to modulate the shape. Do I need an additional variable in transformed parameters or am I supposed to do some kind of initialization of the zeta?

data {
   int<lower=1> M;               // num observations
   int<lower=1> N;                // number of sub-observations for all observations
   real<lower=0> y[N];               //Observations where each row is a subobservation
   int<lower=1,upper=M> observation[N];    //observation id for subobservation n
   int<lower=1> row_to_subobservation_map[N];      //subobservation index for each row
   int<lower=1> count_of_subobservations_per_observation[M];
}

parameters {
   vector<lower=0,upper=1.0>[N] zeta;      // subobservation dist over observation
} 

model {
   
   for (n in 1:N) {
      
      int zeta_mark = n-row_to_subobservation_map[n]+1;
      int nc = count_of_subobservations_per_observation[observation[n]];
      
      vector[nc] pi = segment(zeta,zeta_mark,nc);  // gamma(zeta | shape, 1)
      pi = pi / sum(pi);                           // thus dirichlet(pi | shape)

      target += log(y[n]*pi[row_to_subobservation_map[n]]); // likelihood

  }
}
1 Like

zeta should not have an upper bound of 1. You just need to do

target += gamma_lpdf(zeta | shape, 1);

where shape is a positive vector that you pass in as data.

Iā€™m still struggling here. Do I really want to increase target with both pi as well as with zeta? For some reason that seems off to me, but Iā€™m pretty new at this so itā€™s quite possible.

When I use (a multivariate version of) the model below I get divergent sample errors. Is this model that youā€™re proposing? If so, Iā€™ll look into issues with incorporating my additional feature variables, else Iā€™ll try to fix my model.

data {
   int<lower=1> M;               // num observations
   int<lower=1> N;                // number of sub-observations for all observations
   real<lower=0> y[N];               //Observations where each row is a subobservation
   int<lower=1,upper=M> observation[N];    //observation id for subobservation n
   int<lower=1> row_to_subobservation_map[N];      //subobservation index for each row
   int<lower=1> count_of_subobservations_per_observation[M];
   real shape;     // subobservation weight prior

}

transformed data {
   vector<lower=0>[N] shape_vec;
   for (n in 1:N){
      shape_vec[n] = shape; //symmetric dirichlet prior
   }  
}

parameters {
   vector<lower=0>[N] zeta;      // subobservation dist over observation
} 

transformed parameters {

}

model {
   
   for (n in 1:N) {
      
      int zeta_mark = n-row_to_subobservation_map[n]+1;
      int nc = count_of_subobservations_per_observation[observation[n]];
      
      vector[nc] pi = segment(zeta,zeta_mark,nc);  // gamma(zeta | shape, 1)
      pi = pi / sum(pi);                           // thus dirichlet(pi | shape)

      target += log(y[n]*pi[row_to_subobservation_map[n]]); // likelihood
      
  }
  target += gamma_lpdf(zeta | shape, 1.);
}

No, just by zeta.

So where does pi fit in if Iā€™m only inferring zeta? How does this enforce the simplex?

The line pi = pi / sum (pi); implies pi sums to 1 and it was already nonnegative.

2 Likes

Bumping this old thread in hopes that someone can help me clarify one specific point. I would like to use @aaronjgā€™s identifiable formulation below, but I donā€™t entirely understand it.

data {
 int<lower=2> K;
 int<lower=0,upper=2147483647> x[K];
}
parameters {
 vector<lower=0>[K-1] gamma;
}
transformed parameters {
 real sum_gamma = 1 + sum(gamma);
 vector[K] pi = append_row(1,gamma) / sum_gamma;
}
model {
 target += -K *log(sum_gamma);
 target += dirichlet_lpdf(pi | rep_vector(1,K));
 target += multinomial_lpmf(x | pi);
}

Specifically, I gather that -K *log(sum_gamma) is the log Jacobian of the transform from gamma to pi, but Iā€™m not sure why thatā€™s the case. Iā€™ve tried to work it out, but TBH I get a bit lost already with formulating an invertible mapping from K-1 dimensions to K. Clearly the simplex pi is effectively only K-1-dimensional, but if you drop, say, pi[1] it seems like you would still need to include sum_gamma in order to solve for gamma, and then youā€™re back to K dimensions. Regardless of how I set it up, the total derivative winds up with a bunch of 1/sum_gamma^2 terms and I donā€™t see how it simplifies to the expression in the code.

Any hints would be greatly appreciated.

Think about the gamma vector as K dimensional - there are the original K-1 defined gammas and then the 1 which is implicitly there as well. the function from \mathbb{R}^k to \mathbb{R}^k is just dividing each component in this vector by sum_gamma, so the jacobian is diag(1/sum_gamma) which is a K \times K matrix so the determinant of the jacobian is (1/sum_gamma)^K and thus the log-determinant is -K *log(sum_gamma).

Thanks for the reply. That \mathbb{R}^k \to \mathbb{R}^k mapping makes sense, and itā€™s invertible because one element of the simplex \pi is the reciprocal of the sum \Gamma := \sum_{i=1}^K \gamma_{i}. Iā€™m still not sure I understand the Jacobian; for example,

J_{11} = \frac{\partial \pi_{1}}{\partial \gamma_{1}} = \frac{\partial}{\partial \gamma_{1}} \left[\frac{\gamma_{1}}{\Gamma}\right] = \frac{1 \cdot \Gamma - \gamma_{1} \cdot 1}{\Gamma^2} = \frac{1}{\Gamma} - \frac{\gamma_{1}}{\Gamma^2}.

Maybe I just need remedial calculus?

Iā€™ve been using a different approach based on simplex amalgamation. If K_{0} is the dimension of the largest simplex youā€™ll need, you can declare an array of length-K_{0} simplexes and form a smaller one, say of length K, by amalgamating the last (or first) K_{0} - K + 1 elements. An appropriate Dirichlet prior ensures that the amalgamation is simplex-uniform, which is what I want. Hereā€™s a toy example:

data {
  int<lower=2> K0;          // length of original simplex
  int<lower=1,upper=K0> K;  // length of amalgamated simplex
  int<lower=0> x[K];        // multinomial counts
}
parameters {
  simplex[K0] gamma;
}
transformed parameters {
  vector<lower=0>[K] pi = append_row(head(gamma, K - 1), sum(tail(gamma, K0 - K + 1)));
}
model {
  // prior on gamma that implies pi ~ Dir(rep_vector(1,K))
  gamma ~ dirichlet(append_row(rep_vector(1, K - 1), rep_vector(1.0/(K0 - K + 1), K0 - K + 1)));

  // likelihood
  x ~ multinomial(pi);
}

The amalgamated K_{0} - K + 1 elements of \gamma are nonidentified, but bounded because theyā€™re just sampled from the simplex prior. Seems to work fine in practice, and I find it easier to wrap my head around (YMMV).

@ebuhle: For the case of N simplexes with different dimensions I would need a array of simplexes( each of length K_0)

parameters {
  simplex[K0] gamma[N];
}

and then each will amalgamate differently. Does this approach would cause identifiability problems?

@Fabian_Crespo, that is actually my real use case, and it works fine. I need multiple simplices of each ā€œreducedā€ size K, so I construct the corresponding Dirichlet \alpha parameters in transformed data rather than doing it on the fly.

I implemented a model using Dirichlet amalgamations. But the model is not mixing well: there are divergent transitions. Using the diagnostic_summary() from cmdstanr and the Rhat tree far from 1:

Warning: 1286 of 5000 (26.0%) transitions ended with a divergence.
See https://mc-stan.org/misc/warnings for details.

Warning: 3639 of 5000 (73.0%) transitions hit the maximum treedepth limit of 10.
See https://mc-stan.org/misc/warnings for details.

$num_divergent
[1]  113 1173    0    0

$num_max_treedepth
[1] 1137    2 1250 1250

$ebfmi
[1] 0.6848023 0.8351373 1.9718730 1.3153290
 variable       mean     median     sd    mad         q5        q95 rhat ess_bulk ess_tail
 lp__                               -422968.28 -422966.00  10.40   8.90 -422988.00 -422954.00 1.17       16       34
 hyper_parameters_exponential[1]          0.03       0.01   0.10   0.01       0.00       0.07 1.42        8       17
 hyper_parameters_exponential[2]          0.01       0.00   0.02   0.00       0.00       0.04 1.95        5       19
 hyper_parameters_exponential[3]          0.02       0.02   0.03   0.01       0.00       0.05 1.19       20       64
 hyper_parameters_exponential[4]          1.23       0.02  11.85   0.02       0.00       0.23 1.70       10       13
 gammas[1]                              112.15     135.45  63.79  73.16      15.00     212.39 1.73        6       33
 gammas[2]                              154.23     136.42  98.46 117.14      44.01     327.95 1.83        5       18
 gammas[3]                               90.35      67.99  62.56  57.54      25.12     173.28 1.61        7       38
 gammas[4]                               39.80      43.15  16.49  11.39       3.17      61.53 1.69       11       12
 theta                                    2.75       2.83   0.70   0.85       1.30       3.47 1.70        6       20
 pop_sizes_proportion[1]                  0.16       0.18   0.05   0.05       0.08       0.22 1.49        8      114
 pop_sizes_proportion[2]                  0.17       0.18   0.07   0.09       0.07       0.29 1.99        5       13
 pop_sizes_proportion[3]                  0.21       0.23   0.07   0.06       0.09       0.32 1.71        6       14
 pop_sizes_proportion[4]                  0.45       0.46   0.05   0.04       0.37       0.54 1.56       12       85
 population_simplexes[1,1]                0.20       0.20   0.01   0.02       0.17       0.22 1.41        8       31
 population_simplexes[2,1]                0.21       0.22   0.02   0.02       0.17       0.24 1.21       14      121
 population_simplexes[3,1]                0.23       0.23   0.02   0.03       0.21       0.26 1.45        8      123
 population_simplexes[4,1]                0.34       0.35   0.07   0.06       0.19       0.43 1.65       10       31
 population_simplexes[1,2]                0.04       0.04   0.01   0.01       0.01       0.06 1.55       13       50
 population_simplexes[2,2]                0.02       0.03   0.01   0.01       0.00       0.04 1.60        7       53
 population_simplexes[3,2]                0.00       0.00   0.00   0.00       0.00       0.01 1.68       10       32
 population_simplexes[4,2]                0.07       0.07   0.01   0.01       0.04       0.09 1.59        9       31
 population_simplexes[1,3]                0.06       0.05   0.01   0.01       0.04       0.08 1.46        7       54
 population_simplexes[2,3]                0.02       0.02   0.02   0.02       0.00       0.07 1.32        9       36
 population_simplexes[3,3]                0.07       0.07   0.01   0.01       0.05       0.09 1.33       14       63
 population_simplexes[4,3]                0.00       0.00   0.00   0.00       0.00       0.00 1.36       10       37
 population_simplexes[1,4]                0.01       0.01   0.01   0.01       0.00       0.02 1.52       11       51
 population_simplexes[2,4]                0.09       0.09   0.02   0.03       0.05       0.12 1.33       10       65
 population_simplexes[3,4]                0.00       0.00   0.00   0.00       0.00       0.01 1.22       14      123
 population_simplexes[4,4]                0.00       0.00   0.00   0.00       0.00       0.00 1.39        8      149
 population_simplexes[1,5]                0.12       0.12   0.01   0.01       0.10       0.13 1.44        8       63
 population_simplexes[2,5]                0.01       0.00   0.01   0.00       0.00       0.02 1.19       15       53
 population_simplexes[3,5]                0.01       0.01   0.00   0.00       0.00       0.01 1.60        9       68
 population_simplexes[4,5]                0.00       0.00   0.00   0.00       0.00       0.00 1.53       12       62
 population_simplexes[1,6]                0.58       0.58   0.01   0.02       0.57       0.60 1.34       10      111
 population_simplexes[2,6]                0.22       0.21   0.03   0.03       0.17       0.26 1.57        7       14
 population_simplexes[3,6]                0.04       0.04   0.01   0.02       0.02       0.06 1.55        7       85
 population_simplexes[4,6]                0.05       0.05   0.01   0.01       0.03       0.07 1.52       10       33
 population_simplexes[1,7]                0.00       0.00   0.00   0.00       0.00       0.00 1.49        7       11
 population_simplexes[2,7]                0.00       0.00   0.00   0.00       0.00       0.00 1.57       29       63
 population_simplexes[3,7]                0.01       0.01   0.01   0.01       0.00       0.02 1.57        7       30
 population_simplexes[4,7]                0.02       0.02   0.00   0.00       0.01       0.03 1.70       17       27
 population_simplexes[1,8]                0.00       0.00   0.00   0.00       0.00       0.00 1.83        5       13
 population_simplexes[2,8]                0.31       0.31   0.03   0.02       0.24       0.34 1.18       17       70
 population_simplexes[3,8]                0.16       0.16   0.01   0.01       0.14       0.18 1.22       13       83
 population_simplexes[4,8]                0.00       0.00   0.00   0.00       0.00       0.00 1.56        7       42
 population_simplexes[1,9]                0.00       0.00   0.00   0.00       0.00       0.00 1.57        9       45
 population_simplexes[2,9]                0.00       0.00   0.02   0.00       0.00       0.01 1.44       14       31
 population_simplexes[3,9]                0.00       0.00   0.00   0.00       0.00       0.01 1.19       21       84
 population_simplexes[4,9]                0.10       0.10   0.02   0.02       0.06       0.13 1.59       10       32
 population_simplexes[1,10]               0.00       0.00   0.00   0.00       0.00       0.00 1.73       10       38
 population_simplexes[2,10]               0.00       0.00   0.02   0.00       0.00       0.02 1.18       17      160
 population_simplexes[3,10]               0.37       0.37   0.02   0.03       0.33       0.40 1.30       10       70
 population_simplexes[4,10]               0.00       0.00   0.00   0.00       0.00       0.00 1.65       62       49
 population_simplexes[1,11]               0.00       0.00   0.00   0.00       0.00       0.00 1.34        9       47
 population_simplexes[2,11]               0.00       0.00   0.02   0.00       0.00       0.01 1.34       50      102
 population_simplexes[3,11]               0.00       0.00   0.00   0.00       0.00       0.01 1.49        7       36
 population_simplexes[4,11]               0.00       0.00   0.00   0.00       0.00       0.00 1.31       10      112
 population_simplexes[1,12]               0.00       0.00   0.00   0.00       0.00       0.00 1.27       22       39
 population_simplexes[2,12]               0.00       0.00   0.02   0.00       0.00       0.01 1.48        7       41
 population_simplexes[3,12]               0.01       0.00   0.03   0.00       0.00       0.06 1.36       17       83
 population_simplexes[4,12]               0.00       0.00   0.00   0.00       0.00       0.00 1.29       13      107
 population_simplexes[1,13]               0.00       0.00   0.00   0.00       0.00       0.00 1.31       14      179
 population_simplexes[2,13]               0.01       0.00   0.02   0.00       0.00       0.03 1.83        5       18
 population_simplexes[3,13]               0.00       0.00   0.01   0.00       0.00       0.01 1.56       21       36
 population_simplexes[4,13]               0.04       0.04   0.01   0.01       0.02       0.05 1.60       13       29
 population_simplexes[1,14]               0.00       0.00   0.00   0.00       0.00       0.00 1.28       48      108
 population_simplexes[2,14]               0.00       0.00   0.02   0.00       0.00       0.01 1.59       14       82
 population_simplexes[3,14]               0.00       0.00   0.01   0.00       0.00       0.00 1.39       15       33
 population_simplexes[4,14]               0.00       0.00   0.00   0.00       0.00       0.00 1.45       18       32
 population_simplexes[1,15]               0.00       0.00   0.00   0.00       0.00       0.00 1.19       15      111
 population_simplexes[2,15]               0.00       0.00   0.02   0.00       0.00       0.00 1.32       10       54
 population_simplexes[3,15]               0.00       0.00   0.01   0.00       0.00       0.01 1.89        5       16
 population_simplexes[4,15]               0.01       0.00   0.04   0.00       0.00       0.08 1.39        8      169
 population_simplexes[1,16]               0.00       0.00   0.00   0.00       0.00       0.00 1.53        7       50
 population_simplexes[2,16]               0.00       0.00   0.02   0.00       0.00       0.02 1.30       11      156
 population_simplexes[3,16]               0.00       0.00   0.01   0.00       0.00       0.00 1.52        7       30
 population_simplexes[4,16]               0.04       0.00   0.07   0.00       0.00       0.18 1.43        8       83
 population_simplexes[1,17]               0.00       0.00   0.00   0.00       0.00       0.00 1.20       14       20
 population_simplexes[2,17]               0.01       0.00   0.02   0.00       0.00       0.01 1.40        8      130
 population_simplexes[3,17]               0.00       0.00   0.01   0.00       0.00       0.01 1.40        9      106
 population_simplexes[4,17]               0.01       0.00   0.04   0.00       0.00       0.05 1.35        9       32
 population_simplexes[1,18]               0.00       0.00   0.00   0.00       0.00       0.00 1.20       14       63
 population_simplexes[2,18]               0.00       0.00   0.01   0.00       0.00       0.00 1.37       75       40
 population_simplexes[3,18]               0.00       0.00   0.02   0.00       0.00       0.02 1.11       29       76
 population_simplexes[4,18]               0.02       0.00   0.06   0.00       0.00       0.07 1.32       10       52
 population_simplexes[1,19]               0.00       0.00   0.00   0.00       0.00       0.00 1.25       12      109
 population_simplexes[2,19]               0.01       0.00   0.02   0.00       0.00       0.02 1.32       10      114
 population_simplexes[3,19]               0.00       0.00   0.01   0.00       0.00       0.00 1.77       26       95
 population_simplexes[4,19]               0.01       0.00   0.04   0.00       0.00       0.03 1.24       15       70
 population_simplexes[1,20]               0.00       0.00   0.00   0.00       0.00       0.00 1.35        9      109
 population_simplexes[2,20]               0.00       0.00   0.01   0.00       0.00       0.01 1.20       14       86
 population_simplexes[3,20]               0.01       0.00   0.02   0.00       0.00       0.04 1.35       15       74
 population_simplexes[4,20]               0.01       0.00   0.05   0.00       0.00       0.04 1.22       14      111
 population_simplexes[1,21]               0.00       0.00   0.00   0.00       0.00       0.00 1.74        8       63
 population_simplexes[2,21]               0.00       0.00   0.02   0.00       0.00       0.01 1.66       37       61
 population_simplexes[3,21]               0.01       0.00   0.02   0.00       0.00       0.03 1.98        8       19
 population_simplexes[4,21]               0.02       0.00   0.04   0.00       0.00       0.05 1.67        6       48
 population_simplexes[1,22]               0.00       0.00   0.00   0.00       0.00       0.00 1.49        8       42
 population_simplexes[2,22]               0.00       0.00   0.02   0.00       0.00       0.01 1.33       10       43
 population_simplexes[3,22]               0.00       0.00   0.01   0.00       0.00       0.00 1.23       13      128
 population_simplexes[4,22]               0.00       0.00   0.03   0.00       0.00       0.00 1.77        6       26
 population_simplexes[1,23]               0.00       0.00   0.00   0.00       0.00       0.00 1.17       24      156
 population_simplexes[2,23]               0.00       0.00   0.01   0.00       0.00       0.00 1.75        6       12
 population_simplexes[3,23]               0.00       0.00   0.01   0.00       0.00       0.01 1.48       33      134
 population_simplexes[4,23]               0.02       0.00   0.04   0.00       0.00       0.08 1.65        6       35
 population_simplexes[1,24]               0.00       0.00   0.00   0.00       0.00       0.00 1.45        9      104
 population_simplexes[2,24]               0.01       0.00   0.02   0.00       0.00       0.05 1.42       11       39
 population_simplexes[3,24]               0.00       0.00   0.01   0.00       0.00       0.00 1.21       43      113
 population_simplexes[4,24]               0.01       0.00   0.04   0.00       0.00       0.04 1.32       10      148
 population_simplexes[1,25]               0.00       0.00   0.00   0.00       0.00       0.00 1.47       15      109
 population_simplexes[2,25]               0.00       0.00   0.02   0.00       0.00       0.01 1.52       71      148
 population_simplexes[3,25]               0.00       0.00   0.01   0.00       0.00       0.01 1.76        6       46
 population_simplexes[4,25]               0.01       0.00   0.03   0.00       0.00       0.02 1.59       10       62
 population_simplexes[1,26]               0.00       0.00   0.00   0.00       0.00       0.00 1.24       13      157
 population_simplexes[2,26]               0.00       0.00   0.01   0.00       0.00       0.01 1.51        7       75
 population_simplexes[3,26]               0.01       0.00   0.02   0.00       0.00       0.03 1.33       10       34
 population_simplexes[4,26]               0.03       0.00   0.08   0.00       0.00       0.22 1.62        6       58
 population_simplexes[1,27]               0.00       0.00   0.00   0.00       0.00       0.00 1.66        6       27
 population_simplexes[2,27]               0.00       0.00   0.01   0.00       0.00       0.00 1.12       24       85
 population_simplexes[3,27]               0.00       0.00   0.01   0.00       0.00       0.00 1.54        7       47
 population_simplexes[4,27]               0.03       0.00   0.07   0.00       0.00       0.19 1.43        8       43
 population_simplexes[1,28]               0.00       0.00   0.00   0.00       0.00       0.00 1.34       10       31
 population_simplexes[2,28]               0.01       0.00   0.02   0.00       0.00       0.03 1.43        8       62
 population_simplexes[3,28]               0.00       0.00   0.02   0.00       0.00       0.02 1.47        8       89
 population_simplexes[4,28]               0.02       0.00   0.04   0.00       0.00       0.09 1.21       13       85
 population_simplexes[1,29]               0.00       0.00   0.00   0.00       0.00       0.00 1.50       12       43
 population_simplexes[2,29]               0.00       0.00   0.02   0.00       0.00       0.00 1.30       10       50
 population_simplexes[3,29]               0.00       0.00   0.02   0.00       0.00       0.03 1.30       41       55
 population_simplexes[4,29]               0.01       0.00   0.05   0.00       0.00       0.10 1.62       10       37
 population_simplexes[1,30]               0.00       0.00   0.00   0.00       0.00       0.00 1.30       15       21
 population_simplexes[2,30]               0.00       0.00   0.02   0.00       0.00       0.03 1.36        9       56
 population_simplexes[3,30]               0.00       0.00   0.01   0.00       0.00       0.00 1.56        7       31
 population_simplexes[4,30]               0.01       0.00   0.05   0.00       0.00       0.07 1.63       10       81
 population_simplexes[1,31]               0.00       0.00   0.00   0.00       0.00       0.00 1.59       42       71
 population_simplexes[2,31]               0.00       0.00   0.02   0.00       0.00       0.02 1.51       33       50
 population_simplexes[3,31]               0.00       0.00   0.02   0.00       0.00       0.02 1.34       10       66
 population_simplexes[4,31]               0.01       0.00   0.05   0.00       0.00       0.08 1.48       12       50
 population_simplexes[1,32]               0.00       0.00   0.00   0.00       0.00       0.00 1.49       28       46
 population_simplexes[2,32]               0.00       0.00   0.02   0.00       0.00       0.01 1.36        9       81
 population_simplexes[3,32]               0.00       0.00   0.01   0.00       0.00       0.00 1.23       16       95
 population_simplexes[4,32]               0.04       0.00   0.06   0.00       0.00       0.15 1.33       10       43
 population_simplexes[1,33]               0.00       0.00   0.00   0.00       0.00       0.00 1.34       17       45
 population_simplexes[2,33]               0.00       0.00   0.02   0.00       0.00       0.02 1.19       16       72
 population_simplexes[3,33]               0.00       0.00   0.02   0.00       0.00       0.02 1.62        9       25
 population_simplexes[4,33]               0.02       0.00   0.06   0.00       0.00       0.14 1.65        6       31
 population_simplexes[1,34]               0.00       0.00   0.00   0.00       0.00       0.00 1.23       13       39
 population_simplexes[2,34]               0.00       0.00   0.02   0.00       0.00       0.02 1.43       13       28
 population_simplexes[3,34]               0.01       0.00   0.02   0.00       0.00       0.04 1.24       12       51 population_simplexes[4,34]               0.01       0.00   0.05   0.00       0.00       0.08 1.32       10       32
 population_simplexes[1,35]               0.00       0.00   0.00   0.00       0.00       0.00 1.50       11       32
 population_simplexes[2,35]               0.00       0.00   0.02   0.00       0.00       0.00 1.44        8       32
 population_simplexes[3,35]               0.00       0.00   0.01   0.00       0.00       0.00 1.27       12      181
 population_simplexes[4,35]               0.01       0.00   0.04   0.00       0.00       0.05 1.15       20      232
 population_simplexes[1,36]               0.00       0.00   0.00   0.00       0.00       0.00 1.34        9       48
 population_simplexes[2,36]               0.00       0.00   0.02   0.00       0.00       0.02 1.35        9       59
 population_simplexes[3,36]               0.01       0.00   0.03   0.00       0.00       0.04 1.16       20       86
 population_simplexes[4,36]               0.01       0.00   0.04   0.00       0.00       0.05 1.12       24       73
 population_simplexes[1,37]               0.00       0.00   0.00   0.00       0.00       0.00 1.17       24       55
 population_simplexes[2,37]               0.00       0.00   0.01   0.00       0.00       0.01 1.42       10       55
 population_simplexes[3,37]               0.01       0.00   0.02   0.00       0.00       0.03 1.32       10       56
 population_simplexes[4,37]               0.01       0.00   0.05   0.00       0.00       0.08 1.50       20      115
 population_simplexes[1,38]               0.00       0.00   0.00   0.00       0.00       0.00 1.53        7       69
 population_simplexes[2,38]               0.00       0.00   0.02   0.00       0.00       0.01 1.13       39       39
 population_simplexes[3,38]               0.00       0.00   0.01   0.00       0.00       0.01 1.14       22      201
 population_simplexes[4,38]               0.01       0.00   0.04   0.00       0.00       0.02 1.42        8       46
 population_simplexes[1,39]               0.00       0.00   0.00   0.00       0.00       0.00 1.47        7       34
..............................................................................................................................................
 pi[1]                                    0.20       0.20   0.01   0.02       0.17       0.22 1.41        8       31
 pi[2]                                    0.04       0.04   0.01   0.01       0.01       0.06 1.55       13       50
 pi[3]                                    0.06       0.05   0.01   0.01       0.04       0.08 1.46        7       54
 pi[4]                                    0.01       0.01   0.01   0.01       0.00       0.02 1.52       11       51
 pi[5]                                    0.12       0.12   0.01   0.01       0.10       0.13 1.44        8       63
 pi[6]                                    0.58       0.58   0.01   0.02       0.57       0.60 1.34       10      111
 pi[7]                                    0.00       0.00   0.00   0.00       0.00       0.01 1.22       21       88
 pi[8]                                    0.21       0.22   0.02   0.02       0.17       0.24 1.21       14      121
 pi[9]                                    0.02       0.03   0.01   0.01       0.00       0.04 1.60        7       53
 pi[10]                                   0.02       0.02   0.02   0.02       0.00       0.07 1.32        9       36
 pi[11]                                   0.09       0.09   0.02   0.03       0.05       0.12 1.33       10       65
 pi[12]                                   0.01       0.00   0.01   0.00       0.00       0.02 1.19       15       53
 pi[13]                                   0.22       0.21   0.03   0.03       0.17       0.26 1.57        7       14
 pi[14]                                   0.00       0.00   0.00   0.00       0.00       0.00 1.57       29       63
 pi[15]                                   0.31       0.31   0.03   0.02       0.24       0.34 1.18       17       70
 pi[16]                                   0.11       0.11   0.09   0.11       0.02       0.30 1.43        8       68
 pi[17]                                   0.23       0.23   0.02   0.03       0.21       0.26 1.45        8      123
 pi[18]                                   0.00       0.00   0.00   0.00       0.00       0.01 1.68       10       32
 pi[19]                                   0.07       0.07   0.01   0.01       0.05       0.09 1.33       14       63
 pi[20]                                   0.00       0.00   0.00   0.00       0.00       0.01 1.22       14      123
 pi[21]                                   0.01       0.01   0.00   0.00       0.00       0.01 1.60        9       68
 pi[22]                                   0.04       0.04   0.01   0.02       0.02       0.06 1.55        7       85
 pi[23]                                   0.01       0.01   0.01   0.01       0.00       0.02 1.57        7       30
 pi[24]                                   0.16       0.16   0.01   0.01       0.14       0.18 1.22       13       83
 pi[25]                                   0.00       0.00   0.00   0.00       0.00       0.01 1.19       21       84
 pi[26]                                   0.37       0.37   0.02   0.03       0.33       0.40 1.30       10       70
 pi[27]                                   0.00       0.00   0.00   0.00       0.00       0.01 1.49        7       36
 pi[28]                                   0.10       0.12   0.06   0.08       0.03       0.18 1.58        6       75
 pi[29]                                   0.34       0.35   0.07   0.06       0.19       0.43 1.65       10       31
 pi[30]                                   0.07       0.07   0.01   0.01       0.04       0.09 1.59        9       31
 pi[31]                                   0.00       0.00   0.00   0.00       0.00       0.00 1.36       10       37
 pi[32]                                   0.00       0.00   0.00   0.00       0.00       0.00 1.39        8      149
 pi[33]                                   0.00       0.00   0.00   0.00       0.00       0.00 1.53       12       62
 pi[34]                                   0.05       0.05   0.01   0.01       0.03       0.07 1.52       10       33
 pi[35]                                   0.02       0.02   0.00   0.00       0.01       0.03 1.70       17       27
 pi[36]                                   0.00       0.00   0.00   0.00       0.00       0.00 1.56        7       42
 pi[37]                                   0.10       0.10   0.02   0.02       0.06       0.13 1.59       10       32
 pi[38]                                   0.00       0.00   0.00   0.00       0.00       0.00 1.65       62       49
 pi[39]                                   0.00       0.00   0.00   0.00       0.00       0.00 1.31       10      112
 pi[40]                                   0.00       0.00   0.00   0.00       0.00       0.00 1.29       13      107
 pi[41]                                   0.04       0.04   0.01   0.01       0.02       0.05 1.60       13       29
 pi[42]                                   0.00       0.00   0.00   0.00       0.00       0.00 1.45       18       32
 pi[43]                                   0.38       0.36   0.12   0.11       0.21       0.65 1.67       10       32
 unscaled_gammas[1]                     641.95     655.12 285.21 247.43     149.94    1122.69 1.50        8       27
 unscaled_gammas[2]                     840.23     792.33 286.86 283.93     280.06    1264.88 1.61        7       17
 unscaled_gammas[3]                     407.35     353.80 187.43 213.89     142.71     641.60 1.54        7       21
 unscaled_gammas[4]                      87.19      93.34  35.03  25.26       6.97     132.16 1.72       10       12

My question: how to solve this? More iterations ? (the evaluation of the likelihood is computationally expensive)

More iterations arenā€™t sufficient ā€” the issue with divergences is that they indicate areas of parameter space which are difficult for HMC to explore. As a result, samples containing divergences can indicate biased parameter estimates, as the sampler systematically avoids areas of high curvature.

Iā€™d try out the re-parametrization options listed above to start. If that doesnā€™t work, you might want to dig into model diagnostics (e.g. pairs plots, parcoord plots) to see if you can spot any degeneracies there.

Itā€™s tough to say anything concrete without the actual model + diagnostic plots in front of me.

PS: I donā€™t expect it to fix all your problems, but divergences seem to be dependent on the order of the values in the Dirichlet simplex (I have no clue why this would be the case, but itā€™s documented e.g. here, and Iā€™ve noticed this myself ā€¦)

1 Like

Thanks for the answer.

For the following simpler model with only one simplex vector, I got no divergent transitions. The simplex vector simplex1 is used to build a positive ordered vector called coal_event_times with values from 0 totorigin(by taking the vector of cumulative sums of the simplex vector). The log determinant of the Jacobian transformation is n*log(torigin), In this case, no explicit prior Dirichlet distribution on simplex1 is set.) A costum distribution is set on the transformed coal_event_times in the model section.

parameters{
  real<lower=0.0> hyper_parameters_exponential;

  real<lower=0.01> gamma;
  real<lower=0.0> torigin;
  real<lower=0.0> theta;

  simplex[total_sample_size] simplex1;

}
transformed parameters {
  //declarations
  positive_ordered[total_sample_size-1] coal_event_times;
  coal_event_times =head(cumulative_sum(simplex1),total_sample_size-1) * torigin;
}
model{
  hyper_parameters_exponential ~ gamma(0.001,0.001);
  gamma ~ exponential(hyper_parameters_exponential);
  theta ~ exponential(1);


  J_diag = rep_vector(torigin, total_sample_size-1);
  target += sum(log(J_diag));

coal_event_times~structured_coalescent([gamma]', [torigin]',
                    [torigin]',
                    [1.0]', {total_sample_size}, K);
...........
//compute likelihood using coal_event_times and data
//target += ....

but with the model of ragged array of simplexes, where I trying to do the same cumulative sum transformation for all the simplex vectors, I got divergent transitions. For that model I explicitly set a Dirichlet prior distribution on population_simplexes following the amalgamation method:

parameters{
  real<lower=0.0> hyper_parameters_exponential[N];

  vector<lower=0.001>[N] gammas;

  
  real<lower=0.0> theta;
  real<lower=0.0>  torigin_oldest_population;
  
  simplex[N] pop_sizes_proportion;
  simplex[total_sample_size] population_simplexes[N];
  vector<lower=0, upper=1>[N-1] positions_torigins_in_intervals; 

}
transformed parameters {
  //declarations

  vector[total_sample_size-1] coal_event_times_physical_time;
  vector[total_sample_size-1] coal_event_times_model_time;
  vector<lower=0.0>[N] torigins_in_model_time;
  vector<lower=0.0>[N] torigins_in_physical_time;
  vector<lower=0>[sum(number_internal_nodes)+N] pi;

  torigins_in_model_time[N] = torigin_oldest_population;
  torigins_in_physical_time[N] = torigins_in_model_time[N]* pop_sizes_proportion[N];
 

   for(i in 1:N){
       int start_idx=1;
       int end_idx=1;
       int no_internals ;
        if (i>=2){
          start_idx = cum_sample_sizes[i-1]+1+cum_number_child_pops[i-1];
          
        }
        end_idx = cum_sample_sizes[i]+cum_number_child_pops[i];
  
       no_internals = number_internal_nodes[i];
   

       pi[start_idx:end_idx] = append_row(head(population_simplexes[i], no_internals), sum(tail(population_simplexes[i], total_sample_size - no_internals  )));
       
      
   }

   for (k in 1:(N-1)) {
      int j= N-k;
      int idx_parent_pop = indexes_father_populations[j];
      vector [number_internal_nodes[idx_parent_pop]+1] pi_father_pop;
      vector[pos_parent_child_pop_MRCA_in_parent_pop[j]] head_cum_sum;
      real lower_coal_time=0;
      real upper_coal_time;
      real unif;
      int end_idx=1;
      int no_internals = number_internal_nodes[idx_parent_pop-1];
 
      int start_idx=1;
        if (idx_parent_pop>=2){
          start_idx = cum_sample_sizes[idx_parent_pop-1]+1+cum_number_child_pops[idx_parent_pop-1];
        }
       end_idx = cum_sample_sizes[idx_parent_pop]+cum_number_child_pops[idx_parent_pop];


       pi_father_pop = pi[start_idx:end_idx];

      head_cum_sum= head(cumulative_sum(pi_father_pop), pos_parent_child_pop_MRCA_in_parent_pop[j]);
 
      

        lower_coal_time = 0;
     
   


      upper_coal_time = head_cum_sum[pos_parent_child_pop_MRCA_in_parent_pop[j]]; 
      unif= positions_torigins_in_intervals[j];

      torigins_in_physical_time[j] =  torigins_in_physical_time[idx_parent_pop] *unif*upper_coal_time ;

    
      torigins_in_model_time[j] = torigins_in_physical_time[j] / pop_sizes_proportion[j];
    
   }


   for (j in 1:N) {
        int start_idx=1;
        int end_idx=1;
        int start_coal_idx=1;
        int end_coal_idx=1;
        if (j>=2){
          start_idx =cum_sample_sizes[j-1]+1+cum_number_child_pops[j-1];
          start_coal_idx= cum_internal_nodes[j-1]+1;
        }
        end_idx = cum_sample_sizes[j]+cum_number_child_pops[j];

        end_coal_idx= cum_internal_nodes[j];
       
       coal_event_times_model_time[start_coal_idx:end_coal_idx] = head(cumulative_sum(pi[start_idx:end_idx]), number_internal_nodes[j]) *  torigins_in_model_time[j];
    
       coal_event_times_physical_time[start_coal_idx:end_coal_idx] = head(cumulative_sum(pi[start_idx:end_idx]), number_internal_nodes[j]) *  torigins_in_physical_time[j]  ;
     
    }
}
model{
............
    
  hyper_parameters_exponential~gamma(rep_vector(0.001,N), rep_vector(0.001,N));
  gammas ~ exponential(hyper_parameters_exponential);


  theta ~ exponential(1);


  pop_sizes_proportion~ dirichlet(to_vector(sample_sizes));

  
  torigin_oldest_population~conditionalDensityTOrigin(gammas[N], sample_sizes[N]);

   for (j in 1:N) {
     population_simplexes[j]~ dirichlet(append_row(rep_vector(1, number_internal_nodes[j] ), 
                     rep_vector(1.0/(total_sample_size - number_internal_nodes[j]), total_sample_size  - number_internal_nodes[j])));
     
   }

  J_diag = rep_vector_times(torigins_in_model_time, number_internal_nodes);
  //this is a Jacobian from the simplexes to the  coalescent event times 
 

  target += sum(log(J_diag));

 for (j in 1:(N-1)) {
    int index_father_pop = indexes_father_populations[j];
    int number_internal_nodes_parent_pop = number_internal_nodes[index_father_pop];
    vector[pos_parent_child_pop_MRCA_in_parent_pop[j]] head_cum_sum;

    
    vector[number_internal_nodes_parent_pop] father_pop_simplex = population_simplexes[index_father_pop][1:number_internal_nodes_parent_pop];
    head_cum_sum= head(cumulative_sum(father_pop_simplex), pos_parent_child_pop_MRCA_in_parent_pop[j]);

   
}


coal_event_times_model_time~structured_coalescent(gammas, torigins_in_model_time,
                    torigins_in_physical_time,
                    pop_sizes_proportion, sample_sizes, number_internal_nodes, indexes_father_populations, pos_parent_child_pop_MRCA_in_parent_pop, cum_internal_nodes, 
                    number_child_pops,
                    child_pop_indexes,
                    K);
//.............
}

My question is: should I drop the explicit Dirichlet prior distribution to avoid the divergent transitions? As. @betanalpha mentioned I donā€™t want to sample directly from the Dirichlet distribution, but just use the simplex vectors to build positive ordered vectors between 0 and some value T that follow another distribution.

Hi @Fabian_Crespo, itā€™s hard to say exactly whatā€™s causing the divergences based on the information given, and I donā€™t have the bandwidth to understand your more complex model in detail. But as a general suggestion, I would try starting with the simpler one that works and building up complexity incrementally. For example, you might try constructing simplex1 by amalgamating a larger simplex. If that works, try two simplices of different sizes, etc. Another approach might be to strip the likelihood down to its bare bones and focus on getting the simplex amalgamation working first.

That sounds like a bad idea insofar as it changes the prior on the amalgamated simplex to something you (presumably) didnā€™t intend, but it might be useful for diagnostic purposes.

3 Likes

I want to use @aaronjgā€™s concept but wanted to check the Jacobian after your comment. I think I figured it outā€¦ First, I think we should work in k-1 dimensions (see for example the simplex transform in the Stan manual [1]). The map from R^{k-1}_{\geq 0} to the first k-1 coordinates of the simplex is then:

(\gamma_1,...,\gamma_{k-1}) \mapsto (1+\Gamma)^{-1}(\gamma_1,\ldots,\gamma_{k-1})

where \Gamma = 1 + \sum_{i=1}^{k-1} \gamma_i. Then the ij entry of the Jacobian is

\frac{\partial}{\partial \gamma_j} \frac{\gamma_i}{1+\Gamma} = \frac{(1 + \Gamma)\delta_{ij} - \gamma_i}{(1+\Gamma)^2}.

Then we can write the Jacobian as

(1+\Gamma)^{-2}\left[(1+\Gamma)\textbf{I} - \vec{1}\vec{\gamma}^T \right]

where \textbf{I} is the k-1 identity matrix and \vec{1} is the k-1 unit vector.

When k=2, this is a scalar and it is just (1+\gamma_1)^{-2}, the first hint that all is well. When k=3 we have a 2x2 matrix and we can also easily check. More generally, notice that this is (a multiple of) the identity, plus a rank one update. There is a nice formula for computing such a determinant and-- surprise, surprise!-- it comes out to (1 + \Gamma)^{-k}. Iā€™ve probably made at least two mistakes here but they canceled out :)

Iā€™d be curious to know why Stan prefers the stick breaking paramterization over this one, as this one seems to have a simpler determinant.

[1] For those of us who donā€™t do this everyday, we might want some justificationā€¦First, if instead of a simplex we were parameterizing a non-linear subspace of R^k we would need to correct for its non-constant curvature. Second, if the distribution is on all of R^k (my application), we are conditioning on \sum \pi_i = 1. The reason we get away with this is that Stan doesnā€™t care about the constant P(\sum \pi_i = 1) term in the log-likelihood. If the condition was not constant (e.g. it was a simplex that summed to Z instead of 1, and Z was a parameter not data) we would be in trouble, I think.

1 Like

I realize you likely already recognize this, but in case somebody comes along trying to figure out how to parameterize a positive vector that sums to Z, a simple option is to parameterize in terms of a simplex and Z, and then to express your prior as a prior over the simplex and a prior over Z (assuming itā€™s acceptable to write down a prior in which Z is independent of the simplex), thus avoiding the need for any additional Jacobian adjustment beyond the one required to yield the simplex.

1 Like

@jsocolar Yes that is a good clarification! In my particular case \log(\pi) \sim MVN(\mu, \Sigma) and \sum \pi = Z so I donā€™t think such a factorization is possible. Luckily Z is data not parameter so it didnā€™t matter!

By the way, do you know why Stan prefers the stick breaking transform? This one seems simplerā€¦ Iā€™m guessing the the geometry is worse for HMC or something?

Edit: this was easy enough to verify with a simple program comparing the two approaches. When k is small I didnā€™t see a difference but for large k (I tried 100), the stick breaking transform approach ran 2x slower but with 10x ESS.

1 Like