Variable Selection in Parametric Survival Analysis Models

I have a large survival analysis model (Weibull) with about 300k observations on 50+ variables. It takes a very long time to run (>1d), so for that practical reasons I do not want to follow iterative approaches that involves building many different models. I have read about sparsifying priors like the (regularized) Horseshoe as well as the predictive projection method. What don’t understand is with a normal (in both sense of the word) hierarchical prior, such as

...
beta ~ normal(0, sigma);
sigma ~ exponential(2);
...

There are clear self-regularizing properties, in that the prior is in some way learned from the data, roughly speaking. Or is my understanding incorrect? Because that feels like a good start to do this, and then discard any variables with, say 80% or more, of it’s density “very close” to zero. Would this be an approximately fine approach to variable selection for the above mentioned practical reasons, or am I missing key problems here?

The reason for asking is that I find it hard to make the necessary adjustments to my Weibull model to accommodate more strategies like the predictive projection, except if someone can point me in right direction to some relevant resources.

I am currently looking at Projective Inference in High-dimensional Problems and the post Multivariate Variable Selection in Stan for further inspiration, but it might require of me to carry out some derivations by hand first. Any advice will be very much appreciated.

With that many observations and number of variables being much smaller, the posterior should be quite informative and the variable selection is easier.

This is fine and variants of this can be justified also using decision theory, and there are many papers on such thresholding approach. However, if the variables are highly correlating the posterior can be highly correlating and looking at the marginals can be misleading, see e.g. this case study. Correlating variables was one reason to developed projection predictive approach.

In your case of big data and small number of variables

  • (regularized) horseshoe is likely to have only a small effect, but can make the computation much slower
  • projection predictive approach can be slow, although the L1 search would probably be reasonable fast and you could skip the cross-validation of the search path as there is unlikely to be much variation with such big data
  • thresholding is a sensible thing to start with, but check also the posterior correlations

Paper [2109.04702] Latent space projection predictive inference describes how to do projection predictive variable selection with Weibull survival model. The code is in the projpred git repo, but unfortunately not yet in the master branch, and thus we don’t have very easy to follow example. You would like to use varsel instead of cv_varsel (to skip the cross-validation of the search paths), use option method="L1", and set nclusters to a small value. If you are brave to test this, we can help to get it going. I suggest to first test it with much smaller data.

Can you show your model code? Maybe there are something that can be more efficient or to benefit from more recent Stan tricks and compiler optimizations

3 Likes

Thank you @avehtari, this is extremely helpful and you have given me some good direction here, wow! I will do the following:

  • Run with my current code in the mean time, do the thresholding, but inspect posterior correlations as you mentioned. I’m not sure though, what would be considered too high?
  • I will familiarize myself with the paper you suggested and the code in the projpred git repo.

I would definitely like to see if we can test the code together, I will send you a direct message if that is okay.

I think I can share most of the code, some of it might be a bit sensitive (IP related), but should I perhaps share in a different post or can I go ahead and post here?

Post here as it’s on topic. Post what you can. I prefer to help in public threads, as it’s likely that our discussion if helpful for other people, too.

Sure, let me share some code, it would be great to get some advice, since I have only built a hand full of models in Stan so far. I have also added some comments, I hope that helps.

For more context. This is a “late entry” model, with left-truncation and right-censored data. Also, I have time varying coefficients, so all variables can change their value once per period (between start- and stop-times).

I know there are different parameterisations of the Weibull model, let me know if I perhaps made a mistake in hooking up the linear part of the model to the scale parameter. I changed a few things for clarity, but this is the gist of it.

data {
  int<lower=0> N; //Number of observations
  int<lower=0> P; // Number of numeric predictors
  int<lower=0> P_ind; // Number of factor variables (binary)
  vector[N] d; // Survival event variable 1:Death, 0:Censored
  vector<lower=0>[N] stop; // age at stop time
  vector<lower=0>[N] start; // age at start time
  int<lower = 1> X_ind[N, P_ind]; // Index matrix for categorical variables. Values are either 1 or 2
  matrix[N, P] X;  //Design matrix for numeric variables
}

