Optimizing Stan Code to estimate variable transformation parameters

I am trying to fit a hierarchical regression model using “rstan”. In addition to estimating all the regression parameters, there are some variables whose transformation parameters also need to be estimated.
These transformations are non-linear and it is difficult to define a formula to use “brms”. Currently, I am using the make_stancode() to get a structure of code and I further edit this code to include the variable transformation parameters. Finally, I run the model in “rstan”.
The formula used for “make_stancode()”:

"Y ~ 1 + Var1 + Var2 + Var3 + Var4 + Var5 + Var6 + Var7 + Var8 + Var9 + Var10 + Var11 + Var12 + Var13 + Var14 + Var15 + Var16 + Var17 + Var18 + Var19 + Var20 + ( 1 + Var7 | GROUP2 ) + ( -1 + Var14 | GROUP2 ) + ( -1 + Var17 | GROUP2 ) + ( -1 + Var19 | GROUP2 ) + ( -1 + Var20 | GROUP2 ) + ( 1 + Var1 | GROUP1 ) + ( -1 + Var6 | GROUP1 ) + ( -1 + Var8 | GROUP1 ) + ( -1 + Var9 | GROUP1 ) + ( -1 + Var10 | GROUP1 ) + ( -1 + Var11 | GROUP1 ) + ( -1 + Var12 | GROUP1 ) + ( -1 + Var13 | GROUP1 )"

The process flow for transformation of a variable (Vi):

  1. Weighted cumulative sum of each index in the variable (Vi). (Weights as parameters)
  2. Addition of two or more variables to get to a final set of variables (Ci).
  3. Applying a non-linear function (“gamma_cdf()”) for each variable (Ci). (alpha and slope as parameters)
