QR decomposition for mixed effects model

Hi
I’m having to code up quite a complex model in Stan (for a large dataset), which is running incredibly slowly. I was advised to try to reparameterise the model, to speed up model fitting. So I’m trying to work out the fastests way to express a much simpler three level hierarchical longitudinal mixed effects model, as if I know how to do that, I’m confident I can apply that to the much more complex model (I’m aware that you can fit this particular model in other packages, but this is an exercise to allow me to understand what I need to do with a more complex model!). I have asked a separate question around scaling parameters (Example for reparametrisation of multivariate normal parameters in a mixed effects model - #2 by Bob_Carpenter)

I have been reading the section of the handbook regarding QR decomposition, and was wondering how this is implemented if you have a mixed effects model?

Specifically, if I am modelling three level continuous longitudinal data, so measurements at level 1 nested within individuals at level 2, nested within say studies at level 3. Then I might have
y(t_{kij}) = X(t_{kij})\beta + Z^{(2)}(t_{kij})b^{(2)}_{ki}+ Z^{(3)}(t_{kij})b^{(3)}_{k}+\varepsilon_{kij}

Where in the above, k=1,...,K indexes the level three studies, i=1,...,N indexes the individuals across all studies, and j=1,...,M indexes the measurements within individuals. The longitudinal outcome is y, the population level / fixed effects are represented by \beta with corresponding design matrix X, the individual specific random effects are b^{(2)} with corresponding design matrix Z^{(2)}, and the study specific random effects are b^{(3)} with corresponding design matrix Z^{(3)}. The random effects are assumed zero mean, and independent across levels. The columns of the Z matrices will be a subset of the columns in X.

In that case, as I have design matrices for fixed and random effects, and the random effects design matrices Z are subsets of the fixed effects design matrix X, if using the QR decomposition, is it applied once to the X design matrix, and then I just subset the Q matrix to get design matrices for the random effects?

I have tried to code that up in the stan model below. Would someone more experienced and knowledgeable than me be able to tell me if this is along the right lines? As I mentioned, I’m trying to get this model running fast as this will teach me what I need to do to get the much more complex model running fast. It needs to run as fast as possible, as it needs to be able to deal with big(ish) datasets - e.g. an example in the below M=~300000, N=~30000, K=~5

Any advice gratefully received!

functions{
    //calculate longitudinal linear predictor
    vector calc_Linpred_y(vector beta_L_all,
                      matrix b2_L,
                      matrix b3_L,
                      matrix X_L_Q_ast,
                      matrix Z2_L_Q_ast,
                      matrix Z3_L_Q_ast,
                      int M,  
                      array[] int indnum,
                      array[] int studnum){ 
        //define vector linpred_y1 to output the linear predictor into
        vector[M] linpred_y1;
        vector[M] term1 = (X_L_Q_ast*beta_L_all);
        vector[M] term2;
        vector[M] term3;
        
        for(i in 1:M){
          int indnum_temp = indnum[i];
          int studnum_temp = studnum[i];
          term2[i] = Z2_L_Q_ast[i,] * b2_L[indnum_temp,];
          term3[i] = Z3_L_Q_ast[i,] * b3_L[studnum_temp,];
        }
        
        //calc long linpred
        linpred_y1 = term1 + term2 + term3;

        //return longitudinal linear predictor
        return(linpred_y1);
    }
}
data {
  //dimensions of data
  int<lower=1> M;      //number of observations
  int<lower=1> N;      //number of individuals
  int<lower=1> K;      //number of studies
  int<lower=1> pl; //max number long fixed = pl_nontreat + (ntreat_minus1 * ntreatsets_Long_pl)
  int<lower=1> ql;     //dimension of master D matrix
  int<lower=1> rl;     //dimension of master A matrix
  
  //int array id vector and study membership vector 
  //for selection of correct row from b2_L and b3_L when multiplying against 
  //z matrices
  array[M] indnum;     //longform idnum
  array[M] studnum;    //longform studnum
  
  //longitudinal outcome information
  vector[M] y1;         //longitudinal outcome

  //longitudinal fixed information
  matrix[M,pl] X_L;
  
  //real<lower=0> beta_L_sig;    //sd for fixed effects prior
  vector[pl] beta_L_mu;    //sd for fixed effects prior
  vector<lower=0>[pl] beta_L_sig;    //sd for fixed effects prior
  real<lower=0> sigma_e_sig;   //involved in prior for sigma_e - sig for longerror
  real<lower=0> sigma_e_nu;    //involved in prior for sigma_e - sig for longerror

  //longitudinal indrand information
  array[ql] int<lower=1> Z2_L_col;
  vector<lower=0>[ql] tau_Dl_sig;    //sd for fixed effects prior
  real Omega_Dl_eta;          //involved in prior for Omega_Dl - corr mat for indrand
  
  //longitudinal studrand information
  array[rl] int<lower=1> Z3_L_col;
  vector<lower=0>[rl] tau_Al_sig;    //sd for fixed effects prior
  real Omega_Al_eta;          //involved in prior for Omega_Al - corr mat for indrand
  
  int do_likelihood;
  int do_prior;
  
}
transformed data{
  vector[ql] zeros_b2_L = rep_vector(0,ql);  //used for matrix multiplication for long indrand
  vector[rl] zeros_b3_L = rep_vector(0,rl);  //used for matrix multiplication for long studrand
  
  matrix[M,pl]  X_L_Q_ast = qr_thin_Q(X_L) * sqrt(M-1);
  matrix[pl,pl] X_L_R_ast = qr_thin_R(X_L) / sqrt(M-1);
  matrix[pl,pl] X_L_R_ast_inverse = inverse(X_L_R_ast);
  
  matrix[M,ql] Z2_L_Q_ast = X_L_Q_ast[,Z2_L_col];
  matrix[M,rl] Z3_L_Q_ast = X_L_Q_ast[,Z3_L_col];
}
parameters{
  //longitudinal parameters
  vector[pl] beta_L_all;
  real<lower=0> sigma_e;  //longitudinal error term (sd)

  matrix[N,ql] b2_L;               //longitudinal indrand effects.  First so many columns relate to fixed effects not treat related, later ones will be treat realted. what treatment will vary by study.  lable in R afterwards
  cholesky_factor_corr[ql] Omega_Dl;    //Correlation Matrix master long indrand
  vector<lower=0>[ql] tau_Dl;           //variance terms master long indrand

  matrix[K,rl] b3_L;                  //longitudinal studrand effects
  cholesky_factor_corr[rl] Omega_Al;   //Correlation Matrix master long studrand
  vector<lower=0>[rl] tau_Al;          //variance terms master long studrand
  
}
transformed parameters{
  cov_matrix[ql] Dl;     //covariance matrix for long indrand
  cov_matrix[rl] Al;     //covariance matrix for long studrand
  Dl = diag_pre_multiply(tau_Dl, Omega_Dl) * diag_pre_multiply(tau_Dl, Omega_Dl)';
  Al = diag_pre_multiply(tau_Al, Omega_Al) * diag_pre_multiply(tau_Al, Omega_Al)';
}
model{
  if (do_likelihood){
    //RANDOM EFFECTS
    for(i in 1:N){
      b2_L[i,] ~ multi_normal(zeros_b2_L, Dl);
    }
    for(i in 1:K){
      b3_L[i,] ~ multi_normal(zeros_b3_L, Al);
    }

    //Likelihood for longitudinal component
    vector[M] linpred_y1 = calc_Linpred_y(beta_L_all,
                      b2_L,
                      b3_L,
                      X_L_Q_ast,
                      Z2_L_Q_ast,
                      Z3_L_Q_ast,
                      M,  
                      indnum,
                      studnum);
    y1 ~ normal(linpred_y1,sigma_e);     //update longitudinal outcome
  }
 
  //Priors
  if (do_prior){
  
    //longitudinal parameter priors
    beta_L_all ~ normal(beta_L_mu,beta_L_sig);  //prior for longitudinal fixed effects
    sigma_e ~ student_t(sigma_e_nu,0,sigma_e_sig);     //prior for longitudinal error term
    tau_Dl ~ normal(0,tau_Dl_sig);
    Omega_Dl~ lkj_corr_cholesky(Omega_Dl_eta);      //prior for corr matrix for long indrand
    tau_Al ~ normal(0,tau_Al_sig);
    Omega_Al~ lkj_corr_cholesky(Omega_Al_eta);    //prior for corr matrix for long studrand
  }
}
generated quantities{
  vector[pl] beta_L_all_out = vector[K] X_L_R_ast_inverse * beta_L_all;
  matrix[N,ql] b2_L_out;
  matrix[K,rl] b3_L_out;
  
  for(i in 1:N){
      b2_L_out[i,] ~ Z2_L_R_ast_inverse * b2_L[i,];
  }
  for(i in 1:K){
    b3_L_out[i,] ~ Z3_L_R_ast_inverse * b3_L[i,];
  }
  
}