Rewrite brms stan code to speed up model with multiple measurement error predictors

tl;dr: Is there anything I can do to speed-up this model with arguments to brm or by tweaking the stan code? And am I specifying the horseshoe priors correctly?

I am trying to fit a model with 20 predictors, where 13 of those predictors have known measurement error I’d like to account for and about 3000 observations. On the workstation I have access to with 16 physical cores (32 logical), I’m finding the model to be prohibitively slow running 2 chains with 2 cores and 7 threads + GPU support. I took a look at the generated stan code and was wondering if there might be a way to rewrite the stan code to help speed things up.
I’m also wondering if I’m specifying the horseshoe prior correctly - should I also be placing a horseshoe prior on the class = meanme parameters? I want shrinkage across all of the parameters, not just the ones without measurement error.

Here’s some code to generate an example dataset and stan code with only 5 predictors, 3 of which have measurement error.

library(dplyr)
library(brms)
n <- 500
model_data <- tibble(
  y = rpois(n, lambda = c(3, 5, 7, 10, 12)),
  x1 = rnorm(n),
  x2 = rnorm(n),
  x3 = rnorm(n),
  x4 = rnorm(n),
  x5 = rnorm(n),
  x1_se = runif(n),
  x2_se = runif(n),
  x3_se = runif(n)
)

auto_stan_code <- make_stancode(
  y ~ me(x1, x1_se) + me(x2, x2_se) + me(x3, x3_se) + x4 + x5,
  data = model_data,
  backend = 'cmdstanr',
  prior = prior(horseshoe(par_ratio = 0.33), class = b) +
    prior(normal(0, 1), class = meanme) +
    prior(normal(0, 2), class = sdme, lb = 0),
  threads = 2,
  cores = 2,
  chains = 2,
  family = 'poisson'
)

and the generated stan code

