Hi all,
I have a pretty straightforward hierarchical logistic regression model. The dataset is kind of large (160k obs, 5k participants, 150 items), and so I want to make use of reduce _sum. The question is, what is the best way to do so?
It intuitively makes sense to send as much as possible into reduce_sum - ie the computation of the regression predictions, rather than just the computed predictions. This however, requires sending quite a few variables into partial_sum. Is that the way to go? (see below)
Second - the biggest variable going into partial_sum is the model matrix x. Should I send it in as the first variable, the one that gets cut into pieces? If so - recude_sum only works with arrays in the first argument. Should I send it in as an array and convert it back to a matrix for algebra within partial_sum?
Model below,
Thanks!
Original model:
data {
int<lower=0> N; // num observations
int<lower=1> K; // num predictors
int<lower=1> J1; // num participants
int<lower=1> J2; // num questions
int<lower=1> k1; // num of predictors that vary by participant
int<lower=1> k2; // num of predictors that vary by question
int<lower=1,upper=J1> jj1[N]; // participant number
int<lower=1,upper=J2> jj2[N]; // question number
int<lower=1,upper=K> V1[k1]; // Indices of predictors that vary by participant
int<lower=1,upper=K> V2[k2]; // Indices of predictors that vary by participant
matrix[N, K] x; // model matrix
int y[N]; // Bernoulli outcomes
}
parameters {
matrix[k1, J1] z1; // Unscaled deviations for participant
matrix[k2, J2] z2; // Unscaled deviations for question
cholesky_factor_corr[k1] L_Omega1; // Correlation matrix for participant deviations
cholesky_factor_corr[k2] L_Omega2; // Correlation matrix for question deviations
vector<lower=0>[k1] tau1; // SDs of participant deviations
vector<lower=0>[k2] tau2; // SDs of participant deviations
vector[K] beta; // Group effects
}
model{
matrix[J1, k1] w1 = (diag_pre_multiply(tau1, L_Omega1) * z1)'; // scaled participant deviations
matrix[J2, k2] w2 = (diag_pre_multiply(tau2, L_Omega2) * z2)'; // scaled question deviations
vector[N] p;
to_vector(z1) ~ std_normal();
to_vector(z2) ~ std_normal();
L_Omega1 ~ lkj_corr_cholesky(2);
L_Omega2 ~ lkj_corr_cholesky(2);
tau1 ~ std_normal();
tau2 ~ std_normal();
beta ~ std_normal();
p = x * beta; // Group effects
p+= (x[:, V1] .* w1[jj1, :]) * rep_vector(1, k1); // By participant effects
p += (x[:, V2] .* w2[jj2, :]) * rep_vector(1, k2); // By question effects
y ~ bernoulli_logit(p);
}
attempt at reduce_sum:
functions {
real partial_sum(real[,] x,
int start,
int end,
int[] y,
vector beta,
int[] V1,
int[] V2,
int k1,
int k2,
matrix w1,
matrix w2,
int [] jj1,
int [] jj2,
int N,
int K) {
matrix[N,K] mx = to_matrix(x);
return bernoulli_logit_lpmf(y | mx * beta + (mx[:, V1] .* w1[jj1[start:end], :]) * rep_vector(1, k1) + (mx[:, V2] .* w2[jj2[start:end], :]) * rep_vector(1, k2));
}
}
data {
int<lower=0> N; // num observations
int<lower=1> K; // num predictors
int<lower=1> J1; // num participants
int<lower=1> J2; // num questions
int<lower=1> k1; // num of predictors that vary by participant
int<lower=1> k2; // num of predictors that vary by question
int<lower=1,upper=J1> jj1[N]; // participant number
int<lower=1,upper=J2> jj2[N]; // question number
int<lower=1,upper=K> V1[k1]; // Indices of predictors that vary by participant
int<lower=1,upper=K> V2[k2]; // Indices of predictors that vary by question
real x[N, K]; // model matrix
int y[N]; // Bernoulli outcomes
}
parameters {
matrix[k1, J1] z1; // Unscaled deviations for participant
matrix[k2, J2] z2; // Unscaled deviations for question
cholesky_factor_corr[k1] L_Omega1; // Correlation matrix for participant deviations
cholesky_factor_corr[k2] L_Omega2; // Correlation matrix for question deviations
vector<lower=0>[k1] tau1; // SDs of participant deviations
vector<lower=0>[k2] tau2; // SDs of participant deviations
vector[K] beta; // Group effects
}
model{
matrix[J1, k1] w1 = (diag_pre_multiply(tau1, L_Omega1) * z1)'; // scaled participant deviations
matrix[J2, k2] w2 = (diag_pre_multiply(tau2, L_Omega2) * z2)'; // scaled question deviations
to_vector(z1) ~ std_normal();
to_vector(z2) ~ std_normal();
L_Omega1 ~ lkj_corr_cholesky(2);
L_Omega2 ~ lkj_corr_cholesky(2);
tau1 ~ std_normal();
tau2 ~ std_normal();
beta ~ std_normal();
target += reduce_sum(partial_sum, x, 1, y, beta, V1, V2, k1, k2, w1, w2, jj1, jj1, N, K);
}