How to most efficiently reduce_sum in a hierarchical logistic model

Hi all,

I have a pretty straightforward hierarchical logistic regression model. The dataset is kind of large (160k obs, 5k participants, 150 items), and so I want to make use of reduce _sum. The question is, what is the best way to do so?

It intuitively makes sense to send as much as possible into reduce_sum - ie the computation of the regression predictions, rather than just the computed predictions. This however, requires sending quite a few variables into partial_sum. Is that the way to go? (see below)

Second - the biggest variable going into partial_sum is the model matrix x. Should I send it in as the first variable, the one that gets cut into pieces? If so - recude_sum only works with arrays in the first argument. Should I send it in as an array and convert it back to a matrix for algebra within partial_sum?

Model below,

Thanks!

Original model:

data {
  int<lower=0> N;              // num observations
  int<lower=1> K;              // num predictors
  int<lower=1> J1;              // num participants
  int<lower=1> J2;              // num questions
  int<lower=1> k1;              // num of predictors that vary by participant
  int<lower=1> k2;              // num of predictors that vary by question
  
  int<lower=1,upper=J1> jj1[N];  // participant number
  int<lower=1,upper=J2> jj2[N];  // question number
  
  int<lower=1,upper=K> V1[k1]; // Indices of predictors that vary by participant
  int<lower=1,upper=K> V2[k2]; // Indices of predictors that vary by participant
  
  matrix[N, K] x;               // model matrix
  int y[N];                 // Bernoulli outcomes
}

parameters {
  matrix[k1, J1] z1; // Unscaled deviations for participant
  matrix[k2, J2] z2; // Unscaled deviations for question
  cholesky_factor_corr[k1] L_Omega1; // Correlation matrix for participant deviations
  cholesky_factor_corr[k2] L_Omega2; // Correlation matrix for question deviations
  vector<lower=0>[k1] tau1; // SDs of participant deviations
  vector<lower=0>[k2] tau2; // SDs of participant deviations
  vector[K] beta; // Group effects
}

model{
  matrix[J1, k1] w1 = (diag_pre_multiply(tau1, L_Omega1) * z1)'; // scaled participant deviations
  matrix[J2, k2] w2 = (diag_pre_multiply(tau2, L_Omega2) * z2)'; // scaled question deviations
  vector[N] p;


  to_vector(z1) ~ std_normal();
  to_vector(z2) ~ std_normal();
  
  L_Omega1 ~ lkj_corr_cholesky(2);
  L_Omega2 ~ lkj_corr_cholesky(2);
  
  tau1 ~ std_normal();
  tau2 ~ std_normal();
  
  beta ~ std_normal();
  
  p = x * beta;  // Group effects
  p+= (x[:, V1] .* w1[jj1, :]) * rep_vector(1, k1); // By participant effects
  p +=  (x[:, V2] .* w2[jj2, :]) * rep_vector(1, k2); // By question effects
  
  y ~ bernoulli_logit(p);
}

attempt at reduce_sum:

functions {
  real partial_sum(real[,] x,
    int start,
    int end,
    int[] y,
    vector beta,
    int[] V1,
    int[] V2,
    int k1,
    int k2,
    matrix w1,
    matrix w2,
    int [] jj1,
    int [] jj2,
    int N,
    int K) {
        matrix[N,K] mx = to_matrix(x); 
        return bernoulli_logit_lpmf(y | mx * beta + (mx[:, V1] .* w1[jj1[start:end], :]) * rep_vector(1, k1) + (mx[:, V2] .* w2[jj2[start:end], :]) * rep_vector(1, k2));
    }
}

data {
  int<lower=0> N;              // num observations
  int<lower=1> K;              // num predictors
  int<lower=1> J1;              // num participants
  int<lower=1> J2;              // num questions
  int<lower=1> k1;              // num of predictors that vary by participant
  int<lower=1> k2;              // num of predictors that vary by question
  
  int<lower=1,upper=J1> jj1[N];  // participant number
  int<lower=1,upper=J2> jj2[N];  // question number
  
  int<lower=1,upper=K> V1[k1]; // Indices of predictors that vary by participant
  int<lower=1,upper=K> V2[k2]; // Indices of predictors that vary by question
  
  real x[N, K];               // model matrix
  int y[N];                 // Bernoulli outcomes
}