// generated with brms 2.19.0
functions {
  /* Efficient computation of the horseshoe prior
   * see Appendix C.1 in https://projecteuclid.org/euclid.ejs/1513306866
   * Args:
   *   z: standardized population-level coefficients
   *   lambda: local shrinkage parameters
   *   tau: global shrinkage parameter
   *   c2: slap regularization parameter
   * Returns:
   *   population-level coefficients following the horseshoe prior
   */
  vector horseshoe(vector z, vector lambda, real tau, real c2) {
    int K = rows(z);
    vector[K] lambda2 = square(lambda);
    vector[K] lambda_tilde = sqrt(c2 * lambda2 ./ (c2 + tau^2 * lambda2));
    return z .* lambda_tilde * tau;
  }
  /* integer sequence of values
   * Args:
   *   start: starting integer
   *   end: ending integer
   * Returns:
   *   an integer sequence from start to end
   */
  int[] sequence(int start, int end) {
    int seq[end - start + 1];
    for (n in 1:num_elements(seq)) {
      seq[n] = n + start - 1;
    }
    return seq;
  }
  // compute partial sums of the log-likelihood
  real partial_log_lik_lpmf(int[] seq, int start, int end, data int[] Y, data matrix Xc, vector b, real Intercept, vector bsp, vector Xme_1, vector Xme_2, vector Xme_3) {
    real ptarget = 0;
    int N = end - start + 1;
    // initialize linear predictor term
    vector[N] mu = rep_vector(0.0, N);
    mu += Intercept;
    for (n in 1:N) {
      // add more terms to the linear predictor
      int nn = n + start - 1;
      mu[n] += (bsp[1]) * Xme_1[nn] + (bsp[2]) * Xme_2[nn] + (bsp[3]) * Xme_3[nn];
    }
    ptarget += poisson_log_glm_lpmf(Y[start:end] | Xc[start:end], mu, b);
    return ptarget;
  }
}
data {
  int<lower=1> N;  // total number of observations
  int Y[N];  // response variable
  int<lower=1> K;  // number of population-level effects
  matrix[N, K] X;  // population-level design matrix
  int<lower=1> Ksp;  // number of special effects terms
  // data for the horseshoe prior
  real<lower=0> hs_df;  // local degrees of freedom
  real<lower=0> hs_df_global;  // global degrees of freedom
  real<lower=0> hs_df_slab;  // slab degrees of freedom
  real<lower=0> hs_scale_global;  // global prior scale
  real<lower=0> hs_scale_slab;  // slab prior scale
  int grainsize;  // grainsize for threading
  // data for noise-free variables
  int<lower=1> Mme_1;  // number of groups
  vector[N] Xn_1;  // noisy values
  vector<lower=0>[N] noise_1;  // measurement noise
  vector[N] Xn_2;  // noisy values
  vector<lower=0>[N] noise_2;  // measurement noise
  vector[N] Xn_3;  // noisy values
  vector<lower=0>[N] noise_3;  // measurement noise
  int<lower=1> NCme_1;  // number of latent correlations
  int prior_only;  // should the likelihood be ignored?
}
transformed data {
  int Kc = K - 1;
  matrix[N, Kc] Xc;  // centered version of X without an intercept
  vector[Kc] means_X;  // column means of X before centering
  int seq[N] = sequence(1, N);
  for (i in 2:K) {
    means_X[i - 1] = mean(X[, i]);
    Xc[, i - 1] = X[, i] - means_X[i - 1];
  }
}
parameters {
  // local parameters for the horseshoe prior
  vector[Kc] zb;
  vector<lower=0>[Kc] hs_local;
  real Intercept;  // temporary intercept for centered predictors
  // local parameters for the horseshoe prior
  vector[Ksp] zbsp;
  vector<lower=0>[Ksp] hs_localsp;
  // horseshoe shrinkage parameters
  real<lower=0> hs_global;  // global shrinkage parameter
  real<lower=0> hs_slab;  // slab regularization parameter
  // parameters for noise free variables
  vector[Mme_1] meanme_1;  // latent means
  vector<lower=0>[Mme_1] sdme_1;  // latent SDs
  matrix[Mme_1, N] zme_1;  // standardized latent values
  cholesky_factor_corr[Mme_1] Lme_1;  // cholesky factor of the latent correlation matrix
}
transformed parameters {
  vector[Kc] b;  // population-level effects
  // special effects coefficients
  vector[Ksp] bsp;
  matrix[N, Mme_1] Xme1;  // actual latent values
  // using separate vectors increases efficiency
  vector[N] Xme_1;
  // using separate vectors increases efficiency
  vector[N] Xme_2;
  // using separate vectors increases efficiency
  vector[N] Xme_3;
  real lprior = 0;  // prior contributions to the log posterior
  // compute the actual regression coefficients
  b = horseshoe(zb, hs_local, hs_global, hs_scale_slab^2 * hs_slab);
  // compute the actual regression coefficients
  bsp = horseshoe(zbsp, hs_localsp, hs_global, hs_scale_slab^2 * hs_slab);
  // compute actual latent values
  Xme1 = rep_matrix(transpose(meanme_1), N) + transpose(diag_pre_multiply(sdme_1, Lme_1) * zme_1);
  Xme_1 = Xme1[, 1];
  Xme_2 = Xme1[, 2];
  Xme_3 = Xme1[, 3];
  lprior += student_t_lpdf(Intercept | 3, 1.9, 2.5);
  lprior += student_t_lpdf(hs_global | hs_df_global, 0, hs_scale_global)
    - 1 * log(0.5);
  lprior += inv_gamma_lpdf(hs_slab | 0.5 * hs_df_slab, 0.5 * hs_df_slab);
  lprior += normal_lpdf(meanme_1 | 0, 1);
  lprior += normal_lpdf(sdme_1 | 0, 2)
    - 3 * normal_lccdf(0 | 0, 2);
  lprior += lkj_corr_cholesky_lpdf(Lme_1 | 1);
}
model {
  // likelihood including constants
  if (!prior_only) {
    target += reduce_sum(partial_log_lik_lpmf, seq, grainsize, Y, Xc, b, Intercept, bsp, Xme_1, Xme_2, Xme_3);
  }
  // priors including constants
  target += lprior;
  target += std_normal_lpdf(zb);
  target += student_t_lpdf(hs_local | hs_df, 0, 1)
    - rows(hs_local) * log(0.5);
  target += std_normal_lpdf(zbsp);
  target += student_t_lpdf(hs_localsp | hs_df, 0, 1)
    - rows(hs_localsp) * log(0.5);
  target += normal_lpdf(Xn_1 | Xme_1, noise_1);
  target += normal_lpdf(Xn_2 | Xme_2, noise_2);
  target += normal_lpdf(Xn_3 | Xme_3, noise_3);
  target += std_normal_lpdf(to_vector(zme_1));
}
generated quantities {
  // actual population-level intercept
  real b_Intercept = Intercept - dot_product(means_X, b);
  // obtain latent correlation matrix
  corr_matrix[Mme_1] Corme_1 = multiply_lower_tri_self_transpose(Lme_1);
  vector<lower=-1,upper=1>[NCme_1] corme_1;
  // extract upper diagonal of correlation matrix
  for (k in 1:Mme_1) {
    for (j in 1:(k - 1)) {
      corme_1[choose(k - 1, 2) + j] = Corme_1[j, k];
    }
  }
}