parameters {
  real<lower=0> alpha; //Shape parameter of Weibull
  vector<lower=0>[P] beta_sigma; // Coefficient for numeric predictors
  vector<lower=0>[P_ind] beta_ind_sigma1; // Stdev for index variable =1 
  vector<lower=0>[P_ind] beta_ind_sigma2; //Stdev for index variable = 2 
  vector[P] z_bs; // Normal(0, 1) value for non-centered parameterisation for numeric variables.
  vector[P_ind] z_ibs1; // Normal(0, 1) value for non-centered parameterisation for index variables.
  vector[P_ind] z_ibs2;
}
transformed parameters {
  vector[P] beta; 
  matrix[P_ind, 2] beta_ind; //Two columns representing parameters for index variables
  beta = z_bs .* beta_sigma;
  beta_ind[,1] = z_ibs1 .* beta_ind_sigma1;
  beta_ind[,2] = z_ibs2 .* beta_ind_sigma2;
}

model {
  vector[N] sigma;
  vector[N] sigma_ind_linear;
  vector[P_ind] tempsum;

// Priors
  beta_sigma ~ exponential(3);
  beta_ind_sigma1 ~ exponential(3);
  beta_ind_sigma2 ~ exponential(3);
  z_bs ~ normal(0, 1);
  z_ibs1 ~ normal(0, 1);
  z_ibs2 ~ normal(0, 1);
  alpha ~ gamma(2, 2);

// Model
// Adding up index variables contribution to target
  for(i in 1:N){
    for(j in 1:P_ind){
        tempsum[j] =  beta_ind[j, X_ind[i, j]];
    }
    sigma_ind_linear[i] = sum(tempsum);
  }
// Linear model
  sigma = exp((X*beta) + sigma_ind_linear)+0.001; // adding delta for numerical stability.
  
  for(i in 1:N){
    if(d[i]==0){
      target += weibull_lccdf(stop[i]|alpha, sigma[i]) - weibull_lccdf(start[i]| alpha, sigma[i]);
    }
    else {
      target += weibull_lpdf(stop[i]|alpha, sigma[i]) - weibull_lccdf(start[i]| alpha, sigma[i]);
    }
  }
}

You can make your model code much faster by not looping over Weibull functions. Just divide the data to observed and censored so that you don’t need check d[i]==0 on loop and you can vectorize both target += lines. See e.g. Stan User’s Guide and A Survival Model in Stan

You can also vectorize computing sigma_ind_linear, but that has smaller effect on speed

1 Like

Ah yes, I have seen that before quite a while ago. I will do that and post the code, and hopefully the improvement in execution time. I know how to do the vectorization part, but I’m just a bit unsure how you see the vectorization of the sigma_ind_linear part though, but let me have a look.

Unfortunately it seems to run as slow as before, if not slower. I’ll post my code here if anyone can spot whether I made a mistake in the vectorization and splitting up the data. However, there might be a difference between these because so far it seems like I’m getting more sensible results with this (so maybe there is a mistake somewhere in the code above, but it is not immediately clear to me how they can be different). This will now become my new benchmark then I suppose, but 15k records takes >2 hours and I have >300k records to fit, so I really hope I can still speed this up. What would be helpful is if I could do this in batches, using the posterior of the one output as the prior for the next run at least, but as I understand that is still not straightforward.

data {
  int<lower=0> N_obs;
  int<lower=0> N_cens;
  int<lower=0> P;
  int<lower=0> P_ind;
  vector<lower=0>[N_obs] stop_obs;
  vector<lower=0>[N_cens] stop_cens;
  vector<lower=0>[N_obs] start_obs;
  vector<lower=0>[N_cens] start_cens;
  int<lower = 1> X_ind_obs[N_obs, P_ind];
  int<lower = 1> X_ind_cens[N_cens, P_ind];
  matrix[N_obs, P] X_obs;
  matrix[N_cens, P] X_cens;
}

