Speeding up Locomotive Wheel Wear Prediction

Hi All,

I have been working on a problem with wheel wear for some time. I have moved from brms using splines back to a staple STAN code. The .stan models were fitting well to data and I was getting it prepped for production before Christmas and have come into significant bottle necks relating to computation time - hence hoping for criticism on how to better structure this.

Problem context:

  • Train Locomotive and Wagon wheels have a wear parameter and distance recorded each week (x and y in data)
  • In one type of fleet I have ~12000 wheels in service. Some of these have been in service for 5 years, some less. So, for a given “wheel life” we could have 250 data points, or 1 if the wheel was changed less than a week ago.
  • This would be one of the larger sets, but there are another 130 fleets which have data too.
  • Originally I was going to use historical wheels (dating back 5 years) to inform the population, but largely overkill - and computationally I am struggling with what we have.

An example of the data feeding into the model looks like that below

> str(sdata)
List of 9
 $ N             : int 341437 #Number of raw data points
 $ J             : num 2960   #Wheel lives (could be 1 to hundreds of wear-distance records depending on wheel age
 $ x             : num [1:341437] 0 0.0317 0.0644 0.0926 0.1118 ...
 $ y             : num [1:341437] 29.5 29.4 29.3 29.3 29.5 ...
 $ idn           : num [1:341437] 1 1 1 1 1 1 1 1 1 1 ...
 $ run_estimation: num 1
 $ thresh        : num 35
 $ wear_init_mu  : num 29.2
 $ wear_init_sd  : num 0.289

Context continued:

  • The rationale was to have previously used and active wheels inform the prediction of current wheels to better plan when these are likely to reach a threshold
  • We get the Individual posteriors for gradients and if we deemed they were not “good”, we would lean on the population value of the gradient instead.
  • I have also developed similar approach has been developed for a non-linear model with 3 parameters.
  • The intercept model is only required for the individual wheel models. In the event of using population predictions I apply them (not shown here) to the last recorded value of the wheel in question.

Trying to speed this up:

  • I have removed a bunch of parameters to try and speed this up (only capturing a and b, albeit a lot of them)
  • This thing, when running for 3000 wheel lives at 5e2 iterations took > 12h before removing the parameters being captured. Now this would take about 4h (although the info regarding the transition times doesn’t seem to change).

Operating Sys:

  • Linux, Ubuntu 20.04
  • 56 CPUs with 2.2GHz each
  • Built in a Docker container from rocker/geospatial:4.2.1
  • rstan version: 2.21.7

Stan model

data {
  int<lower=1> N;                    //the number of individual wear measurements
  int<lower=1> J;                    //number of wheels
  vector[N] x;                       //distance
  vector[N] y;                       //wear
  int<lower=1,upper=J> idn[N];       //vector of group indices
  int<lower=0, upper=1> run_estimation;
  real<lower=1> thresh;
  real wear_init_mu;
  real wear_init_sd;

  vector<lower=15,upper=45> [J] a;                      //intercept for each wheel
  // real mu_a;                        // 
  vector<lower=0> [J] b;            //slope for each wheel
  real mu_b;                        //mean slope 
  // real<lower=0,upper=100> sigma_a;  //sd of intecept (wheel level, but across all)
  real<lower=0,upper=100> sigma_b;  //sd of slope (wheel level, but across all)
  real<lower=0,upper=100> sigma_y;  //residual error of obs

transformed parameters {
  vector[N] y_hat;

  for(i in 1:N)
    y_hat[i] = a[idn[i]] + x[i] * b[idn[i]];

model {
  // mu_a ~ normal(wear_init_mu, 1);
  // sigma_a ~ normal(wear_init_sd, 1);
  // a ~ normal(mu_a, sigma_a);
  a ~ normal(wear_init_mu, 1);

  mu_b ~ gamma(1.5, 1); //added 21 Nov
  sigma_b ~ gamma(1, 1);
  b ~ normal(mu_b, sigma_b);
  // b ~ normal(1.5, 1);  
  sigma_y ~ uniform(0, 100);
  y ~ normal(y_hat, sigma_y);


  • Is there merit in changing the data structure from vectors (for x and y)?
  • Rewriting this in some other way?
  • Open to crit, scathing or otherwise.


1 Like

I did a small edit on the title so folks know what it is you are after. And this reply will bump it up to the top.

1 Like

Have you tried using a non-centered parameterization (25.7 Reparameterization | Stan User’s Guide) for the slope parameters b? That can simplify the posterior geometry for hierarchical models like this, which makes the HMC better behaved (and faster).

What do the HMC diagnostics look like when you fit the model (Rhat, effective sample size, etc.)? Another thing to look at is treedepth – if the model requires a high treedepth that will slow down computation. If that’s the case, reparameterizing may help.

1 Like

Thanks for bumping it up @Ara_Winter and for the response @hsusmann.

I haven’t tried the non-centered parameterization. I will look into it when I get a “lull in the battle” - Thank you!

I moved everything over to CmdStan which is a lot faster (although I did successfully sick a lot of time getting the GPU in Docker to work which was more computationally intense on the CPU’s than the normal approach).

The beauty of it though is I have these huge memory consumption events which are killing the server as I try to reread in the .csvs to wrangle outputs for prediction. The four 1Gb .csvs start to consume close to 30Gb in the container with read_cmdstan_csv… so its another interesting challenge…

Thanks again.