Hello,
I have done some tests comparing soft vs hard sum-to-zero constrain for a simplex regression. The simulated data is really defined having incredibly low error
HARD
148.82 seconds (Total)
mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
weights[1,1] -1.1458828 0.0005821308 0.007623225 -1.1599872 -1.1514000 -1.1455387 -1.1406438 -1.1312168 171.4891 0.9993008
weights[1,2] -0.4548668 0.0003671708 0.005209648 -0.4658918 -0.4585181 -0.4546926 -0.4514349 -0.4447093 201.3171 0.9986171
weights[1,3] 0.2378516 0.0004638308 0.005466593 0.2277121 0.2344816 0.2379877 0.2412291 0.2504590 138.9039 1.0068702
weights[1,4] 1.3904648 0.0003544976 0.004300609 1.3823475 1.3877333 1.3905326 1.3930715 1.3992900 147.1747 1.0046746
weights[1,5] 0.9263817 0.0002775686 0.003620667 0.9190944 0.9238915 0.9263782 0.9289837 0.9336622 170.1518 1.0000263
weights[2,1] 0.2195011 0.0008506887 0.010793117 0.1986463 0.2123535 0.2191074 0.2271519 0.2411544 160.9728 0.9996466
weights[2,2] -0.7747972 0.0002818525 0.005537706 -0.7853681 -0.7783540 -0.7751138 -0.7715074 -0.7628066 386.0255 0.9979987
weights[2,3] -2.7699840 0.0006716145 0.009954892 -2.7892922 -2.7763711 -2.7700132 -2.7634324 -2.7494017 219.7015 1.0006988
weights[2,4] -0.7760190 0.0003541681 0.005065385 -0.7863347 -0.7793836 -0.7757947 -0.7726554 -0.7662422 204.5530 1.0045940
weights[2,5] 1.2296639 0.0002174585 0.003480065 1.2226467 1.2274545 1.2297332 1.2320280 1.2366653 256.1072 0.9994862
SOFT | sum( weights[i] ) ~ normal(0,0.0001 * ncat)
274.33 seconds (Total)
mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
weights[1,1] -1.3391406 2.293569e-04 0.004116341 -1.34689154 -1.3419479 -1.33901422 -1.33620911 -1.33157701 322.1062 1.0021981
weights[1,2] -0.6460621 1.665210e-04 0.003204347 -0.65266495 -0.6483706 -0.64596065 -0.64375407 -0.63989639 370.2892 0.9981509
weights[1,3] 0.0487602 1.747026e-04 0.003143532 0.04266207 0.0467480 0.04875449 0.05087039 0.05509691 323.7701 0.9989046
weights[1,4] 1.2004083 9.242863e-05 0.001855101 1.19674888 1.1992482 1.20052803 1.20162172 1.20403749 402.8302 0.9983294
weights[1,5] 0.7360301 9.226777e-05 0.001954185 0.73247408 0.7346718 0.73600215 0.73736124 0.73968258 448.5710 0.9988598
weights[2,1] 0.7972074 4.435837e-04 0.007603421 0.78291908 0.7918101 0.79735841 0.80248565 0.81130783 293.8102 0.9998841
weights[2,2] -0.2007952 3.901609e-04 0.006597414 -0.21296538 -0.2057596 -0.20103038 -0.19646600 -0.18663437 285.9302 0.9980394
weights[2,3] -2.2042454 4.989408e-04 0.008751689 -2.22082484 -2.2106547 -2.20442567 -2.19815306 -2.18722835 307.6703 0.9980393
weights[2,4] -0.1993981 2.050149e-04 0.003841183 -0.20697032 -0.2019425 -0.19949765 -0.19673615 -0.19200305 351.0421 0.9990791
weights[2,5] 1.8071866 1.857058e-04 0.003765082 1.79999651 1.8046481 1.80760927 1.80970585 1.81433016 411.0529 0.9987926
SOFT | sum( weights[i] ) ~ normal(0,0.001 * ncat)
79.38 seconds (Total)
mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
weights[1,1] -1.33906679 0.0002355011 0.004312584 -1.34714862 -1.34190566 -1.33907841 -1.33642318 -1.33011286 335.3429 1.0079406
weights[1,2] -0.64618862 0.0001428191 0.003193533 -0.65236610 -0.64832035 -0.64620616 -0.64413122 -0.64008974 500.0000 0.9983157
weights[1,3] 0.04869102 0.0001630507 0.003289347 0.04241653 0.04658683 0.04871578 0.05059692 0.05539219 406.9810 0.9984773
weights[1,4] 1.20037635 0.0001118742 0.002209674 1.19604522 1.19900284 1.20043974 1.20183060 1.20452307 390.1188 1.0010383
weights[1,5] 0.73600669 0.0001068532 0.002141574 0.73174139 0.73448537 0.73612169 0.73737938 0.74006021 401.6900 1.0019666
weights[2,1] 0.79718099 0.0004392946 0.007917254 0.78134531 0.79165344 0.79744736 0.80230182 0.81249897 324.8159 1.0049828
weights[2,2] -0.20085779 0.0003315689 0.005952465 -0.21337641 -0.20493411 -0.20081094 -0.19664387 -0.18971984 322.2895 0.9980259
weights[2,3] -2.20422130 0.0004126981 0.008191094 -2.21970960 -2.20939100 -2.20411915 -2.19900825 -2.18731448 393.9298 0.9982997
weights[2,4] -0.19941701 0.0002299476 0.004042518 -0.20745008 -0.20198728 -0.19939631 -0.19665669 -0.19174443 309.0624 1.0004030
weights[2,5] 1.80712068 0.0001937383 0.003534819 1.80038185 1.80474004 1.80740252 1.80927082 1.81423255 332.8918 1.0021593
SOFT | sum( weights[i] ) ~ normal(0,0.01 * ncat)
140.97 seconds (Total)
mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
weights[1,1] -1.33955063 0.0008641711 0.01110507 -1.36052415 -1.34727240 -1.33957167 -1.33235902 -1.31665335 165.1364 1.0017526
weights[1,2] -0.64652730 0.0008300391 0.01070266 -0.66796289 -0.65353296 -0.64644981 -0.64006590 -0.62298819 166.2593 1.0066967
weights[1,3] 0.04793137 0.0008624810 0.01088247 0.02781052 0.04031121 0.04761735 0.05478967 0.07136967 159.2047 1.0052243
weights[1,4] 1.19978972 0.0008158316 0.01044942 1.17947424 1.19230257 1.19978662 1.20652287 1.22099937 164.0527 1.0063044
weights[1,5] 0.73554133 0.0008443594 0.01048765 0.71558505 0.72851332 0.73577050 0.74265520 0.75798619 154.2771 1.0061033
weights[2,1] 0.79638481 0.0009657271 0.01240122 0.77180348 0.78893469 0.79593330 0.80412887 0.81997343 164.8996 1.0064321
weights[2,2] -0.20158482 0.0009524633 0.01252942 -0.22578230 -0.21016327 -0.20207057 -0.19368810 -0.17667237 173.0474 0.9994586
weights[2,3] -2.20447889 0.0008962483 0.01329836 -2.23148403 -2.21309907 -2.20474485 -2.19592306 -2.17677002 220.1604 0.9995204
weights[2,4] -0.19983153 0.0009872462 0.01170564 -0.22269266 -0.20701587 -0.20031971 -0.19219750 -0.17605516 140.5851 1.0014370
weights[2,5] 1.80641418 0.0009663525 0.01159318 1.78408569 1.79912203 1.80546777 1.81424173 1.82954490 143.9243 1.0020987
SOFT | sum( weights[i] ) ~ normal(0,0.1 * ncat)
858.51 seconds (Total)
mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
weights[1,1] -1.32546963 0.008589777 0.09236797 -1.5237648 -1.3866248889 -1.32259656 -1.2604053 -1.161136991 115.6323 0.9986507
weights[1,2] -0.63236712 0.008551023 0.09215380 -0.8263866 -0.6930839905 -0.62824401 -0.5679193 -0.467462170 116.1423 0.9985147
weights[1,3] 0.06198528 0.008606207 0.09239113 -0.1359249 -0.0001658668 0.06443147 0.1283943 0.229030744 115.2489 0.9985782
weights[1,4] 1.21397869 0.008559155 0.09211254 1.0158682 1.1538334071 1.21748272 1.2796228 1.379548700 115.8179 0.9986014
weights[1,5] 0.74969554 0.008550612 0.09208181 0.5501582 0.6877385938 0.75286649 0.8144900 0.916024662 115.9720 0.9985180
weights[2,1] 0.78904096 0.008271534 0.10159131 0.5726654 0.7349224799 0.78935339 0.8587340 0.984562498 150.8485 1.0138684
weights[2,2] -0.20872677 0.008256790 0.10111547 -0.4265115 -0.2657348554 -0.20996668 -0.1389566 -0.016155017 149.9729 1.0122870
weights[2,3] -2.21062558 0.008298003 0.10127022 -2.4275077 -2.2657670090 -2.21437824 -2.1449375 -2.007571547 148.9417 1.0131808
weights[2,4] -0.20691604 0.008281673 0.10113992 -0.4194727 -0.2619759795 -0.20620171 -0.1383121 -0.008027167 149.1451 1.0133139
weights[2,5] 1.79936795 0.008289542 0.10118807 1.5851504 1.7427574705 1.79884520 1.8712273 1.994083503 149.0038 1.0127490
HARD model
data {
int<lower=1> N; // total number of observations
int<lower=2> ncat; // number of categories
int<lower=2> input_dim; // number of predictor levels
matrix[N,input_dim] X; // predictor design matrix
vector[ncat] Y[N]; // response variable (simplex?)
}
parameters {
matrix[input_dim, ncat - 1] weights_raw; // coefficients
real<lower=1> v;
}
transformed parameters{
matrix[input_dim, ncat] weights;
for(n in 1:(ncat-1)) for(m in 1:input_dim) weights[m,n] = weights_raw[m,n];
for(m in 1:input_dim) weights[m, ncat] = -sum(weights_raw[,m]);
}
model {
matrix[N, ncat] logits;
for (i in 1:input_dim) weights[i] ~ normal(0,2);
logits = X * weights;
for (n in 1:N) Y[n] ~ dirichlet(softmax(to_vector(logits[n])) * v);
}
SOFT model
data {
int<lower=1> N; // total number of observations
int<lower=2> ncat; // number of categories
int<lower=2> input_dim; // number of predictor levels
matrix[N,input_dim] X; // predictor design matrix
vector[ncat] Y[N]; // response variable (simplex?)
real pr;
}
parameters {
matrix[input_dim, ncat] weights; // coefficients
real<lower=ncat> v;
}
model {
matrix[N, ncat] logits;
for (i in 1:input_dim) sum( weights[i] ) ~ normal(0,pr * ncat);
for (i in 1:input_dim) weights[i] ~ normal(0,2);
v ~ cauchy(0, 2.5);
logits = X * weights;
for (n in 1:N) Y[n] ~ dirichlet(softmax(to_vector(logits[n])) * v + 1);
}
I hope this helps, I am interested in testing the best setting for bigger models, and update the reference manual with useful info.
Bw.