parameters {
  real<lower=0> alpha; //Shape parameter of Weibull
  vector<lower=0>[P] beta_sigma; // Coefficient for numeric predictors
  vector<lower=0>[P_ind] beta_ind_sigma1; // Stdev for index variable =1 
  vector<lower=0>[P_ind] beta_ind_sigma2; //Stdev for index variable = 2 
  vector[P] z_bs; // Normal(0, 1) value for non-centered parameterisation for numeric variables.
  vector[P_ind] z_ibs1; // Normal(0, 1) value for non-centered parameterisation for index variables.
  vector[P_ind] z_ibs2;
}
transformed parameters {
  vector[P] beta; 
  matrix[P_ind, 2] beta_ind; //Two columns representing parameters for index variables
  beta = z_bs .* beta_sigma;
  beta_ind[,1] = z_ibs1 .* beta_ind_sigma1;
  beta_ind[,2] = z_ibs2 .* beta_ind_sigma2;
}

model {
  vector[N_obs] sigma_obs;
  vector[N_cens] sigma_cens;
  vector[N_cens] sigma_ind_linear_cens;
  vector[N_obs] sigma_ind_linear_obs;
  vector[P_ind] tempsum;

// Priors
  beta_sigma ~ exponential(3);
  beta_ind_sigma1 ~ exponential(3);
  beta_ind_sigma2 ~ exponential(3);
  z_bs ~ normal(0, 1);
  z_ibs1 ~ normal(0, 1);
  z_ibs2 ~ normal(0, 1);
  alpha ~ gamma(2, 2);

// Model
// Adding up index variables contribution to target
for(i in 1:N_obs){
    for(j in 1:P_ind){
        tempsum[j] =  beta_ind[X_ind_obs[i, j]][j];
    }
    sigma_ind_linear_obs[i] = sum(tempsum);
  }

 for(i in 1:N_cens){
    for(j in 1:P_ind){
        tempsum[j] =  beta_ind[X_ind_cens[i, j]][j];
    }
    sigma_ind_linear_cens[i] = sum(tempsum);
  }
// Linear model
 sigma_obs = exp(mu + (X_obs*beta) + sigma_ind_linear_obs)+0.0001;
 sigma_cens = exp(mu + (X_cens*beta) + sigma_ind_linear_cens)+0.0001;
  
 target += weibull_lccdf(stop_cens|alpha, sigma_cens) - weibull_lccdf(start_cens| alpha, sigma_cens);
 target += weibull_lpdf(stop_obs|alpha, sigma_obs) - weibull_lccdf(start_obs| alpha, sigma_obs);

}

Easy speedups: 1) you tagged is as rstan, so switch to CmdStanR to use the latest Stan version, 2) add to make/local (can be done from R) CXXFLAGS += -march=native -mtune=native can drop computation time 50% , 3) use stanc_options = list("O1") when calling cmdstan_model, even in smaller problems reduced computation time 25%

Further speedup can be obtained by using OpenBLAS or MKL to use more threads without changing the code Speedup by using external BLAS/LAPACK with CmdStan and CmdStanR/Py

Further speedup with code changes (or use brms to build your code) is to use sum_reduce

All these without GPUs.

Further speedup could be obtained by making the posterior easier to sample

  • this is not necessarily the best choice for big data as the likelihood is very informative, and it’s possible this will create a bad funnel
  beta = z_bs .* beta_sigma;
  beta_ind[,1] = z_ibs1 .* beta_ind_sigma1;
  beta_ind[,2] = z_ibs2 .* beta_ind_sigma2;
  • why do you have sigmas as vectors? You told P<<N, and then as I see, these could be scalars, that would make the posterior much easier
  vector<lower=0>[P] beta_sigma; // Coefficient for numeric predictors
  vector<lower=0>[P_ind] beta_ind_sigma1; // Stdev for index variable =1 
  vector<lower=0>[P_ind] beta_ind_sigma2; //Stdev for index variable = 2 
  • all exponential priors are suspicious