// generated with brms 2.13.0
functions {
  
  // Generating weights for the first transformation
  vector Ad_weights(vector params, vector theta, real[] x, int[] y) {
    real z = pow(0.5, ((y[1]-1)*(1/params[1])));
    return [z]';
  }
  
  // First Transformation function for each variable (Vi)
  vector First(vector params, vector theta, real[] x, int[] y) {
    real z = dot_product(to_vector(x), params)/sum(params);
    return [z]';
  }
}
data {
  //Similar to "make_stancode()"
  int<lower=1> N;  // number of observations
  vector[N] Y;  // response variable
  int<lower=1> K;  // number of population-level effects
  matrix[N, K] X;  // population-level design matrix
  // data for group-level effects of ID 1
  int<lower=1> N_1;  // number of grouping levels
  int<lower=1> M_1;  // number of coefficients per level
  int<lower=1> J_1[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_1_1;
  vector[N] Z_1_2;
  int<lower=1> NC_1;  // number of group-level correlations
  // data for group-level effects of ID 2
  int<lower=1> N_2;  // number of grouping levels
  int<lower=1> M_2;  // number of coefficients per level
  int<lower=1> J_2[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_2_1;
  // data for group-level effects of ID 3
  int<lower=1> N_3;  // number of grouping levels
  int<lower=1> M_3;  // number of coefficients per level
  int<lower=1> J_3[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_3_1;
  // data for group-level effects of ID 4
  int<lower=1> N_4;  // number of grouping levels
  int<lower=1> M_4;  // number of coefficients per level
  int<lower=1> J_4[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_4_1;
  // data for group-level effects of ID 5
  int<lower=1> N_5;  // number of grouping levels
  int<lower=1> M_5;  // number of coefficients per level
  int<lower=1> J_5[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_5_1;
  // data for group-level effects of ID 6
  int<lower=1> N_6;  // number of grouping levels
  int<lower=1> M_6;  // number of coefficients per level
  int<lower=1> J_6[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_6_1;
  // data for group-level effects of ID 7
  int<lower=1> N_7;  // number of grouping levels
  int<lower=1> M_7;  // number of coefficients per level
  int<lower=1> J_7[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_7_1;
  // data for group-level effects of ID 8
  int<lower=1> N_8;  // number of grouping levels
  int<lower=1> M_8;  // number of coefficients per level
  int<lower=1> J_8[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_8_1;
  // data for group-level effects of ID 9
  int<lower=1> N_9;  // number of grouping levels
  int<lower=1> M_9;  // number of coefficients per level
  int<lower=1> J_9[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_9_1;
  vector[N] Z_9_2;
  int<lower=1> NC_9;  // number of group-level correlations
  // data for group-level effects of ID 10
  int<lower=1> N_10;  // number of grouping levels
  int<lower=1> M_10;  // number of coefficients per level
  int<lower=1> J_10[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_10_1;
  // data for group-level effects of ID 11
  int<lower=1> N_11;  // number of grouping levels
  int<lower=1> M_11;  // number of coefficients per level
  int<lower=1> J_11[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_11_1;
  // data for group-level effects of ID 12
  int<lower=1> N_12;  // number of grouping levels
  int<lower=1> M_12;  // number of coefficients per level
  int<lower=1> J_12[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_12_1;
  // data for group-level effects of ID 13
  int<lower=1> N_13;  // number of grouping levels
  int<lower=1> M_13;  // number of coefficients per level
  int<lower=1> J_13[N];  // grouping indicator per observation
  // group-level predictor values
  vector[N] Z_13_1;
  int prior_only;  // should the likelihood be ignored?

  //customised
  int<lower = 1> num_media; // no. of individual variables in the First Transformation
  int<lower=1> max_lag;  // maximum lag effect (delay) to be considered
  real X_media[num_media, N, max_lag]; // 3d lag matrix for the complete for Vi set of variables
  int<lower = 1> custom_num_media;// no. of variables used in the Second Transformation
  int<lower =1> arange[max_lag,1];// arange (from 1:max_lag) matrix for the map function in the first transformation
}
transformed data {
  //All dummies required by the map function in the First Transformation
  vector[0] w_dummy1[max_lag];
  real w_dummy2[max_lag, 1];
  vector[0] a_dummy1[N];
  int a_dummy2[N,1];
  
  
  //mean centering for variables that remain constant in each iteration
  int B = K-1-custom_num_media;
  matrix[N, B] Xb;
  vector[B] means_Xb;
  int start_index = K-B+1;
  
  
  for (i in 1:B) {
    means_Xb[i] = mean(X[,start_index]);
    Xb[, i] = X[,start_index] - means_Xb[i];
  }
  
}
parameters {
  vector[K-1] b;  // population-level effects
  real Intercept;  // temporary intercept for centered predictors
  real<lower=0> sigma;  // residual SD
  vector<lower=0>[M_1] sd_1;  // group-level standard deviations
  matrix[M_1, N_1] z_1;  // standardized group-level effects
  cholesky_factor_corr[M_1] L_1;  // cholesky factor of correlation matrix
  vector<lower=0>[M_2] sd_2;  // group-level standard deviations
  vector[N_2] z_2[M_2];  // standardized group-level effects
  vector<lower=0>[M_3] sd_3;  // group-level standard deviations
  vector[N_3] z_3[M_3];  // standardized group-level effects
  vector<lower=0>[M_4] sd_4;  // group-level standard deviations
  vector[N_4] z_4[M_4];  // standardized group-level effects
  vector<lower=0>[M_5] sd_5;  // group-level standard deviations
  vector[N_5] z_5[M_5];  // standardized group-level effects
  vector<lower=0>[M_6] sd_6;  // group-level standard deviations
  vector[N_6] z_6[M_6];  // standardized group-level effects
  vector<lower=0>[M_7] sd_7;  // group-level standard deviations
  vector[N_7] z_7[M_7];  // standardized group-level effects
  vector<lower=0>[M_8] sd_8;  // group-level standard deviations
  vector[N_8] z_8[M_8];  // standardized group-level effects
  vector<lower=0>[M_9] sd_9;  // group-level standard deviations
  matrix[M_9, N_9] z_9;  // standardized group-level effects
  cholesky_factor_corr[M_9] L_9;  // cholesky factor of correlation matrix
  vector<lower=0>[M_10] sd_10;  // group-level standard deviations
  vector[N_10] z_10[M_10];  // standardized group-level effects
  vector<lower=0>[M_11] sd_11;  // group-level standard deviations
  vector[N_11] z_11[M_11];  // standardized group-level effects
  vector<lower=0>[M_12] sd_12;  // group-level standard deviations
  vector[N_12] z_12[M_12];  // standardized group-level effects
  vector<lower=0>[M_13] sd_13;  // group-level standard deviations
  vector[N_13] z_13[M_13];  // standardized group-level effects
  
  //First Transformation parameter
  vector<lower=0,upper=2>[num_media] eng_factor;
  //Second Transformation parameters
  matrix<lower=0>[2,custom_num_media] alpha_slope;
  
}
transformed parameters {
  matrix[N_1, M_1] r_1;  // actual group-level effects
  // using vectors speeds up indexing in loops
  vector[N_1] r_1_1;
  vector[N_1] r_1_2;
  vector[N_2] r_2_1;  // actual group-level effects
  vector[N_3] r_3_1;  // actual group-level effects
  vector[N_4] r_4_1;  // actual group-level effects
  vector[N_5] r_5_1;  // actual group-level effects
  vector[N_6] r_6_1;  // actual group-level effects
  vector[N_7] r_7_1;  // actual group-level effects
  vector[N_8] r_8_1;  // actual group-level effects
  matrix[N_9, M_9] r_9;  // actual group-level effects
  // using vectors speeds up indexing in loops
  vector[N_9] r_9_1;
  vector[N_9] r_9_2;
  vector[N_10] r_10_1;  // actual group-level effects
  vector[N_11] r_11_1;  // actual group-level effects
  vector[N_12] r_12_1;  // actual group-level effects
  vector[N_13] r_13_1;  // actual group-level effects
  
  vector[max_lag] weights[num_media];
  vector[N] Adstocked[num_media];
  matrix[N, custom_num_media] Ad_Gamma;
  matrix[N, custom_num_media] Xc;
  vector[N] mu;
  
  //Calculating the expected adjusted cdf to skip multiple calculations
  real minus_lccdf = 16 * student_t_lccdf(0 | 3, 0, 2.5);
  
  
  // compute actual group-level effects
  r_1 = (diag_pre_multiply(sd_1, L_1) * z_1)';
  r_1_1 = r_1[, 1];
  r_1_2 = r_1[, 2];
  r_2_1 = (sd_2[1] * (z_2[1]));
  r_3_1 = (sd_3[1] * (z_3[1]));
  r_4_1 = (sd_4[1] * (z_4[1]));
  r_5_1 = (sd_5[1] * (z_5[1]));
  r_6_1 = (sd_6[1] * (z_6[1]));
  r_7_1 = (sd_7[1] * (z_7[1]));
  r_8_1 = (sd_8[1] * (z_8[1]));
  // compute actual group-level effects
  r_9 = (diag_pre_multiply(sd_9, L_9) * z_9)';
  r_9_1 = r_9[, 1];
  r_9_2 = r_9[, 2];
  r_10_1 = (sd_10[1] * (z_10[1]));
  r_11_1 = (sd_11[1] * (z_11[1]));
  r_12_1 = (sd_12[1] * (z_12[1]));
  r_13_1 = (sd_13[1] * (z_13[1]));
  
  //First Transformation using map functions 
  for (i in 1:num_media) {
    weights[i] = map_rect(Ad_weights, to_vector([eng_factor[i]]), w_dummy1, w_dummy2, arange);
    Adstocked[i] = map_rect(First, weights[i], a_dummy1, X_media[i,], a_dummy2);
  }
  
  //Adding two or more variables to get final set of variables (Ci)
  Ad_Gamma[,1] = Adstocked[1] + Adstocked[2];
  Ad_Gamma[,2] = Adstocked[3] + Adstocked[4] + Adstocked[5] + Adstocked[6] + Adstocked[7];
  Ad_Gamma[,3] = Adstocked[8];
  Ad_Gamma[,4] = Adstocked[9];
  Ad_Gamma[,5] = Adstocked[10];
  
  
  //Second Transformation for each variable (Ci)
  for (k in 1:custom_num_media) {
    for (j in 1:N){
      Ad_Gamma[j,k] = gamma_cdf(Ad_Gamma[j,k], alpha_slope[1,k], 1/alpha_slope[2,k]);
    }
  }
  
  //Mean centering for transformed variables
  for (i in 1:custom_num_media) {
    Xc[, i] = Ad_Gamma[,i] - mean(Ad_Gamma[,i]);
  }
  
  //Accumulating linear effects for both fixed and random parts by vectorising
  mu = Intercept + Xb * b[(custom_num_media + 1):(K-1)] + Xc * b[1:custom_num_media] + r_1_1[J_1] .* Z_1_1 + r_1_2[J_1] .* Ad_Gamma[,1] + r_2_1[J_2] .* Z_2_1 + r_3_1[J_3] .* Z_3_1 + r_4_1[J_4] .* Z_4_1 + r_5_1[J_5] .* Z_5_1 + r_6_1[J_6] .* Z_6_1 + r_7_1[J_7] .* Z_7_1 + r_8_1[J_8] .* Z_8_1 + r_9_1[J_9] .* Z_9_1 + r_9_2[J_9] .* Z_9_2 + r_10_1[J_10] .* Z_10_1 + r_11_1[J_11] .* Z_11_1 + r_12_1[J_12] .* Z_12_1 + r_13_1[J_13] .* Z_13_1;
  
}


model {
  
  // priors including all constants (vectorising wherever possible)
  
  b ~ std_normal();
  
  
  to_vector(z_1) ~ std_normal();
  z_2[1] ~ std_normal();
  z_3[1] ~ std_normal();
  z_4[1] ~ std_normal();
  z_5[1] ~ std_normal();
  z_6[1] ~ std_normal();
  z_7[1] ~ std_normal();
  z_8[1] ~ std_normal();
  to_vector(z_9) ~ std_normal();
  z_10[1] ~ std_normal();
  z_11[1] ~ std_normal();
  z_12[1] ~ std_normal();
  z_13[1] ~ std_normal();
  
  //setting priors for variable transformation parameters
  to_vector(alpha_slope) ~ normal(1.4, 0.3);
  eng_factor ~ gamma(4, 4);
  
  
  target += student_t_lpdf(Intercept | 3, 5, 2.5);
  target += student_t_lpdf(sigma | 3, 0, 2.5);
  target += student_t_lpdf(sd_1 | 3, 0, 2.5);
  target += lkj_corr_cholesky_lpdf(L_1 | 1);
  target += student_t_lpdf(sd_2 | 3, 0, 2.5);
  target += student_t_lpdf(sd_3 | 3, 0, 2.5);
  target += student_t_lpdf(sd_4 | 3, 0, 2.5);
  target += student_t_lpdf(sd_5 | 3, 0, 2.5);
  target += student_t_lpdf(sd_6 | 3, 0, 2.5);
  target += student_t_lpdf(sd_7 | 3, 0, 2.5);
  target += student_t_lpdf(sd_8 | 3, 0, 2.5);
  target += student_t_lpdf(sd_9 | 3, 0, 2.5);
  target += lkj_corr_cholesky_lpdf(L_9 | 1);
  target += student_t_lpdf(sd_10 | 3, 0, 2.5);
  target += student_t_lpdf(sd_11 | 3, 0, 2.5);
  target += student_t_lpdf(sd_12 | 3, 0, 2.5);
  target += student_t_lpdf(sd_13 | 3, 0, 2.5);
  
  //subtracting constant lccdf to normalise according to "make_stancode()"
  target += -minus_lccdf;
    
  
  // likelihood including all constants
  if (!prior_only) {
    target += normal_lpdf(Y | mu, sigma);
  }
}

The model took around 6 hrs using the conventional “make_stancode()” code on a VM with 10 cores.
Further, I have tried the suggestions made by @JLC and @andrjohns in the post to reduce run-time a bit.
The warnings are:

I have two questions mainly:
1. How to reduce run-time further?
2. How to perform posterior predictive checks and converge results?

Thanks,
@msk98

Sorry for the late reply, your question fell through a bit.

The divergences/treedepth are an issue that needs to be resolved first, it is quite likely that after resolving the cause of the divergences the model will be much faster.

A first step would be to separate your new functions and the linear model and test them separately - debugging such large models in their full complexity is extremely challenging. Would just fitting the brms formula converge?

Second, build a simple pure Stan model where you compute Ad_Gamma with a minimal number (does 1 make sense?) of num_media and num_custom_media. Treat Ad_Gamma as directly observed (e.g. Y ~ normal(Ad_Gamma, 1);) Create a simulated dataset Y following the model exactly. Can you fit this model?

Note also that you can somewhat increase your comfort by using stanvar to inject Stan code in the brms model - this way you can then alter the brms formula and keep your modifications and also still take advantage of some of the brms helper functions.

Thanks @martinmodrak,
Sorry for the late reply, missed the notification.
As you have pointed to test the linear model first, I tried to fit the linear model with two approaches, one with “rstanarm” and other with “brms”.

But, the results from rstanarm’s “stan_lmer()” converged with fewer divergent errors (9) and in lesser time when compared to “brm()”. The “brm()” function was around 4 times slower and resulted in many divergent transitions.

This could be because the group-level standard deviations are modeled differently in both approaches as pointed out by @bgoodri in the post. I am new to brms/stan and I am not confident on how to go about the modeling process (setting priors or re-parameterizing). Any help on how to approach this will be much appreciated. Please let me know if any information is required.

Short on time so just quick notes:

  • Remove terms from your formula to find the smallest model that diverges and largest model that converges. You have a big formula, so it is quite possible you just don’t have enough data to inform all the parameters.
  • Try not modelling correlations in the varying intercepts (i.e. use (1||group) instead of 1|group )
  • Check out our new Divergent transitions - a primer