Reduce computation time with highly-correlated covariates

Hi. I have a linear mixed effects model

theta_i(t) = alpha_1x + alpha_2xt + alpha_3xt^2 + u_{i0} + u_{i1}t + u_{i2}t^2

The “xt” and “xt^2” terms are highly correlated, x is a fixed covariate and t represents time.
My stan code for this part is

theta =  alpha[1] * X + alpha[2] * X .* time + alpha[3] * X .* square(time) + u[1][subject] + u[2][subject] .* time + u[3][subject] .* square(time);

The model fitting is not bad, all reaching convergence, and I only got warnings about transitions after warmup that exceeded the maximum treedepth. My concern is the computation time, it took about 150 seconds to run one chain. If I remove “xt^2”, it would take 45 seconds to run one chain; if I remove both “xt” and “xt^2”, it would take only 15 seconds to run one chain. (2,000 iterations)
Are there other ways to reduce computation time without removing covariates?
Any advice is welcome, thank you!

Some diagnostics: for get_num_leapfrog_per_iteration(), what I obtained are mostly 511, some 1023.
pairs plot

Inference for Stan model: 84a5589a580f1afb0030bfd212537eb6.
4 chains, each with iter=2000; warmup=1000; thin=1; 
post-warmup draws per chain=1000, total post-warmup draws=4000.

          mean se_mean   sd  2.5%   25%   50%   75% 97.5% n_eff Rhat
alpha[1]  0.47    0.01 0.23  0.03  0.32  0.47  0.62  0.94  1300    1
alpha[2]  1.01    0.02 0.68 -0.30  0.54  1.00  1.48  2.32  1088    1
alpha[3] -0.35    0.01 0.47 -1.27 -0.65 -0.35 -0.02  0.54  1125    1

Samples were drawn using NUTS(diag_e) at Sat Jun 27 16:28:41 2020.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at 
convergence, Rhat=1).

BTW, I tried metric = dense_e, the computation time increased to 180 seconds.

I think what you’re looking for is the the QR reparameterization trick.

2 Likes

Hi Mike, thank you for your advice.
I apply QR Reparameterization to predictor matrix, (x, xt, xt^2).
Now my code is

theta =  Q_ast * alpha + u[1][subject] + u[2][subject] .* time + u[3][subject] .* square(time);

The pairs plot looked much better than previous one.

          mean se_mean   sd  2.5%   25%   50%   75% 97.5% n_eff Rhat
alpha[1]  4.51    0.02 0.91  2.75  3.87  4.50  5.15  6.27  2090    1
alpha[2]  1.31    0.01 0.57  0.17  0.93  1.31  1.69  2.47  4871    1
alpha[3] -0.37    0.01 0.49 -1.38 -0.71 -0.37 -0.04  0.59  2694    1

However, computation time is only slightly reduced to 135~140 seconds.
So I guess the QR reparametrization mainly fix collinearity problem.
For get_num_leapfrog_per_iteration(), what I obtained are all 511, no 1023 this time.
Any possibilities I could reduce running time below 60 seconds?

Presumably this is something you’re going to be running repeatedly in some sort of production environment, hence the concern for even such a short single-run time?

I doubt it will make much difference, but you could pre-compute square(time) in the transformed data section instead of re-computing it every time the model. Ditto X.*time and X.*square(time) presuming X is data. Finally, you could check if there’s any redundancy in time and X; if there is, you can reduce them to just their unique values, compute their product with their respective alpha, then index into that result to get the thetas (for terms that involve both, like X.*time you’d want to find the unique combinations of X and time) .

Oh, since you mention this is a mixed effects model: have you tried both centered and non-centered versions? Sometimes that reparamterization can influence the sampling time.

Thank you for the reply.

Yes, this is only for one small simulated dataset, containing about 800 data points. I plan to repeat the simulation for at least 500 times.
As you suggested, I pre-computed terms like t^2 and X*t, now the computation time for one chain reduced to 125 seconds. Considering there are 4 chains, it is still not satisfying.

I’ve already applied non-centered parametrization to the random effects u.

My dataset is actually a longitudinal dataset, including Y as longitudinal data, X as baseline covariate and time as measurement time, which means for each subject, X are the same across different time points, and time is equally spaced. Something like

   id time        X        Y
1   1 0.00 5.159634 5.600736
2   1 0.25 5.159634 4.092070
3   1 0.50 5.159634 3.102044
4   1 0.75 5.159634 2.070282
8   2 0.00 4.892276 3.714980
9   2 0.25 4.892276 3.261769
15  3 0.00 5.026449 4.753733
16  3 0.25 5.026449 2.240989
17  3 0.50 5.026449 3.534182

Does this kind of data structure of time and X contain “redundancy”?

But have you checked how the centered parameterization performs? I believe that there are cases (has to do with how much data you have per subject vs how many subjects you have) where one can expect centered to perform better than non-centered.

