tl;dr: Is there anything I can do to speed-up this model with arguments to brm or by tweaking the stan code? And am I specifying the horseshoe priors correctly?
I am trying to fit a model with 20 predictors, where 13 of those predictors have known measurement error I’d like to account for and about 3000 observations. On the workstation I have access to with 16 physical cores (32 logical), I’m finding the model to be prohibitively slow running 2 chains with 2 cores and 7 threads + GPU support. I took a look at the generated stan code and was wondering if there might be a way to rewrite the stan code to help speed things up.
I’m also wondering if I’m specifying the horseshoe prior correctly - should I also be placing a horseshoe prior on the class = meanme
parameters? I want shrinkage across all of the parameters, not just the ones without measurement error.
Here’s some code to generate an example dataset and stan code with only 5 predictors, 3 of which have measurement error.
library(dplyr)
library(brms)
n <- 500
model_data <- tibble(
y = rpois(n, lambda = c(3, 5, 7, 10, 12)),
x1 = rnorm(n),
x2 = rnorm(n),
x3 = rnorm(n),
x4 = rnorm(n),
x5 = rnorm(n),
x1_se = runif(n),
x2_se = runif(n),
x3_se = runif(n)
)
auto_stan_code <- make_stancode(
y ~ me(x1, x1_se) + me(x2, x2_se) + me(x3, x3_se) + x4 + x5,
data = model_data,
backend = 'cmdstanr',
prior = prior(horseshoe(par_ratio = 0.33), class = b) +
prior(normal(0, 1), class = meanme) +
prior(normal(0, 2), class = sdme, lb = 0),
threads = 2,
cores = 2,
chains = 2,
family = 'poisson'
)
and the generated stan code
// generated with brms 2.19.0
functions {
/* Efficient computation of the horseshoe prior
* see Appendix C.1 in https://projecteuclid.org/euclid.ejs/1513306866
* Args:
* z: standardized population-level coefficients
* lambda: local shrinkage parameters
* tau: global shrinkage parameter
* c2: slap regularization parameter
* Returns:
* population-level coefficients following the horseshoe prior
*/
vector horseshoe(vector z, vector lambda, real tau, real c2) {
int K = rows(z);
vector[K] lambda2 = square(lambda);
vector[K] lambda_tilde = sqrt(c2 * lambda2 ./ (c2 + tau^2 * lambda2));
return z .* lambda_tilde * tau;
}
/* integer sequence of values
* Args:
* start: starting integer
* end: ending integer
* Returns:
* an integer sequence from start to end
*/
int[] sequence(int start, int end) {
int seq[end - start + 1];
for (n in 1:num_elements(seq)) {
seq[n] = n + start - 1;
}
return seq;
}
// compute partial sums of the log-likelihood
real partial_log_lik_lpmf(int[] seq, int start, int end, data int[] Y, data matrix Xc, vector b, real Intercept, vector bsp, vector Xme_1, vector Xme_2, vector Xme_3) {
real ptarget = 0;
int N = end - start + 1;
// initialize linear predictor term
vector[N] mu = rep_vector(0.0, N);
mu += Intercept;
for (n in 1:N) {
// add more terms to the linear predictor
int nn = n + start - 1;
mu[n] += (bsp[1]) * Xme_1[nn] + (bsp[2]) * Xme_2[nn] + (bsp[3]) * Xme_3[nn];
}
ptarget += poisson_log_glm_lpmf(Y[start:end] | Xc[start:end], mu, b);
return ptarget;
}
}
data {
int<lower=1> N; // total number of observations
int Y[N]; // response variable
int<lower=1> K; // number of population-level effects
matrix[N, K] X; // population-level design matrix
int<lower=1> Ksp; // number of special effects terms
// data for the horseshoe prior
real<lower=0> hs_df; // local degrees of freedom
real<lower=0> hs_df_global; // global degrees of freedom
real<lower=0> hs_df_slab; // slab degrees of freedom
real<lower=0> hs_scale_global; // global prior scale
real<lower=0> hs_scale_slab; // slab prior scale
int grainsize; // grainsize for threading
// data for noise-free variables
int<lower=1> Mme_1; // number of groups
vector[N] Xn_1; // noisy values
vector<lower=0>[N] noise_1; // measurement noise
vector[N] Xn_2; // noisy values
vector<lower=0>[N] noise_2; // measurement noise
vector[N] Xn_3; // noisy values
vector<lower=0>[N] noise_3; // measurement noise
int<lower=1> NCme_1; // number of latent correlations
int prior_only; // should the likelihood be ignored?
}
transformed data {
int Kc = K - 1;
matrix[N, Kc] Xc; // centered version of X without an intercept
vector[Kc] means_X; // column means of X before centering
int seq[N] = sequence(1, N);
for (i in 2:K) {
means_X[i - 1] = mean(X[, i]);
Xc[, i - 1] = X[, i] - means_X[i - 1];
}
}
parameters {
// local parameters for the horseshoe prior
vector[Kc] zb;
vector<lower=0>[Kc] hs_local;
real Intercept; // temporary intercept for centered predictors
// local parameters for the horseshoe prior
vector[Ksp] zbsp;
vector<lower=0>[Ksp] hs_localsp;
// horseshoe shrinkage parameters
real<lower=0> hs_global; // global shrinkage parameter
real<lower=0> hs_slab; // slab regularization parameter
// parameters for noise free variables
vector[Mme_1] meanme_1; // latent means
vector<lower=0>[Mme_1] sdme_1; // latent SDs
matrix[Mme_1, N] zme_1; // standardized latent values
cholesky_factor_corr[Mme_1] Lme_1; // cholesky factor of the latent correlation matrix
}
transformed parameters {
vector[Kc] b; // population-level effects
// special effects coefficients
vector[Ksp] bsp;
matrix[N, Mme_1] Xme1; // actual latent values
// using separate vectors increases efficiency
vector[N] Xme_1;
// using separate vectors increases efficiency
vector[N] Xme_2;
// using separate vectors increases efficiency
vector[N] Xme_3;
real lprior = 0; // prior contributions to the log posterior
// compute the actual regression coefficients
b = horseshoe(zb, hs_local, hs_global, hs_scale_slab^2 * hs_slab);
// compute the actual regression coefficients
bsp = horseshoe(zbsp, hs_localsp, hs_global, hs_scale_slab^2 * hs_slab);
// compute actual latent values
Xme1 = rep_matrix(transpose(meanme_1), N) + transpose(diag_pre_multiply(sdme_1, Lme_1) * zme_1);
Xme_1 = Xme1[, 1];
Xme_2 = Xme1[, 2];
Xme_3 = Xme1[, 3];
lprior += student_t_lpdf(Intercept | 3, 1.9, 2.5);
lprior += student_t_lpdf(hs_global | hs_df_global, 0, hs_scale_global)
- 1 * log(0.5);
lprior += inv_gamma_lpdf(hs_slab | 0.5 * hs_df_slab, 0.5 * hs_df_slab);
lprior += normal_lpdf(meanme_1 | 0, 1);
lprior += normal_lpdf(sdme_1 | 0, 2)
- 3 * normal_lccdf(0 | 0, 2);
lprior += lkj_corr_cholesky_lpdf(Lme_1 | 1);
}
model {
// likelihood including constants
if (!prior_only) {
target += reduce_sum(partial_log_lik_lpmf, seq, grainsize, Y, Xc, b, Intercept, bsp, Xme_1, Xme_2, Xme_3);
}
// priors including constants
target += lprior;
target += std_normal_lpdf(zb);
target += student_t_lpdf(hs_local | hs_df, 0, 1)
- rows(hs_local) * log(0.5);
target += std_normal_lpdf(zbsp);
target += student_t_lpdf(hs_localsp | hs_df, 0, 1)
- rows(hs_localsp) * log(0.5);
target += normal_lpdf(Xn_1 | Xme_1, noise_1);
target += normal_lpdf(Xn_2 | Xme_2, noise_2);
target += normal_lpdf(Xn_3 | Xme_3, noise_3);
target += std_normal_lpdf(to_vector(zme_1));
}
generated quantities {
// actual population-level intercept
real b_Intercept = Intercept - dot_product(means_X, b);
// obtain latent correlation matrix
corr_matrix[Mme_1] Corme_1 = multiply_lower_tri_self_transpose(Lme_1);
vector<lower=-1,upper=1>[NCme_1] corme_1;
// extract upper diagonal of correlation matrix
for (k in 1:Mme_1) {
for (j in 1:(k - 1)) {
corme_1[choose(k - 1, 2) + j] = Corme_1[j, k];
}
}
}
I see that each measurement error variable is getting assigned its own vector, rather than placing all the measurement errors into one matrix. I’m admittedly a bit of a novice of writing custom Stan code, but it seems to me that having a vector for each, rather than a matrix like for the non-measurement error variables, might be a place to speed things up. Would setting up a Xme matrix and corresponding Xme_noise matrix be one way to speed things up? Other thoughts on how I might make this more efficient?
Session info if relevant:
R version 4.2.1 (2022-06-23 ucrt)
Platform: x86_64-w64-mingw32/x64 (64-bit)
Running under: Windows 10 x64 (build 19044)
Matrix products: default
locale:
[1] LC_COLLATE=English_United States.utf8 LC_CTYPE=English_United States.utf8 LC_MONETARY=English_United States.utf8 LC_NUMERIC=C
[5] LC_TIME=English_United States.utf8
attached base packages:
[1] stats graphics grDevices datasets utils methods base
other attached packages:
[1] dplyr_1.0.10 brms_2.19.0 Rcpp_1.0.9
loaded via a namespace (and not attached):
[1] nlme_3.1-157 matrixStats_0.63.0 xts_0.13.0 lubridate_1.9.0 RColorBrewer_1.1-3 threejs_0.3.3 rprojroot_2.0.3 rstan_2.26.22
[9] tensorA_0.36.2 tools_4.2.1 backports_1.4.1 utf8_1.2.2 R6_2.5.1 DT_0.27 DBI_1.1.3 colorspace_2.0-3
[17] withr_2.5.0 tidyselect_1.2.0 gridExtra_2.3 prettyunits_1.1.1 processx_3.8.0 Brobdingnag_1.2-9 curl_4.3.3 compiler_4.2.1
[25] cli_3.4.1 shinyjs_2.1.0 colourpicker_1.2.0 posterior_1.4.1 scales_1.2.1 dygraphs_1.1.1.6 checkmate_2.1.0 mvtnorm_1.1-3
[33] callr_3.7.3 StanHeaders_2.26.22 stringr_1.4.1 digest_0.6.30 minqa_1.2.5 base64enc_0.1-3 pkgconfig_2.0.3 htmltools_0.5.4
[41] lme4_1.1-33 fastmap_1.1.0 htmlwidgets_1.5.4 rlang_1.0.6 rstudioapi_0.14 shiny_1.7.4 farver_2.1.1 generics_0.1.3
[49] jsonlite_1.8.3 zoo_1.8-11 crosstalk_1.2.0 gtools_3.9.4 distributional_0.3.2 inline_0.3.19 magrittr_2.0.3 loo_2.6.0
[57] bayesplot_1.10.0 Matrix_1.4-1 munsell_0.5.0 fansi_1.0.3 abind_1.4-5 lifecycle_1.0.3 stringi_1.7.8 snakecase_0.11.0
[65] MASS_7.3-57 pkgbuild_1.4.0 plyr_1.8.8 grid_4.2.1 parallel_4.2.1 promises_1.2.0.1 forcats_0.5.2 crayon_1.5.2
[73] miniUI_0.1.1.1 lattice_0.20-45 splines_4.2.1 ps_1.7.2 pillar_1.8.1 igraph_1.3.5 boot_1.3-28 markdown_1.4
[81] shinystan_2.6.0 codetools_0.2-18 reshape2_1.4.4 stats4_4.2.1 rstantools_2.3.1 glue_1.6.2 V8_4.2.2 renv_0.16.0
[89] RcppParallel_5.1.7 nloptr_2.0.3 vctrs_0.5.0 httpuv_1.6.8 gtable_0.3.1 assertthat_0.2.1 ggplot2_3.4.0 mime_0.12
[97] janitor_2.1.0 xtable_1.8-4 coda_0.19-4 later_1.3.0 tibble_3.1.8 shinythemes_1.2.0 timechange_0.1.1 ellipsis_0.3.2
[105] bridgesampling_1.1-2 here_1.0.1