Hi,
I’m attempting to parallelize a set of hierarchical models using reduce_sum. As a simple and illustrative example, the code below was an attempt to extend the multilevel 2PL IRT model from the Stan user guide (pg. 29-31), with the only difference being that the data I’m using are structured in long/melted format, analogous to the data structure in the hierarchical logistic regression example (pg. 25). So, if we have J students all administered K items (no missingness), then the N observations for the response vector, y, are J*K in length.
In the 1st code example below, slicing is done over y (y_slice) for the arbitrary reason that this is what’s been chosen for slicing in almost every example I’ve come across. Because y varies across the N observations, my understanding is that slicing in this manner is done over each of the N=J*K observations. Let me know if I’m wrong on that.
Problem: I’m interested in slicing over the J students instead of the N observations to see how that affects run time. I’ve spent a good amount of time and a variety of approaches trying to set the model up, but nothing has worked – either the estimates are wrong or the model fails for other reasons. As mentioned, I’ve also not seen any straightforward examples that show how slicing is done at levels above individual observations, so there’s not much to adapt from. (It looks like the Stan user guide has an example of parallelizing a hierarchical model using map_rect, but I don’t believe there’s an analogous example using reduce_sum.) When I’ve seen this issue discussed elsewhere online, I’ve seen the suggestion to slice over observed predictor data x. However, there are no observed predictor data for a 2PL IRT model, so I’m not sure this applies. I’ve tried various other approaches, including slicing over the alpha vector (as per the 2nd example below), using an approach similar to what brms does of generating a sequence of indices to slice over, etc. None has worked, so clearly I’m approaching this wrong somewhere.
Could you provide any thoughts as to how to approach this? Thanks for any ideas.
Original model (slicing over y)
function {
real partial_sum(int[] y_slice,
int start,
int end,
vector gamma,
vector alpha,
vector beta,
int[] kk,
int[] jj) {
int N = end - start + 1;
vector[N] mu;
for (n in 1:N) {
mu[n] = gamma[kk[n]] * (alpha[jj[n]] - beta[kk[n]]);
}
return bernoulli_logit_lpmf(y_slice | mu);
}
}
data {
int<lower=1> J; // number of students
int<lower=1> K; // number of questions
int<lower=1> N; // number of observations
int<lower=1,upper=J> jj[N]; // student associated with observation n
int<lower=1,upper=K> kk[N]; // question associated with observation n
int<lower=0,upper=1> y[N]; // correctness for observation n
}
parameters {
real mu_beta; // mean question difficulty
vector[J] alpha; // ability for j - mean
vector[K] beta; // difficulty for k
vector<lower=0>[K] gamma; // discrimination of k
real<lower=0> sigma_beta; // scale of difficulties
real<lower=0> sigma_gamma; // scale of log discrimination
}
model {
int grainsize = 1;
alpha ~ std_normal();
beta ~ normal(mu_beta, sigma_beta);
gamma ~ lognormal(0, sigma_gamma);
mu_beta ~ cauchy(0, 5);
sigma_beta ~ cauchy(0, 5);
sigma_gamma ~ cauchy(0, 5);
target += reduce_sum(partial_sum, y, grainsize, gamma, alpha, beta);
}
Slicing using alpha
function {
real partial_sum(vector[] a_slice,
int start,
int end,
int[] y,
vector gamma,
vector beta,
int[] kk) {
return bernoulli_logit_lpmf(y[start:end] | gamma[kk] .*
(a_slice - beta[kk]));
}
}
data {
int<lower=1> J; // number of students
int<lower=1> K; // number of questions
int<lower=1> N; // number of observations
int<lower=1,upper=J> jj[N]; // student associated with observation n
int<lower=1,upper=K> kk[N]; // question associated with observation n
int<lower=0,upper=1> y[N]; // correctness for observation n
}
parameters {
real mu_beta; // mean question difficulty
vector[J] alpha; // ability for j - mean
vector[K] beta; // difficulty for k
vector<lower=0>[K] gamma; // discrimination of k
real<lower=0> sigma_beta; // scale of difficulties
real<lower=0> sigma_gamma; // scale of log discrimination
}
model {
int grainsize = 1;
alpha ~ std_normal();
beta ~ normal(mu_beta, sigma_beta);
gamma ~ lognormal(0, sigma_gamma);
mu_beta ~ cauchy(0, 5);
sigma_beta ~ cauchy(0, 5);
sigma_gamma ~ cauchy(0, 5);
target += reduce_sum(partial_sum, alpha, grainsize, y, gamma, beta, kk);
}