Yes there’s redundancy, but now that I look at the formula in your first post, I see that you’re probably not going to get much speed up by eliminating it. So, for example, if you have nT samples for each subject (one sample for each time point), then the computation of alpha[1]*Xinvolves computing a product nT times per subject when merely once would be sufficient (since each subject has a unique value in X. Since X is real valued I presume there’s no repeated values in that column across subjects; if there were, you could furthermore account for that redundancy.

But since you have only one observation per subject per time and all the other terms involve time there’s no more redundancy to eliminate beyond that first rather minor redundancy in alpha[1]*X. I guess you could still get rid of some redundancy in any terms involving X by first computing against the unique entries in X and then indexing the result to then multiply with time.

Try the centered parameterization and if you don’t find it helps and you need further guidance on implementing a redundancy-minimized computation, reply back here and I’ll work on it.

1 Like

Many thanks for the prompt reply.

I first applied centered parametrization and the running time for one chain is 200 seconds, with convergence problems (Rhat > 2). I switched it to non-centered version and the model fitting was good, all Rhat < 1.01, and 150 seconds for one chain.

Do you mean that I should introduce X as X[N], N is the number of subjects? Since now I introduce X as X[ms], ms is the number of all measurements. And then apply QR reparametriztion to predictor matrix (Xt, Xt^2)?
BTW, there are 150 subjects in dataset and about 800 measurements, which means 5 measurements for each subject in average.

Ah, darn, I’d forgotten about the QR stuff. I’m not sure my redundancy-reducing stuff is compatible with the QR trick. Hm, I’ll have to let this mull a bit. In the interim, could you post your full model? I might be able to see other areas where you can improve things (ex. often folks can have slow/badly-sampled models when using ridicuously uninformed priors, etc).

Thanks for checking model code for me.

data {
  int<lower=1> N; // Number of subjects
  int<lower=1> ms; // Number of longitudinal measurements
  int subject[ms]; // subject ID
  int<lower=1> K; // Number of longitudinal outcomes
  vector[obs] Y[K];
  vector<lower=0>[ms] time;
  vector<lower=0>[ms] tt; // pre-computed square(time)
  matrix[ms,3] x; // predictor matrix (x, xt, xt^2)
}
parameters {
  vector[3] alpha;
  vector<lower=0>[K] sigma_e;
  vector[3] mu;
  vector<lower=0>[3] sigma;
  vector[N] u_raw[3];
}
transformed parameters {
  vector[N] u[3];  // random effects
  vector[ms] theta; // the model to describe longitudinal trajectories
  matrix[ms,3] Q_ast;
  matrix[3,3] R_ast;
  matrix[3,3] R_ast_inverse;
  // thin and scale the QR decomposition
  Q_ast = qr_thin_Q(x) * sqrt(ms - 1);
  R_ast = qr_thin_R(x) / sqrt(ms - 1);
  R_ast_inverse = inverse(R_ast);
  for (j in 1:3) {
    u[j] = mu[j] + sigma[j] * u_raw[j];
  }
  theta =  Q_ast * alpha + u[1][subject] + u[2][subject] .* time + u[3][subject] .* tt;
}
model {
  for (k in 1:K) {
    Y[k] ~ normal(theta,sigma_e[k]);
  }
  for (j in 1:3) u_raw[j] ~ std_normal(); 
  alpha ~ normal(0,100);
  mu ~ normal(0,100);
  sigma_e ~ normal(0,100);
  sigma ~ normal(0,100);
}

There are 2 longitudinal markers (K = 2) with similar trends, my model is Y_ik(t) = theta_i(t) + measurement error_k. If I remove the fixed effects part (all part contains X), the computation time for one chain is less than 10 seconds.

Oh! Do the QR stuff in the transformed data section, not transformed parameters; it only needs to be done once per data set, while putting it in the TP section means it gets done repeatedly.

Also, do you genuinely have uncertainty that spans 3 orders of magnitude for all your parameters? Generally, with proper data scaling, normal(0,1) is considered to be reasonably weakly-informative, with normal(0,10) being the most extreme I’ve seen, so normal(0,100) might be forcing the sampler explore very unreasonable areas of the parameter space.

Also, this bit:

for (j in 1:3) u_raw[j] ~ std_normal();

Can be alternatively expressed as:

to_vector(u_raw) ~ std_normal();

But I’m not sure if the latter is any faster computationally than the former, but worth a look.

Also, mind editing it to remove the magic number 3? I think that some of those 3s should be either ms or K, but can’t figure out which and it would be easier to follow if it were explicit.

Oops, my bad. I reset the QR parametrization, tightening priors to normal(0,10). The running time for one chain is 90 seconds now, improved a lot compared with previous ones.

There are many 3s due to I have 3 fixed effects terms and 3 random effects terms. I’ve reviewed it thoroughly and it should be good.
90 seconds for one chain is still not so good, but I guess it is hard to improve more?
Again, thank you for advises, really helpful.

Progress: I removed the mu which represents the mean of random effects. Now the random effects simulation is u[j] = sigma[j] * u_raw[j];, assuming the mean of random effects is 0. It may be a stronger prior comparing with previous prior, but now the computation time for one chain is below 10 seconds. I think it is worthy dong this for efficiency.

I did not follow all the optimisations in this thread but in one of the outputs, I saw that n_eff is about 2000. You probably don’t need that level of precision and can get away with lower effective sampels. So, one way to speed up the simulation would be to run shorter chains.

Hi stijn, thank you for your advice, I noticed that the n_eff for all parameters are bigger than 2,000. So I think maybe 1,000 iteration for one chain is fine.

Oh! Sorry I didn’t catch that. Yeah, it’s way mite common to model the random effects as having a mean of zero, and failing to do so would have left weaker identifiability in the model that would have slowed down sampling.

I think @stijn’s point is that the higher the n_eff, the more precision for increasingly extreme quantiles in the posterior. If you’re only interested in the median and, say. 90% interval, you probably don’t need even 2000 n_eff.

1 Like

Another speed-up thought for you: if you have multiple cores available, you could try out the new/experimental campfire warmup, which seems to be able to terminate warmup earlier than the standard warmup.

I only used the mean value of estimated parameters for following analysis, so yes, I think I do not need n_eff this high.
And thank you for the information about campfire warmup, I will try it.