Hi everyone:
It seems this question is starting to become a classic. It is related to log_mix, log_exp_sum and the imputation of a 3 categorical covariate to model a binary outcome. I have checked several similar questions and references:
https://andrewgelman.com/2017/08/21/mixture-models-stan-can-use-log_mix/
But it is still not enough clear for me. I have done the exploration of what works hierarchically. I am comparing step bu step output from JAGS and STAN. I get to the point to impute the categorical covariate using only complete data and it works like a charm (same coefficients and posterior predictive of the categories). So I know the problem comes when moving to use the data that includes the unobserved categories.
In the following code
y is the binary outcome
cov_cat_1 and cov_cat_3 are dummy variables for the categories 1 and 3 (-1 when NA)
day and lat are predictors of the log odds to be in category 1-3
So as you can see theta is a vector of length 3 containing the log odds to be in category 1-3
and once this category is imputed the idea is to use this imputation to estimate a0, beta_bp, and beta_m.
My questions are: what is wrong? What do you suggest? I would really appreciate suggestions to be as clear as possible, please.
model{
a0 ~ normal(0,30);
beta_bp ~ normal(0,30);
beta_m ~ normal(0,30);
a_imp ~ normal(0,30); // explained above
b1_imp ~ normal(0,30); // explained above
b2_imp ~ normal(0,30); // explained above
for (i in 1:n_obs) {
if (cov_cat_miss[i] == 0) {
y[i] ~ bernoulli_logit(a0+
beta_bp*cov_cat_1[i]+
beta_m*cov_cat_3[i]);}
else {
vector[n_cat] theta;
vector[n_cat] log_prob_theta;
matrix[n_cat, n_cat] lp;
real p2 = a_imp[2] + b1_imp[2]*day[i] + b2_imp[2]*lat[i];
real p3 = a_imp[3] + b1_imp[3]*day[i] + b2_imp[3]*lat[i];
theta[1] = 0;
theta[2] = p2;
theta[3] = p3;
log_prob_theta=log_softmax(theta);
lp[1,1] = log_prob_theta[1] + bernoulli_logit_lpmf( y[i] | a0 + beta_bp); //cat 1
lp[2,1] = log_prob_theta[1] + bernoulli_logit_lpmf( y[i] | a0); //cat 2 (baseline)
lp[3,1] = log_prob_theta[1] + bernoulli_logit_lpmf( y[i] | a0 + beta_m); //cat 3
lp[1,2] = log_prob_theta[2] + bernoulli_logit_lpmf( y[i] | a0 + beta_bp);
lp[2,2] = log_prob_theta[2] + bernoulli_logit_lpmf( y[i] | a0);
lp[3,2] = log_prob_theta[2] + bernoulli_logit_lpmf( y[i] | a0 + beta_m);
lp[1,3] = log_prob_theta[3] + bernoulli_logit_lpmf( y[i] | a0 + beta_bp);
lp[2,3] = log_prob_theta[3] + bernoulli_logit_lpmf( y[i] | a0);
lp[3,3] = log_prob_theta[3] + bernoulli_logit_lpmf( y[i] | a0 + beta_m);
target += log_sum_exp(lp);
}
}
}