parameters {
  matrix[k1, J1] z1; // Unscaled deviations for participant
  matrix[k2, J2] z2; // Unscaled deviations for question
  cholesky_factor_corr[k1] L_Omega1; // Correlation matrix for participant deviations
  cholesky_factor_corr[k2] L_Omega2; // Correlation matrix for question deviations
  vector<lower=0>[k1] tau1; // SDs of participant deviations
  vector<lower=0>[k2] tau2; // SDs of participant deviations
  vector[K] beta; // Group effects
}

model{
  matrix[J1, k1] w1 = (diag_pre_multiply(tau1, L_Omega1) * z1)'; // scaled participant deviations
  matrix[J2, k2] w2 = (diag_pre_multiply(tau2, L_Omega2) * z2)'; // scaled question deviations


  to_vector(z1) ~ std_normal();
  to_vector(z2) ~ std_normal();
  
  L_Omega1 ~ lkj_corr_cholesky(2);
  L_Omega2 ~ lkj_corr_cholesky(2);
  
  tau1 ~ std_normal();
  tau2 ~ std_normal();
  
  beta ~ std_normal();
  
  target += reduce_sum(partial_sum, x, 1, y, beta, V1, V2, k1, k2, w1, w2, jj1, jj1, N, K);
}
2 Likes

Sorry to say, but the way you do it right now is very inefficient, because the to_matrix statement in the reduce casts the data into a parameter, because you assign it to a local variable.

In Stan it is key to discriminate between what is data and what is a parameter (things wrt we do not calculate the gradient and things for which we do calculate the gradient). Moreover, any data you pass into the partial_sum function is not being copied. So that is cheap and you do not need to worry about it.

In your case I would probably slice over the participants - and also sort your data by participants. When you do that you can use z1 to slice over (which you need to cast into vector[k1] z1[J1]). Furthermore it could be beneficial to transpose the design matrix X and save it into Xt which you then pass into the reducer. The reason is that you can then slice the matrix Xt column-wise for the sub-slices (obviously you need to transpose the product to beta * Xt[:,sub-slice]) - and matrices are stored with column major order.

What could help… though I am really not sure… is to move the product of the scaled random effects into the partial sum function as well (which you would evaluate only in the slices as needed in the reducer)… but as I said, I am not sure on that.

Finally, play a little bit with the grainsize here. As you have 5k participants - so 5k units to sum over if done by participant - then I would start to set grainsize to 100 or so to start with.

I hope this helps. I would hope that this model can benefit from reduce_sum - but it will not be easy! The problem is that the bernoulli_logit_lpmf is computationally very cheap such that the overall cost of doing the parallelisation vs what needs to be calculated is not negligible.

5 Likes

You might also take a look here as I suspect you might have a lot of redundant computations in this model that could be eliminated for performance gains in addition to using reduce_sum.

1 Like

Thank you! It is good to know about data not being copied.
I’ll try to work with the participant matrix, and changing grainzise.

About the design matrix - is the point that slicing matrices by column more efficient than by row?

I had a look at the code - interesting stuff! I’d be glad to read the case study when it’s out.
Am I correct that the idea works for repeated measures of the same condition within a subject? If so, that unfortunately won’t save me - each subject sees each item only once…

I’ve rewritten my code, with the partial_sum function now being:

