Gaussian quadrature inside stan?

I don’t know much JAGS but I think I can translate your model to Stan.
First, let’s just write the type declarations and create some fake Stan code that closely resembles the original JAGS.

data {
  int<lower=1> N;
  int<lower=1> I;
  int<lower=1> K;
  int<lower=1> C;
  vector<lower=0>[C] delta;
  real<lower=0> a_pai_star;
  real<lower=0> b_pai_star;
  real<lower=0> a_xr_star;
  real<lower=0> b_xr_star;
  int<lower=0,upper=1> Q[I,K];
  int<lower=0,upper=1> all_patterns[C,K];
  int<lower=0,upper=1> Y[N,I];
} parameters {
  simplex[C] pai;
  vector[N] th;
  real<lower=0,upper=1> pai_star[I];
  real<lower=-3,upper=3> easy[I];
  real<lower=0,upper=1> xr_star[I, K];
  int<lower=1,upper=C> c[N];
} transformed parameters {
  real r_star[I,K] = xr_star .* Q;
  int alpha[N,K] = all_patterns[c];
  int w[N,I,K] = (1 .- alpha) .* Q;
  real IRT[N,I];
  real p[N,I];
  for (n in 1:N) for (i in 1:I) {
    IRT[n,i] = inv_logit(1.7*(th[n] + easy[i]));
    p[n,i] = pai_star[i] * IRT[n,i];
    for (k in 1:K)
      p[n,i] *= pow(r_star[i,k],w[n,i,k]);
  }
} model {
  c ~ categorical(pai);
  pai ~ dirichlet(delta);
  th ~ std_normal();
  pai_star ~ beta(a_pai_star, b_pai_star);
  xr_star ~ beta(a_xr_star, b_xr_star);
  easy ~ uniform(-3, 3);
  Y ~ bernoulli(p);
}

This will not compile because Stan does not allow discrete parameters and there are some vectorization mistakes but it should be straightforward to compare with the JAGS code. Next step is to merge model and transformed parameters blocks into a single model block and move all variables into the innermost loop possible.

model {
  for (n in 1:N) {
    for (i in 1:I) {
      real r_star[K] = xr_star[i] .* Q[i];
      int alpha[K] = all_patterns[c[n]];
      int w[K] = (1 .- alpha) .* Q[i];
      real IRT = inv_logit(1.7*(th[n] + easy[i]));
      real p = pai_star[i] * IRT;
      for (k in 1:K)
        p *= pow(r_star[k], w[k]);
      Y[n,i] ~ bernoulli(p);
    }
    c[n] ~ categorical(pai);
  }
  pai ~ dirichlet(delta);
  th ~ std_normal();
  pai_star ~ beta(a_pai_star, b_pai_star);
  xr_star ~ beta(a_xr_star, b_xr_star);
  easy ~ uniform(-3, 3);
}

Then simplify the code and eliminate some variables. When Q=0 you have pow(0,0) which I assume evaluates to 1 in JAGS. Also here I replace the sampling statements with target += in preparation for the last step.

model {
  for (n in 1:N) {
    for (i in 1:I) {
      real IRT = inv_logit(1.7*(th[n] + easy[i]));
      real p = pai_star[i] * IRT;
      for (k in 1:K)
        if (all_patterns[c[n],k] == 0 && Q[i,k] == 1)
          p *= xr_star[i,k];
      target += bernoulli_lpmf(Y[n,i]|p);
    }
    target += categorical_lpmf(c[n]|pai);
  }
  pai ~ dirichlet(delta);
  th ~ std_normal();
  pai_star ~ beta(a_pai_star, b_pai_star);
  xr_star ~ beta(a_xr_star, b_xr_star);
  easy ~ uniform(-3, 3);
}

As I mentioned, Stan cannot sample discrete parameters. Therefore we must marginalize out c. The marginalized model is

data {
  int<lower=1> N;
  int<lower=1> I;
  int<lower=1> K;
  int<lower=1> C;
  vector<lower=0>[C] delta;
  real<lower=0> a_pai_star;
  real<lower=0> b_pai_star;
  real<lower=0> a_xr_star;
  real<lower=0> b_xr_star;
  int<lower=0,upper=1> Q[I,K];
  int<lower=0,upper=1> all_patterns[C,K];
  int<lower=0,upper=1> Y[N,I];
} parameters {
  simplex[C] pai;
  vector[N] th;
  real<lower=0,upper=1> pai_star[I];
  real<lower=-3,upper=3> easy[I];
  real<lower=0,upper=1> xr_star[I,K];
} model {
  for (n in 1:N) {
    real lp1[C];
    for (c in 1:C) {
      real lp2[I+1];
      for (i in 1:I) {
        real IRT = inv_logit(1.7*(th[n] + easy[i]));
        real p = pai_star[i] * IRT;
        for (k in 1:K) if (all_patterns[c,k] == 0 && Q[i,k] == 1)
          p *= xr_star[i,k];
        lp2[i] = bernoulli_lpmf(Y[n,i]|p);
      }
      lp2[I+1] = categorical_lpmf(c|pai);
      lp1[c] = sum(lp2);
    }
    target += log_sum_exp(lp1);
  }
  pai ~ dirichlet(delta);
  th ~ std_normal();
  pai_star ~ beta(a_pai_star, b_pai_star);
  for (i in 1:I)
    xr_star[i] ~ beta(a_xr_star, b_xr_star);
  easy ~ uniform(-3, 3);
}

The simplification removed alpha from the output but I don’t think you can marginalize th.

2 Likes