How to most efficiently reduce_sum in a hierarchical logistic model

I’ve rewritten my code, with the partial_sum function now being:

functions {
  real partial_sum(vector[] z1, // Unscaled participant devaitions, sliced
  int start,
  int end,
  matrix x, // Design matrix
  int[] y, // Response
  vector beta, // Group effect parameters
  int[] V1, // Effects that vary by participant
  int[] V2, // Effects that vary by question
  int k1, // Length of V1
  int k2, // Length of V2
  matrix Sigma1, // Covariance matrix for by-participant deviations
  matrix w2, // Scaled question deviations
  int [] jj1, // Participant index for each trial
  int [] js1, // First trial for each participant
  int [] jj2) { // Questions index for each trial
    
    int dstart = js1[start]; // Index of first trial computed over in this partial sum
    int dend = js1[end+1] - 1; // Index of last trial computed over in this partial sum
    int n = end - start + 1; // n participants in this partial sum
    matrix[n, k1] w1;  // Matrix of scaled participant deviations for this partial sum
    int indx1[n] = jj1[dstart:dend]; // Indices of participant deviations for this partial sum
    
    // Adjust indices of participant deviations to start at 1
    for (i in 1:(dend - dstart + 1)){
      indx1[i] -= (start - 1);
    }
    
    // Mutliply unscaled deviations by covariance matrix
    for (i in 1:n){
      w1[i,] = (Sigma1 * z1[i])';
    }
    
    // Return liklihood
    return bernoulli_logit_lpmf(y[dstart:dend] | x[dstart:dend,] * beta +  // Group effects
      (x[dstart:dend, V1] .* w1[indx1,:]) * rep_vector(1, k1) +  // By-participant deviations
      (x[dstart:dend, V2] .* w2[jj2[dstart:dend], :]) * rep_vector(1, k2)); // By-question deviations
  }
}

I get a 3.5 times improvement in runtime on a subset of the data with 100 participants (grainsize 10) with 4 threads. Now running full dataset with grainsize=100, seems to be moving along at a good clip.

I’ll try transposing the design matrix x and subsetting by column and see if it makes a difference.

Thanks!

2 Likes