If you can provide further information about the posterior convergence diagnostics, ESSs and some mcmc_pairs plots, they might provide additional useful information to help you.

2 Likes

Thanks so much for the thoughtful reply @avehtari, I appreciate your time. Wow, you have given me a lot of homework :) Let me start setting up the Rtools and CmdstanR to start implementing the suggestions you made above. I will give some feedback soon. (You can see how new I am to all of these tools, e.g. I haven’t used brms yet, so I will need to look at that as well).

Regarding your other questions and suggestions:

Do you mean I should rather stick the centered version of this?

So I was trying to set hierarchical priors for each of the numerical and index/categorical variables. I.e I want to estimate the posterior of the variances for each linear model’s parameters instead of specifying a single variance estimate for all variables. I was hoping that by doing it this way there will be some form of regularization (adaptive priors), but I think I see some flaw in my reasoning here. So making it a scalar will still enable me to model the distribution of possible variances for the parameters, right?

I know that the standard deviations should be quite low (like 1, 2 or maybe three at most), therefore the simple exponential prior. Can you clarify why they are suspicious?

Sure, I can share some output, but the code above executed beautifully without any convergence issues, well-mixed chains, posteriors in ranges that I would expect, no complaints about ESS being too low, and not a single divergent transition. So the fit look really good on all fronts, it is just slow. Let me know if you need more specific plots, I will just have to be a bit discrete with the data/features I’m using currently.

1 Like

Happy to help, and great that my short answer written in a hurry are clear enough to get forward

Likely. Difficult to be certain without testing or being able to analyse the current posterior etc.

You can look at the posterior of each coefficient directly. Such hierarchical presentation of the prior is usually used only for computational reasons and there is no benefit what information you can get. Your hierarchical prior corresponds to Laplace (double exponential) prior, which you could implement also directly without the hierarchical structure. Horseshoe and regularized horseshoe are examples of priors where the hierarchical structure seems to help computation, and even in them, it seems the version with least number of additional parameters is best. (R)HS priors are often implemented in non-centered way as they are often used in n<<p case, but in n>>p case the centered parameterization is likely to be more efficient.

As you have n>>p, it is also likely that the likelihood dominates and the shape of the prior has ony a small effect. Thus it’s good to start with the simplest normal prior, and look what you can infer with that.

Great

Let’s see first what you get out of the model with scalar sigmas

1 Like

I don’t have time to explore all possible combinations of your suggested optimisations to determine which ones are adding the most i.t.o. computational efficiency. But I’m implementing some of the easy ones on the CmdStan side first and then I will report on differences based on the code side-side change that will make the posterior easier to sample.