functions {
  real partial_sum(vector[] z1, // Unscaled participant devaitions, sliced
  int start,
  int end,
  matrix x, // Design matrix
  int[] y, // Response
  vector beta, // Group effect parameters
  int[] V1, // Effects that vary by participant
  int[] V2, // Effects that vary by question
  int k1, // Length of V1
  int k2, // Length of V2
  matrix Sigma1, // Covariance matrix for by-participant deviations
  matrix w2, // Scaled question deviations
  int [] jj1, // Participant index for each trial
  int [] js1, // First trial for each participant
  int [] jj2) { // Questions index for each trial
    
    int dstart = js1[start]; // Index of first trial computed over in this partial sum
    int dend = js1[end+1] - 1; // Index of last trial computed over in this partial sum
    int n = end - start + 1; // n participants in this partial sum
    matrix[n, k1] w1;  // Matrix of scaled participant deviations for this partial sum
    int indx1[n] = jj1[dstart:dend]; // Indices of participant deviations for this partial sum
    
    // Adjust indices of participant deviations to start at 1
    for (i in 1:(dend - dstart + 1)){
      indx1[i] -= (start - 1);
    }
    
    // Mutliply unscaled deviations by covariance matrix
    for (i in 1:n){
      w1[i,] = (Sigma1 * z1[i])';
    }
    
    // Return liklihood
    return bernoulli_logit_lpmf(y[dstart:dend] | x[dstart:dend,] * beta +  // Group effects
      (x[dstart:dend, V1] .* w1[indx1,:]) * rep_vector(1, k1) +  // By-participant deviations
      (x[dstart:dend, V2] .* w2[jj2[dstart:dend], :]) * rep_vector(1, k2)); // By-question deviations
  }
}

I get a 3.5 times improvement in runtime on a subset of the data with 100 participants (grainsize 10) with 4 threads. Now running full dataset with grainsize=100, seems to be moving along at a good clip.

I’ll try transposing the design matrix x and subsetting by column and see if it makes a difference.

Thanks!

2 Likes

The critical thing is to identify redundancy in the design matrix, x in your example. Even with what I usually call “crossed” random effects (as you have with subject and question), you should be able to at least compute each separately in the more efficient manner I show in that case study. Indeed, you already compute each separately:

But I suspect that that (x[:, V1] .* w1[jj1, :]) has a lot of redundant computations, ditto (x[:, V2] .* w2[jj2, :]). Also, why are you bothering to multiply by a vector of 1s at the end?

First - about redundancy: I have specific columns in my design matrix that have repeating values - e.g. the intercept, or a contrast for a categoric predictor. Are you saying that it is worth computing them separately, once for each level, and then adding them? That would mean breaking the design matrix apart, and would make my code less flexible (e.g. for adding a predictor to the model), but if it results in meaningful speedup, I’ll give it a shot.

As for multiplying by a vector of 1s: I have an NxK design matrix x, and an NxK matrix with the by-participant parameters w. I multiply them element-wise, and then take the product with a ones vector in order to sum the resultant matrix along it’s columns to get an Nx1 vector.

This sounds like you can use one of the specialized dot products possibly? Have a look at the Stan manual for specialized products. If you can represent this operation in a single go then this should speed up things a lot.

But seeing 3.5x on 4 cores is already nice…how does the single thread reduce sum runtime compare against the old code?

Not sure if this helps here, but if there is redundancy in your design matrix, it could be worth checking if you can exploit sufficient statistics and use a binomial instead of a Bernoulli model.
See my last response in this thread for some more explanation: Weighted logistic regression

Since the data have two crossed/additive random effects, I don’t think there’s the redundancy at the trial level necessary to use the sufficient statistics trick.

Blockquote
Since the data have two crossed/additive random effects, I don’t think there’s the redundancy at the trial level necessary to use the sufficient statistics trick.

Right, there is no redundancy at the trial level, only on certain columns of the design matrix.

Blockquote
This sounds like you can use one of the specialized dot products possibly? Have a look at the Stan manual for specialized products.

I originally tried this with rows_dot_product as suggested in the manual, but had the impression it was slower. Perhaps I didn’t check this rigourosuly enough (or at all).

Edit: I tried again. Indeed rows_dot_product is slower.

1 Like

I’m confused how data becomes a parameter when assigned to a local variable.

Every variable declared outside of data or transformed data is always treated as a parameter in Stan. The only exception to that are temporary expressions formed solely by objects coming from data or transformed data.

The issue with the local variable is that you are allowed to munge it within the function block with other parameters. Hence it needs to be treated as a parameter. The stanc3 compiler could be more clever about this, but this is not currently the case (tagging @nhuurre who may comment). For now we cast many things to be a parameter, but this case - I agree - could (and actually should) be autodetected from the compiler and treated accordingly; but we are not yet there.

3 Likes