Modifying brms code to incorporate reduce_sum

I’ve set up a nonlinear model in brms that I would like to run with the reduce_sum function. I don’t have much experience coding directly in Stan, so please let me know if I should provide any additional information that would be helpful.

The setup for this model in brms is:

bf_1 <- bf(Accuracy ~ ppe(repetition, model_time_z, stability_z, b, m, bl, tau),
           b + m + bl + tau ~ 1 + (1|User),
           nl = T)

prior_1 <- 
  prior(normal(0, 1), nlpar = 'b') +
  prior(normal(0, 1), nlpar = 'm') +
  prior(normal(0, 1), nlpar = 'bl') +
  prior(normal(0, 1), nlpar = 'tau')

fit_1 <- brm(bf_1,
             prior = prior_1,
             stanvar = stanvars,
             data = dat_prepped)

With the nonlinear function (called “ppe”) supplied to the stanvar argument.

The Stan code produced by brms is:

// generated with brms 2.12.0
functions {
real ppe(real N,
           real T,
           real St,
           real b,
           real m,
           real bl,
           real tau) {
    real s = .1;
    real c = .1;
    real d;
    real M;
    real P;
    real forget;
    real learn;

    d = b + m * St;
    
    forget = T^-d;
    learn = (bl + N) ^ c;
    M = learn * forget;
    P = (tau - M)/s;
    return P;
  }

}
data {
  int<lower=1> N;  // number of observations
  int Y[N];  // response variable
  int<lower=1> K_b;  // number of population-level effects
  matrix[N, K_b] X_b;  // population-level design matrix
  int<lower=1> K_m;  // number of population-level effects
  matrix[N, K_m] X_m;  // population-level design matrix
  int<lower=1> K_bl;  // number of population-level effects
  matrix[N, K_bl] X_bl;  // population-level design matrix
  int<lower=1> K_tau;  // number of population-level effects
  matrix[N, K_tau] X_tau;  // population-level design matrix
  // covariate vectors for non-linear functions
  int C_1[N];
  vector[N] C_2;
  vector[N] C_3;
  // data for group-level effects of ID 1
  int<lower=1> N_1;  // number of grouping levels
  int<lower=1> M_1;  // number of coefficients per level
  int<lower=1> J_1[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_1_b_1;
  // data for group-level effects of ID 2
  int<lower=1> N_2;  // number of grouping levels
  int<lower=1> M_2;  // number of coefficients per level
  int<lower=1> J_2[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_2_m_1;
  // data for group-level effects of ID 3
  int<lower=1> N_3;  // number of grouping levels
  int<lower=1> M_3;  // number of coefficients per level
  int<lower=1> J_3[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_3_bl_1;
  // data for group-level effects of ID 4
  int<lower=1> N_4;  // number of grouping levels
  int<lower=1> M_4;  // number of coefficients per level
  int<lower=1> J_4[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_4_tau_1;
  int prior_only;  // should the likelihood be ignored?
}
transformed data {
}
parameters {
  vector[K_b] b_b;  // population-level effects
  vector[K_m] b_m;  // population-level effects
  vector[K_bl] b_bl;  // population-level effects
  vector[K_tau] b_tau;  // population-level effects
  vector<lower=0>[M_1] sd_1;  // group-level standard deviations
  vector[N_1] z_1[M_1];  // standardized group-level effects
  vector<lower=0>[M_2] sd_2;  // group-level standard deviations
  vector[N_2] z_2[M_2];  // standardized group-level effects
  vector<lower=0>[M_3] sd_3;  // group-level standard deviations
  vector[N_3] z_3[M_3];  // standardized group-level effects
  vector<lower=0>[M_4] sd_4;  // group-level standard deviations
  vector[N_4] z_4[M_4];  // standardized group-level effects
}
transformed parameters {
  vector[N_1] r_1_b_1;  // actual group-level effects
  vector[N_2] r_2_m_1;  // actual group-level effects
  vector[N_3] r_3_bl_1;  // actual group-level effects
  vector[N_4] r_4_tau_1;  // actual group-level effects
  r_1_b_1 = (sd_1[1] * (z_1[1]));
  r_2_m_1 = (sd_2[1] * (z_2[1]));
  r_3_bl_1 = (sd_3[1] * (z_3[1]));
  r_4_tau_1 = (sd_4[1] * (z_4[1]));
}
model {
  // initialize linear predictor term
  vector[N] nlp_b = X_b * b_b;
  // initialize linear predictor term
  vector[N] nlp_m = X_m * b_m;
  // initialize linear predictor term
  vector[N] nlp_bl = X_bl * b_bl;
  // initialize linear predictor term
  vector[N] nlp_tau = X_tau * b_tau;
  // initialize non-linear predictor term
  vector[N] mu;
  for (n in 1:N) {
    // add more terms to the linear predictor
    nlp_b[n] += r_1_b_1[J_1[n]] * Z_1_b_1[n];
  }
  for (n in 1:N) {
    // add more terms to the linear predictor
    nlp_m[n] += r_2_m_1[J_2[n]] * Z_2_m_1[n];
  }
  for (n in 1:N) {
    // add more terms to the linear predictor
    nlp_bl[n] += r_3_bl_1[J_3[n]] * Z_3_bl_1[n];
  }
  for (n in 1:N) {
    // add more terms to the linear predictor
    nlp_tau[n] += r_4_tau_1[J_4[n]] * Z_4_tau_1[n];
  }
  for (n in 1:N) {
    // compute non-linear predictor values
    mu[n] = ppe(C_1[n] , C_2[n] , C_3[n] , nlp_b[n] , nlp_m[n] , nlp_bl[n] , nlp_tau[n]);
  }
  // priors including all constants
  target += normal_lpdf(b_b | 0, 1);
  target += normal_lpdf(b_m | 0, 1);
  target += normal_lpdf(b_bl | 0, 1);
  target += normal_lpdf(b_tau | 0, 1);
  target += student_t_lpdf(sd_1 | 3, 0, 10)
    - 1 * student_t_lccdf(0 | 3, 0, 10);
  target += normal_lpdf(z_1[1] | 0, 1);
  target += student_t_lpdf(sd_2 | 3, 0, 10)
    - 1 * student_t_lccdf(0 | 3, 0, 10);
  target += normal_lpdf(z_2[1] | 0, 1);
  target += student_t_lpdf(sd_3 | 3, 0, 10)
    - 1 * student_t_lccdf(0 | 3, 0, 10);
  target += normal_lpdf(z_3[1] | 0, 1);
  target += student_t_lpdf(sd_4 | 3, 0, 10)
    - 1 * student_t_lccdf(0 | 3, 0, 10);
  target += normal_lpdf(z_4[1] | 0, 1);
  // likelihood including all constants
  if (!prior_only) {
    target += bernoulli_logit_lpmf(Y | mu);
  }
}
generated quantities {
}

And I’ve tried to incorporate reduce_sum by moving most of the Model code (everything but the priors) to a partial_sum function:

// generated with brms 2.12.0
functions {
real ppe(real N,
           real T,
           real St,
           real b,
           real m,
           real bl,
           real tau) {
    real s = .1;
    real c = .1;
    real d;
    real M;
    real P;
    real forget;
    real learn;

    d = b + m * St;
    
    forget = T^-d;
    learn = (bl + N) ^ c;
    M = learn * forget;
    P = (tau - M)/s;
    return P;
  }
  
  real partial_sum(int[] slice_n_Y, int start, int end, 
                   int[] C_1, vector C_2, vector C_3,
                   int[] J_1, int[] J_2, int[] J_3, int[] J_4,
                   matrix X_b, matrix X_m, matrix X_bl, matrix X_tau,
                   vector b_b, vector b_m, vector b_bl, vector b_tau,
                   vector r_1_b_1, vector r_2_m_1, vector r_3_bl_1, vector r_4_tau_1,
                   vector Z_1_b_1, vector Z_2_m_1, vector Z_3_bl_1, vector Z_4_tau_1) {
                      
                      // initialize linear predictor term
                      vector[size(slice_n_Y)] nlp_b = X_b * b_b;
                      // initialize linear predictor term
                      vector[size(slice_n_Y)] nlp_m = X_m * b_m;
                      // initialize linear predictor term
                      vector[size(slice_n_Y)] nlp_bl = X_bl * b_bl;
                      // initialize linear predictor term
                      vector[size(slice_n_Y)] nlp_tau = X_tau * b_tau;
                      vector[size(slice_n_Y)] mu;
                     
                     for (n in 1:((end-start) + 1)){
                        // add more terms to the linear predictor
                        nlp_b[n] += r_1_b_1[J_1[n]] * Z_1_b_1[n];
                      }
                      for (n in 1:((end-start) + 1)){
                        // add more terms to the linear predictor
                        nlp_m[n] += r_2_m_1[J_2[n]] * Z_2_m_1[n];
                      }
                      for (n in 1:((end-start) + 1)){
                        // add more terms to the linear predictor
                        nlp_bl[n] += r_3_bl_1[J_3[n]] * Z_3_bl_1[n];
                      }
                      for (n in 1:((end-start) + 1)){
                        // add more terms to the linear predictor
                        nlp_tau[n] += r_4_tau_1[J_4[n]] * Z_4_tau_1[n];
                      }
                     for (n in 1:((end-start) + 1)){
                       mu[n] = ppe(C_1[n], C_2[n] , C_3[n] , nlp_b[n] , nlp_m[n] , nlp_bl[n] , nlp_tau[n]);
                     }
    return bernoulli_logit_lpmf(slice_n_Y |
                               mu[1:((end-start) + 1)]);
  }

}
data {
  int<lower=1> N;  // number of observations
  int Y[N];  // response variable
  int<lower=1> K_b;  // number of population-level effects
  matrix[N, K_b] X_b;  // population-level design matrix
  int<lower=1> K_m;  // number of population-level effects
  matrix[N, K_m] X_m;  // population-level design matrix
  int<lower=1> K_bl;  // number of population-level effects
  matrix[N, K_bl] X_bl;  // population-level design matrix
  int<lower=1> K_tau;  // number of population-level effects
  matrix[N, K_tau] X_tau;  // population-level design matrix
  // covariate vectors for non-linear functions
  int C_1[N];
  vector[N] C_2;
  vector[N] C_3;
  // data for group-level effects of ID 1
  int<lower=1> N_1;  // number of grouping levels
  int<lower=1> M_1;  // number of coefficients per level
  int<lower=1> J_1[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_1_b_1;
  // data for group-level effects of ID 2
  int<lower=1> N_2;  // number of grouping levels
  int<lower=1> M_2;  // number of coefficients per level
  int<lower=1> J_2[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_2_m_1;
  // data for group-level effects of ID 3
  int<lower=1> N_3;  // number of grouping levels
  int<lower=1> M_3;  // number of coefficients per level
  int<lower=1> J_3[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_3_bl_1;
  // data for group-level effects of ID 4
  int<lower=1> N_4;  // number of grouping levels
  int<lower=1> M_4;  // number of coefficients per level
  int<lower=1> J_4[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_4_tau_1;
  int prior_only;  // should the likelihood be ignored?
}
transformed data {
}
parameters {
  vector[K_b] b_b;  // population-level effects
  vector[K_m] b_m;  // population-level effects
  vector[K_bl] b_bl;  // population-level effects
  vector[K_tau] b_tau;  // population-level effects
  vector<lower=0>[M_1] sd_1;  // group-level standard deviations
  vector[N_1] z_1[M_1];  // standardized group-level effects
  vector<lower=0>[M_2] sd_2;  // group-level standard deviations
  vector[N_2] z_2[M_2];  // standardized group-level effects
  vector<lower=0>[M_3] sd_3;  // group-level standard deviations
  vector[N_3] z_3[M_3];  // standardized group-level effects
  vector<lower=0>[M_4] sd_4;  // group-level standard deviations
  vector[N_4] z_4[M_4];  // standardized group-level effects
}
transformed parameters {
  vector[N_1] r_1_b_1;  // actual group-level effects
  vector[N_2] r_2_m_1;  // actual group-level effects
  vector[N_3] r_3_bl_1;  // actual group-level effects
  vector[N_4] r_4_tau_1;  // actual group-level effects
  r_1_b_1 = (sd_1[1] * (z_1[1]));
  r_2_m_1 = (sd_2[1] * (z_2[1]));
  r_3_bl_1 = (sd_3[1] * (z_3[1]));
  r_4_tau_1 = (sd_4[1] * (z_4[1]));
}
model {
  // priors including all constants
  target += normal_lpdf(b_b | 0, 1);
  target += normal_lpdf(b_m | 0, 1);
  target += normal_lpdf(b_bl | 0, 1);
  target += normal_lpdf(b_tau | 0, 1);
  target += student_t_lpdf(sd_1 | 3, 0, 10)
    - 1 * student_t_lccdf(0 | 3, 0, 10);
  target += normal_lpdf(z_1[1] | 0, 1);
  target += student_t_lpdf(sd_2 | 3, 0, 10)
    - 1 * student_t_lccdf(0 | 3, 0, 10);
  target += normal_lpdf(z_2[1] | 0, 1);
  target += student_t_lpdf(sd_3 | 3, 0, 10)
    - 1 * student_t_lccdf(0 | 3, 0, 10);
  target += normal_lpdf(z_3[1] | 0, 1);
  target += student_t_lpdf(sd_4 | 3, 0, 10)
    - 1 * student_t_lccdf(0 | 3, 0, 10);
  target += normal_lpdf(z_4[1] | 0, 1);
  // likelihood including all constants
  if (!prior_only) {
    target += reduce_sum(partial_sum, Y, 1, C_1, C_2, C_3,
                          J_1, J_2, J_3, J_4,
                          X_b, X_m, X_bl, X_tau,
                          b_b, b_m, b_bl, b_tau,
                          r_1_b_1, r_2_m_1, r_3_bl_1, r_4_tau_1, 
                          Z_1_b_1, Z_2_m_1, Z_3_bl_1, Z_4_tau_1);
  }
}
generated quantities {
}

This script will run in cmdstanr, but the time to fit is about twice as long even though I’m using all 8 of my logical cores. I’m almost certain that I haven’t made sound use of reduce_sum and was hoping for a bit of guidance on how I can improve the model. Thanks!

I think the slicing is not setup correctly. You multiply X_b * b_b and so on - this refers to the entire data, not just the slice bits and pieces, no?

I would recommend you to first get this to work without reduce_sum, but wit the partial_sum function being called twice. So get this to work

target += partial_sum(Y[1: N/2], 1, N/2, …);
target += partial_sum(Y[N/2 + 1 : N], N/2 + 1, N, …);

once that works and gives you the same results, then switch to reduce_sum.

Right, I need to slice up X_b, b_b, etc. Don’t know why I didn’t catch that earlier. Thanks for the quick response!

NP.

I really recommend you to develop the partial_sum function as described. If this works, then you should be safe - it’s easy to mess up these things and using the partial_sum function for its intended use (doing partials) should catch these issues quickly.