There are two observed variables, kk[n]
and y[n]
, which have a shared unobserved predictor u[n]
.
In another question this answer showed how marginalization works. I want to marginalize out u
in this new model. I’m still trying to get the hang of it and this one seems much more complicated because of the categorical and binary variables. Is this a straightforward thing to do or should I abandon this marginalization idea here? How would I go about it?
I generate data with this model
data {
int<lower=0> N;
int<lower=1> J;
int<lower=1> K;
}
generated quantities {
real x[N];
real u[N];
for (n in 1:N) {
x[n] = normal_rng(0,1);
u[n] = normal_rng(0,1);
}
real beta[K];
real gamma[K];
vector[K] b0;
vector[K] g0;
for (k in 1:K) {
beta[k] = normal_rng(1,1);
gamma[k] = normal_rng(3,1);
b0[k] = normal_rng(0,1);
g0[k] = normal_rng(0,1);
}
real<lower=0> sigma = uniform_rng(.6,.8);
real mu[N];
int y[N];
simplex[J] theta;
int<lower=1,upper=J> jj[N];
int<lower=1,upper=K> kk[N];
theta = dirichlet_rng(rep_vector(1,J));
vector[K] h0[J];
for (j in 1:J) {
for (k in 1:K) {
h0[j][k] = normal_rng(0,1);
}
}
for (n in 1:N) {
jj[n] = categorical_rng(theta);
kk[n] = categorical_logit_rng(h0[jj[n]] + x[n]*b0 + u[n]*g0);
mu[n] = x[n] * beta[kk[n]] + u[n] * gamma[kk[n]];
y[n] = bernoulli_logit_rng(mu[n]);
}
}
and infer with this model
data {
int<lower=0> N;
int<lower=1> K;
int<lower=1> J;
int y[N];
real x[N];
int jj[N];
int kk[N];
}
parameters {
real mu_beta[K];
real<lower=0> mu_gamma[K];
real<lower=0> sigma_beta;
real<lower=0> sigma_gamma;
real offset_beta[K];
real<lower=0> offset_gamma[K];
vector[N] u;
real<lower=0> sigma;
vector[K] b0;
vector[K] g0;
vector[K] j0;
simplex[J] theta;
}
transformed parameters {
real beta[K];
real gamma[K];
for (k in 1:K) {
beta[k] = offset_beta[k] * sigma_beta + mu_beta[k];
gamma[k] = offset_gamma[k] * sigma_gamma + mu_gamma[k];
}
}
model {
to_vector(u) ~ normal(0,1);
mu_beta ~ normal(1, 3);
mu_gamma ~ normal(1, 3);
sigma_beta ~ normal(1,1);
sigma_gamma ~ normal(1,1);
for (k in 1:K) {
offset_beta[k] ~ normal(0,1);
offset_gamma[k] ~ normal(0,1);
}
sigma ~ normal(.8,.1);
to_vector(b0) ~ normal(0,3);
to_vector(g0) ~ normal(0,3);
to_vector(j0) ~ normal(0,3);
theta ~ dirichlet(rep_vector(1,J));
real mu[N];
for (n in 1:N) {
jj[n] ~ categorical(theta);
kk[n] ~ categorical_logit(jj[n]*j0 + x[n]*b0 + u[n]*g0);
mu[n] = x[n]*beta[kk[n]] + u[n]*gamma[kk[n]];
y[n] ~ bernoulli_logit(mu[n]);
}
}
generated quantities {
real p[K];
for (k in 1:K) {
p[k] = inv_logit(mean(x)*beta[k] + mean(u)*mean(gamma));
}
}