What I have done now:

  • Switched to CmdStanR from Rstan
  • Added the following to make/local by running the following R code (I hope I have done this correctly:
    • cpp_options = list("CXXFLAGS += -march=native -mtune=native", "STAN_CPP_OPTIMS=true")
    • cmdstanr::cmdstan_make_local(cpp_options = cpp_options, append = TRUE)
    • cmdstanr::rebuild_cmdstan(cores = 4)
  • I compiled the model with the following stanc_options in R:
    • mod <- cmdstan_model(stan_file = './my_model.stan', stanc_options = list("O1"))

So I have implemented suggestions 1-3.

This part I still find a bit confusing and I’m a bit out of my depth. Also, it would take a bit of time to set up so I might need to leave this as the last option to explore. I will look into reduce_sum, again this might require some more work, I’m trying the easy (potential) wins first.

RESULT 1: My previous run took 2.5h (for 15k records), but this time it only took 1.5 hours, with the same good diagnostics as before (3 divergent transitions only, no other issues). I have already noticed a speedup just due to the fact that I’m using CmdStanR, so I’m not sure which options added the most to the improvement. That is about 60% of the previous run time (of course there is variation in these estimates, but I assume these two are different enough that it is a really good improvement :)

Next, changing the sigmas to scalars and following the centered approach.

4 Likes

Great! I guess that the most speedup in your case comes from -march=native -mtune=native.

This is a great result! The combination of O1 and -march=native -mtune=native will be pretty powerful moving forward. The -march=native enables more efficient handling of calculations involving matrices, and the O1 allows the values/adjoints to be fully represented as separate matrices. Using both options allows for the most efficient processing of large matrix parameters

2 Likes

what version of RStan were you using?

I’m using Rstan Version: 2.21.2 (and R version 4.1.32). Let me know if you need more info.

UPDATE: RESULT 2: I have changed to scalar sigmas as you suggested @avehtari and I tried both centered and non-centered, for 15k and 100k records respectively. As one would expect there is a significant drop in computation time with a lot less parameters. Also, the centered version was running faster as you predicted:

  • 15k sample size
    • Centered: 35minutes
    • Non-centered: 48 minutes
  • 100k Sample size
    • Centered: About 7.4 hours
    • Non-centered: 13.7 hours.

I’m simply reporting on execution times here, but assessing whether the quality of the results are the same would take a bit more time to investigate, however from all diagnostics, trace-plots, etc. it looks all fine and very similar. The centered approach did return quite a few divergent transitions (340) when running on15k samples, but that can be dealt with easily. However, for the 100k run, it only showed 2 divergent transitions and all diagnostics look good.

Summary so far:

  1. using CmdStan instead of Rstan,
  2. using cpp_options = list("CXXFLAGS += -march=native -mtune=native", "STAN_CPP_OPTIMS=true"),
  3. switching to centered parameterization, and
  4. obviously reducing the number of (hyper) parameters to estimate,
  5. using stanc_options = list(“O1”)

seems to be the winning combination so far :) My estimated run time for >300K observations is still going to be quite long, but at least I’m not going to wait weeks. I’ll keep experimenting with some of the other suggestions in parallel. Thanks again for the help so far @avehtari.

EDIT: added bullet (5), forgot to mention.

Great!

More observations makes centered parameterization posterior closer to normal.

Now that you switched to less parameters, it’s likely that the posterior with >300K observations is close to normal, and you can try optimize() method. Unfortunately CmdStanR doesn’t (yet) provide Hessian, but optimizing will give you an idea what the timing could be and then it would depend on your exact model what to recommend next.

When using sample() you could probably use shorter warmup, and during experimentation of different model also less post-warmup iterarions can be used.

With sum_reduce you can get about linear speedup with the number of physical cores, and sometimes even more, as cache efficiency can improve. But then maybe you start to be in the range where sampling is already faster than the time you would spent on implementing some additional speedup?

You didn’t list stanc_options = list("O1"). Did you test with and without? It is possible that in your case the size of the data dominates the cost so that specific new O1 optimization is minor.

With a larger number of observations, it is possible that the for loop with a lot of indexing is also taking a substantial amount of time.

Although the next point is beyond the user tricks, I mention for Stan devs reading this that, here weibull_log_glm* functions would be useful, but then it seems there are quite many distributions without glm functions.

1 Like

Makes sense, thanks for pointing that out!

I will definitely try optimize(), as for my practical situation, (penalized) maximum likelihood estimates will be sufficient, so thanks for the suggestion. I’m quite new to Stan, so I’m learning a lot about other available functionality in this one thread:)

That is where my thoughts were going as well, even 100 samples less per chain will make a difference at this scale for me.

Yes, for now at least, I think we are reaching that point, but it feels you have given me enough interesting optimization options that I would still like to apply in a less rushed situation, and I will post relevant feedback here.

Sorry, I will edit and indicate the edit, I have indeed used this option as well, but I actually haven’t tested it without it.

I’m running the model on all data currently and will give some final comments on timing soon. If time permits, I will also look into reduce_sum() and one of your earlier suggestions of using OpenBLAS or MKL.

1 Like