I normalize each factor (columns of B) to have unit length, but there are no additional constraints directly on b1
.
There is a significant amount of bells and whistles and transformed priors in the full model, so I am pretty sure the small snippet above is more meaningful, but here it is:
functions {
//sequential ordinal likelihood from BRMS
vector sratio_probit_vec(real mu, vector thres, real disc) {
int ncat = num_elements(thres) + 1;
vector[ncat] p;
vector[ncat - 1] q;
for (k in 1:(ncat - 1)) {
q[k] = 1 - Phi(disc * (thres[k] - mu));
p[k] = 1 - q[k];
for (kk in 1:(k - 1)) p[k] = p[k] * q[kk];
}
p[ncat] = prod(q);
return p;
}
/* sratio-probit log-PDF for a single response
* Args:
* y: response category
* mu: linear predictor
* thres: ordinal thresholds
* disc: discrimination parameter
* Returns:
* a scalar to be added to the log posterior
*/
real sratio_probit_lpmf(int y, real mu, vector thres, real disc) {
int ncat = num_elements(thres) + 1;
vector[ncat] p = sratio_probit_vec(mu, thres, disc);
return categorical_lpmf(y | p);
}
real sratio_probit_rng(real mu, vector thres, real disc) {
int ncat = num_elements(thres) + 1;
vector[ncat] p = sratio_probit_vec(mu, thres, disc);
return categorical_rng(p);
}
int sratio_probit_max(real mu, vector thres, real disc) {
int ncat = num_elements(thres) + 1;
int catmax = 0;
real pmax = 0;
vector[ncat] p = sratio_probit_vec(mu, thres, disc);
for (i in 1:ncat) {
if (pmax < p[i]) {
pmax = p[i];
catmax = i;
}
}
return catmax;
}
}
data {
int<lower=1> D;
int<lower=1> N;
int<lower=1> S;
int<lower=1> K;
int testmin;
int testmax;
int<lower=1,upper=D> group[N];
int<lower=testmin, upper=testmax> panss[S,N];
real<lower=0> disc;
int<lower=0,upper=1> include_likelihood;
real<lower=0> alpha;
}
transformed data {
int ncat = testmax - testmin + 1;
int factor_elements = S;
int load_elements = K;
}
parameters {
row_vector<lower=0>[K] b1;
matrix[S-1,K] B_tilde;
matrix[K,N] W_tilde;
vector[testmax-1] thres[D];
row_vector[N] person_bias;
vector[S] symptom_bias;
positive_ordered[K] sigma_tilde;
real<lower=0> scale;
row_vector<lower=0>[N] load_lambda;
matrix<lower=0>[K,N] local_load_lambda;
real<lower=0,upper=1> split;
}
transformed parameters {
matrix[K,N] W; //loadings/weights
matrix[S,K] B = append_row(b1, B_tilde); //factors
vector[K] sigma = sigma_tilde/sum(sigma_tilde); //factor strength
matrix[S,N] F; //latent matrix
matrix[S,N] F_bias; //bias
matrix<lower=0,upper=1>[K,N] load_shrinkage; //sparsity weight
row_vector<lower=0, upper=load_elements>[N] load_total;
row_vector[N] load_jac_vec;
vector[testmax] p[D];
for (d in 1:D) {
p[d] = sratio_probit_vec(0., thres[d], disc);
}
{ //unit length factor
vector[K] inv_blength;
for (k in 1:K) inv_blength[k] = 1./(sqrt(sum(square(B[,k]))));
B = diag_post_multiply(B, inv_blength);
}
{//Loadings
matrix[K,N] lambda_prod;
matrix[K,N] lambda_tilde;
//regularized horseshoe
real tau = 1.;
real c2 = 1.;
lambda_prod = local_load_lambda .* rep_matrix(load_lambda, K);
lambda_tilde = sqrt( c2 * square(lambda_prod) ./ (c2 + square(tau) * square(lambda_prod) ));
load_shrinkage = 1. - 1. ./ (1. + square(tau) * square(lambda_prod));
load_total = rep_row_vector(1,K) * load_shrinkage;
load_jac_vec = rep_row_vector(1,K) * (2. * load_shrinkage .* (1. - load_shrinkage));
load_jac_vec = log(load_jac_vec) - log(load_lambda);
W = tau * lambda_tilde .* W_tilde;
}
//bias terms
F_bias = (rep_matrix(symptom_bias, N) + rep_matrix(person_bias, S));
//latent matrix
F = scale * ( split * B * diag_pre_multiply(sigma, W) + (1.-split) * F_bias );
}
model {
b1 ~ normal(0,1);
to_vector(B_tilde) ~ normal(0.,1.);
to_vector(W_tilde) ~ normal(0.,1.);
split ~ beta(1.,1.);
sigma_tilde ~ gamma(alpha, 1.);
person_bias ~ normal(0, 1.);
symptom_bias ~ normal(0., 1.);
//special sparsity prior
to_vector(local_load_lambda) ~ cauchy(0., 1.);
load_total ~ normal(0., load_elements/3.);
target += sum((load_jac_vec));
scale ~ normal(0., 1.);
//special threshold prior
for (d in 1:D) {
p[d] ~ dirichlet(rep_vector(1., ncat));
target += sum(log(p[d,1:(ncat-1)])) + normal_lpdf(thres[d] | 0., 1./disc) - normal_lcdf(disc * thres[d] | 0., 1.);
}
//ordinal likelihood
if (include_likelihood) {
for (n in 1:N) {
for (s in 1:S) {
target += sratio_probit_lpmf(panss[s,n] | F[s,n], thres[group[n]], disc);
}
}
}
}
generated quantities {
matrix[S,N] mode;
matrix[S,N] ppc;
for (n in 1:N) {
for (s in 1:S) {
mode[s,n] = sratio_probit_max(F[s,n], thres[group[n]], disc);
ppc[s,n] = sratio_probit_rng(F[s,n], thres[group[n]], disc);
}
}
}