Brief Description of the Problem
Hello, I need to fit a finite mixture of Gaussian models with the weights depend on predictors via a polytomous generalization of the logit link, i.e. the weight for each observation is :
pi = softmax(x*beta)
for each class (Gaussian). I have coded the model and below is the stan code:
N is total number of observations
Q is number of normals
K is number of predictors.
model = "
data {
int N;
int K;
int Q;
vector[N] y;
matrix[N,K] X;
vector[Q] means;
vector<lower=0>[Q] sds;
}
parameters{
matrix[K,Q] beta;
}
transformed parameters{
simplex[Q] pi[N];
vector[Q] pi_logit[N];
for(n in 1:N){
pi[n] = softmax(pi_logit[n]);
}
for(n in 1:N){
for(q in 1:Q){
pi_logit[n,q] = X[n]*beta[1:K,q];
}
}}
model{
vector[Q] lp_nq[N];
vector[N] lp_n;
for(q in 1:Q){
beta[1:K,q] ~ normal(0,3);
}
for(n in 1:N){
for(q in 1:Q){
lp_nq[n,q] = log(pi[n,q]) + normal_lpdf(y[n] | means[q], sds[q]);
}
lp_n[n] = log_sum_exp(lp_nq[n]);
}
target += sum(lp_n);
}"
The error is:
SAMPLING FOR MODEL ‘a1c9203faf153761790eac1af7adc304’ NOW (CHAIN 1).
Rejecting initial value:
Error evaluating the log probability at the initial value.
validate transformed params: pi[k0__] is not a valid simplex. sum(pi[k0__]) = nan, but should be 1
Rejecting initial value:
Error evaluating the log probability at the initial value.
validate transformed params: pi[k0__] is not a valid simplex. sum(pi[k0__]) = nan, but should be 1
Rejecting initial value:
…
I am pretty sure that I have the right likelihood setup and the error comes from the weights, pi
.
I have also provided the simulation data below.
Can anyone find any mistakes in the model?
beta1 = c(0.8, 1.2, 0.5, rep(0, 18))
beta2 = c(0.3, rep(0, 17), -1, 1.7, -2)
beta3 = c(0.3, 1, -2, 0.8, 0.9, rep(0, 16))
beta4 = rep(0,21)
beta = cbind(beta1, beta2, beta3, beta4)
t(beta)
n = 3000; Q = 4
set.seed(1)
x = rnorm(21*n)
xs = matrix(x, n)
pi.num = exp(xs %*% beta) # Numerator of pi
pi.den = apply(pi.num, 1, sum) # Denominator of pi, sum by row.
pis = pi.num/pi.den; colnames(pis) = c(1:4) # pi --- logit link
set.seed(1)
ind = rep(NA, n)
for (i in 1:n){
set.seed(1)
poo = rmultinom(1, 1, pis[i,])
ind[i] = row.names(poo)[poo==1]
}
ind = as.numeric(ind)
mu = c(0.58,0.44,0.38,0.2);# mu = rev(mu)
prec = c(43.49,78.32,101.95,45.71)
sigma = sqrt(1/prec); #sigma = rev(sigma)
set.seed(1)
y = rnorm(n, mu[ind], sigma[ind])
head(y)
plot(density(y))
dat = list(y = y, X = xs, Q = 4, K = 21, N=3000, means = mu, sds = sigma)