I see that each measurement error variable is getting assigned its own vector, rather than placing all the measurement errors into one matrix. I’m admittedly a bit of a novice of writing custom Stan code, but it seems to me that having a vector for each, rather than a matrix like for the non-measurement error variables, might be a place to speed things up. Would setting up a Xme matrix and corresponding Xme_noise matrix be one way to speed things up? Other thoughts on how I might make this more efficient?

Session info if relevant:

R version 4.2.1 (2022-06-23 ucrt)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 19044)

Matrix products: default

locale:
[1] LC_COLLATE=English_United States.utf8  LC_CTYPE=English_United States.utf8    LC_MONETARY=English_United States.utf8 LC_NUMERIC=C                          
[5] LC_TIME=English_United States.utf8    

attached base packages:
[1] stats     graphics  grDevices datasets  utils     methods   base     

other attached packages:
[1] dplyr_1.0.10 brms_2.19.0  Rcpp_1.0.9  

loaded via a namespace (and not attached):
  [1] nlme_3.1-157         matrixStats_0.63.0   xts_0.13.0           lubridate_1.9.0      RColorBrewer_1.1-3   threejs_0.3.3        rprojroot_2.0.3      rstan_2.26.22       
  [9] tensorA_0.36.2       tools_4.2.1          backports_1.4.1      utf8_1.2.2           R6_2.5.1             DT_0.27              DBI_1.1.3            colorspace_2.0-3    
 [17] withr_2.5.0          tidyselect_1.2.0     gridExtra_2.3        prettyunits_1.1.1    processx_3.8.0       Brobdingnag_1.2-9    curl_4.3.3           compiler_4.2.1      
 [25] cli_3.4.1            shinyjs_2.1.0        colourpicker_1.2.0   posterior_1.4.1      scales_1.2.1         dygraphs_1.1.1.6     checkmate_2.1.0      mvtnorm_1.1-3       
 [33] callr_3.7.3          StanHeaders_2.26.22  stringr_1.4.1        digest_0.6.30        minqa_1.2.5          base64enc_0.1-3      pkgconfig_2.0.3      htmltools_0.5.4     
 [41] lme4_1.1-33          fastmap_1.1.0        htmlwidgets_1.5.4    rlang_1.0.6          rstudioapi_0.14      shiny_1.7.4          farver_2.1.1         generics_0.1.3      
 [49] jsonlite_1.8.3       zoo_1.8-11           crosstalk_1.2.0      gtools_3.9.4         distributional_0.3.2 inline_0.3.19        magrittr_2.0.3       loo_2.6.0           
 [57] bayesplot_1.10.0     Matrix_1.4-1         munsell_0.5.0        fansi_1.0.3          abind_1.4-5          lifecycle_1.0.3      stringi_1.7.8        snakecase_0.11.0    
 [65] MASS_7.3-57          pkgbuild_1.4.0       plyr_1.8.8           grid_4.2.1           parallel_4.2.1       promises_1.2.0.1     forcats_0.5.2        crayon_1.5.2        
 [73] miniUI_0.1.1.1       lattice_0.20-45      splines_4.2.1        ps_1.7.2             pillar_1.8.1         igraph_1.3.5         boot_1.3-28          markdown_1.4        
 [81] shinystan_2.6.0      codetools_0.2-18     reshape2_1.4.4       stats4_4.2.1         rstantools_2.3.1     glue_1.6.2           V8_4.2.2             renv_0.16.0         
 [89] RcppParallel_5.1.7   nloptr_2.0.3         vctrs_0.5.0          httpuv_1.6.8         gtable_0.3.1         assertthat_0.2.1     ggplot2_3.4.0        mime_0.12           
 [97] janitor_2.1.0        xtable_1.8-4         coda_0.19-4          later_1.3.0          tibble_3.1.8         shinythemes_1.2.0    timechange_0.1.1     ellipsis_0.3.2      
[105] bridgesampling_1.1-2 here_1.0.1  
1 Like