How to improve model sampling speed when applied to high-dimension data

Hello! I am new to stan and it’s my first time posting here. I am analyzing data from an epidemiology study and use catalytic epidemiology model. My input is a 4 dimension matrix of infected cases, specific to age, municipality, year, and sex. Since we divide the dimensions to many pieces, there are many combinations with 0 cases (so to some extent sparse data). I am using cmdstanr on clusters. The model works well on 400 municipalities, but the run time is about 7 hours on HPC (4 chains in total; one core for one chain). However, I’ve had a lot of difficulty fitting the model to the whole dataset (5570 municipalities in total). It cannot finish in a suitable amount of time. I am looking for advice about how to make the model more efficient. Thank you so much for any suggestions!

//----- Time-dependent catalytic model -----//
  data {
    int nA; // N age groups
    int nT; // N time points
    int nL; // N locations
    int nCasePoints;
    int nPoints;
    array[2,nL,nT,nA] int cases; //Cases
    array[2,nL] matrix[nT,nA] pop; // population
    array[2,nL,nT,nA] int pointIndex; // index of each case point for log-likelihood vector
    array[nA] int aMin; // index for age groups,Arrays defining the age group range (nA age groups thus nA minimum age) indices
    array[nA] int aMax; // index for age groups
  }

parameters {
  array[nL,nT] real<upper=-1> log_lambda; // time-varying FOI (log scale)
  array[2,nA] real logit_rho; // sex and age-dependent reporting rates (logit scale)
  real<lower=1> phi; // overdispersion parameter
}

transformed parameters {
// use input data and parameters to calculate intermediate values of estimates
  array[2,nL] matrix<lower=0, upper=1>[nT,100] S; // susceptible
  array[2,nL] matrix<lower=0, upper=1>[nT,100] I; // infected
  array[2,nL] matrix<lower=0, upper=1>[nT,100] R; // immune
  array[2,nL] matrix<lower=0, upper=1>[nT,nA] Sg; // aggregated to age groups (since later need to compare with the observed data which aggregated to age groups)
  array[2,nL] matrix<lower=0, upper=1>[nT,nA] Ig;
  array[2,nL] matrix<lower=0, upper=1>[nT,nA] Rg;
  array[2,nL,nT,nA] real pCases; // predicted cases
  array[nL,nT] real lambda = exp(log_lambda); // time-varying FOI (linear scale)
  array[2,nA] real rho=inv_logit(logit_rho); // sex and age-dependent reporting rates (linear scale)
  array[2,nL,nT,nA] real ITot; // number of infections

  // initial conditions (first year of transmission)
  // initialize all the data for t=1, assume proportion of infected (I) is proportional to FOI (lambda)
  for(s in 1:2){
    for(l in 1:nL){
      S[s,l,1,] = 1 - lambda[l,1]*rep_row_vector(1,100); 
      I[s,l,1,] = lambda[l,1]*rep_row_vector(1,100);
      R[s,l,1,] = lambda[l,1]*rep_row_vector(1,100);
    }
  }
  
  // loop through subsequent yearly timesteps
  for(s in 1:2){
  // first setup the data for age=1; at age=1, all time/location/sex are susceptible with no infection, so S=1; I/R=0
    for (t in 2:nT){
      S[s,1:nL,t,1] = rep_array(1,nL); // new babies
      I[s,1:nL,t,1] = rep_array(0,nL);
      R[s,1:nL,t,1] = rep_array(0,nL);
    }
  // then use the age=1 data (for all sex/time/location) to loop and get data from age=2 to An
    for(l in 1:nL){
      for(t in 2:nT){
        S[s,l,t,2:100] = S[s,l,t-1,1:99] - lambda[l,t]*S[s,l,t-1,1:99];
        I[s,l,t,2:100] = lambda[l,t]*S[s,l,t-1,1:99];
        R[s,l,t,2:100] = 1-S[s,l,t,2:100];
      }
    }
  }
  
  // aggregate to age groups
  for(s in 1:2) for(l in 1:nL) for(a in 1:nA) for(t in 1:nT) {
    Sg[s,l,t,a] = mean(S[s,l,t,aMin[a]:aMax[a]]);
    Ig[s,l,t,a] = mean(I[s,l,t,aMin[a]:aMax[a]]);
    Rg[s,l,t,a] = mean(R[s,l,t,aMin[a]:aMax[a]]);
    
    ITot[s,l,t,a] = Ig[s,l,t,a]*pop[s,l,t,a];
    pCases[s,l,t,a] = rho[s,a]*ITot[s,l,t,a];
  }
  
}


model {
  
  //--- Priors ---//
  for(l in 1:nL){
    log_lambda[l,] ~ normal(-6,1);
  }

  phi ~ normal(4,1);

  for(s in 1:2){
    logit_rho[s,] ~ normal(-5,0.5);
  }
  
  //--- Likelihood ---//
    //for(t in 1:nT) for (s in 1:2) cases[t,,s] ~ poisson(pCases[s,t,]);
  for (s in 1:2) for (l in 1:nL) for(t in 1:nT) cases[s,l,t,] ~ neg_binomial_2(pCases[s,l,t,], phi);

}

generated quantities {
  array[nPoints] real log_lik;
  real log_lik_sum = 0.0;
  real rmse = 0.0;
  
  for (s in 1:2){
    for(l in 1:nL){
      for(t in 1:nT){
        for(a in 1:nA){
          log_lik[pointIndex[s,l,t,a]] = neg_binomial_2_lpmf(cases[s,l,t,a] | pCases[s,l,t,a], phi);
          log_lik_sum += log_lik[pointIndex[s,l,t,a]];
          rmse += (pCases[s,l,t,a] - cases[s,l,t,a])^2;
        }
      }
    }
  }
  
  rmse = (rmse/nPoints) ^ .5;
}

I took a stab at vectorizing some of the calculations. There’s certainly more things you could vectorize, but I was getting very sleepy.

At least on my machine, with the very fake data that I generated, it cut the time by around 30-40%. None of my testing included the generated quantities block, so I commented out everything related to that.

If you really do have sparse data, it might be much more efficient to use a long, table-like data structure instead of arrays of matrices. Here’s the section of the manual: Sparse and Ragged Data Structures

//----- Time-dependent catalytic model -----//
  data {
    int nA; // N age groups
    int nT; // N time points
    int nL; // N locations
    // int nCasePoints;
    // int nPoints;
    array[2,nT,nL,nA] int cases; //Cases
    array[2,nT] matrix[nL,nA] pop; // population
    // array[2,nL,nT,nA] int pointIndex; // index of each case point for log-likelihood vector
    array[nA] int aMin; // index for age groups,Arrays defining the age group range (nA age groups thus nA minimum age) indices
    array[nA] int aMax; // index for age groups
  }

parameters {
  matrix<upper=-1>[nL, nT] log_lambda; // time-varying FOI (log scale)
  matrix[2,nA] logit_rho; // sex and age-dependent reporting rates (logit scale)
  real<lower=1> phi; // overdispersion parameter
}

transformed parameters {
// use input data and parameters to calculate intermediate values of estimates
  array[2,nT] matrix<lower=0, upper=1>[nL,100] S; // susceptible
  array[2,nT] matrix<lower=0, upper=1>[nL,100] I; // infected
  array[2,nT] matrix<lower=0, upper=1>[nL,100] R; // immune
  // array[2,nL] matrix<lower=0, upper=1>[nT,nA] Sg; // aggregated to age groups (since later need to compare with the observed data which aggregated to age groups)
  array[2,nT] matrix<lower=0, upper=1>[nL,nA] Ig;
  // array[2,nL] matrix<lower=0, upper=1>[nT,nA] Rg;
  array[2,nT,nL,nA] real pCases; // predicted cases
  matrix[nL,nT] lambda = exp(log_lambda); // time-varying FOI (linear scale)
  matrix[2,nA] rho=inv_logit(logit_rho); // sex and age-dependent reporting rates (linear scale)
  array[2,nT,nL,nA] real ITot; // number of infections

  // initial conditions (first year of transmission)
  // initialize all the data for t=1, assume proportion of infected (I) is proportional to FOI (lambda)
  for(s in 1:2){
    for(l in 1:nL){
      S[s,1,l,] = 1 - lambda[l,1]*rep_row_vector(1,100); 
      I[s,1,l,] = lambda[l,1]*rep_row_vector(1,100);
      R[s,1,l,] = lambda[l,1]*rep_row_vector(1,100);
    }
  }
  
  // loop through subsequent yearly timesteps
  for(s in 1:2){
  // first setup the data for age=1; at age=1, all time/location/sex are susceptible with no infection, so S=1; I/R=0
    for (t in 2:nT){
      S[s,t,1:nL,1] = rep_vector(1,nL); // new babies
      I[s,t,1:nL,1] = rep_vector(0,nL);
      R[s,t,1:nL,1] = rep_vector(0,nL);
    }
  // then use the age=1 data (for all sex/time/location) to loop and get data from age=2 to An
    // for(l in 1:nL){
      for(t in 2:nT){
        I[s,t,:,2:100] = diag_pre_multiply(lambda[:,t], S[s,t-1, :,1:99]);

        S[s,t,:,2:100] = S[s,t-1,:,1:99] - I[s,t,:,2:100];
        R[s,t,:,2:100] = 1-S[s,t,:,2:100];
      }
    // }
  }
  
  // aggregate to age groups
  for(s in 1:2) for(l in 1:nL) for(a in 1:nA) for(t in 1:nT) {
    // Sg[s,l,t,a] = mean(S[s,l,t,aMin[a]:aMax[a]]);
    Ig[s,t,l,a] = mean(I[s,t,l,aMin[a]:aMax[a]]);
    // Rg[s,l,t,a] = mean(R[s,l,t,aMin[a]:aMax[a]]);
    
    ITot[s,t,l,a] = Ig[s,t,l,a]*pop[s,t,l,a];
    pCases[s,t,l,a] = rho[s,a]*ITot[s,t,l,a];
  }
  
}


model {
  
  //--- Priors ---//
//   for(l in 1:nL){
//     log_lambda[l,] ~ normal(-6,1);
//   }

  to_vector(log_lambda) ~ normal(-6.0,1.0);
  phi ~ normal(4,1);

//   for(s in 1:2){
//     logit_rho[s,] ~ normal(-5,0.5);
//   }
  to_vector(logit_rho) ~ normal(-5,0.5);
  
  //--- Likelihood ---//
    //for(t in 1:nT) for (s in 1:2) cases[t,,s] ~ poisson(pCases[s,t,]);
//   for (s in 1:2) for (l in 1:nL) for(t in 1:nT) cases[s,t,l] ~ neg_binomial_2(pCases[s,t,l], phi);
   to_array_1d(cases) ~ neg_binomial_2(to_array_1d(pCases), phi);

}

d

Finally, I commented out the Sg and Rg parameters and calculations, since they weren’t used in subsequent computation. If you need something related to those parameters after the fit, you should be able to summarize from the posterior distributions for S and R.

1 Like

Check also compilation options in the blog post Options for improving Stan sampling speed – The Stan Blog

@kaskogsholm I think lambda[l,1]*rep_row_vector(1,100) could be changed to rep_row_vector(lambda[l,1],100), although probably don’t affect the total speed much.

Thank you so much for the help. I have tried your suggestions about vectorization on a chunk of my data and they cut the time by a half! I will keep vectorizing calculations to see if I can further speedup the sampling. Do you know whether inner chain parallelization by reduce_sum will help with the speed? I have tried to use reduce_sum to parallelize likelihood calculations but it becomes even slower (even for only two threads which I don’t think there will be overheading), so I am a little bit confused

I’m glad we made some progress. Unfortunately, I don’t know much about reduce_sum or other related features in Stan since I’ve not gotten to the point of needing them myself.

@avehtari True, and I think there’s probably some trick where you pre-allocate some of the rep_vector’s that are being constructed each iteration, based on something I’ve read here before, but I can’t remember the specifics.

Thank you so much for the reply and sorry for responding it late. I have made some further vectorization on the model, for example on the part of aggregating to age groups, but actually the effect is minor. Could you give me some further advice about where is worth further vectorizing so I can reduce the run time more? I would really appreciate that!

Also, I have checked my data again and found that although many case counts are 0, there are no NA values, and the 0 should be considered into the calculations since it might just represent low infection rate. Instead of changing into a long table form, I am considering trying a zero-inflated model, but I am not sure whether this will add on to the model complexity and thus runtime.

I would check out the compilation options linked in @avehtari’s post. Especially setting -march=native has made a significant difference for me in the past.

If, after that, you’ve run out of places to vectorize, the next thing to do is to try reduce_sum. I can’t really help you with that, but I would identify the slowest section by profiling, then try to use reduce_sum on that portion, following the user guide.

You can also post your current code, and perhaps someone more clever than me will see an improvement. I suspect that in the section where you aggregate to age groups something more optimal is possible. You should be able to eliminate the ITot variable. Some code:

  for(s in 1:2) for(l in 1:nL) for(a in 1:nA) for(t in 1:nT) {
    real Ig = mean(I[s,t,l,aMin[a]:aMax[a]]);
    pCases[s,t,l,a] = rho[s,a]*Ig*pop[s,t,l,a];
  }

You may have already made this optimization, and I’m not sure how much it will help, but it would allow you to avoid the large allocations for ITot and Ig made on every pass through the transformed parameter block.

1 Like

Thank you so much for the patience! I have tried the compilation method in @avehtari’s post by putting this into R:

cpp_options = list("CXXFLAGS += -march=native -mtune=native -DEIGEN_USE_BLAS -DEIGEN_USE_LAPACKE", "LDLIBS += -lblas -llapack -llapacke") cmdstanr::cmdstan_make_local(cpp_options = cpp_options, append = TRUE) cmdstanr::rebuild_cmdstan(cores = 4)

However, weirdly it does not quite accelerate the sampling. When I just add -march=native -mtune=native, the runtime is similar to the previous one. When I add the full code referring to external libraries, the runtime even increase… I am not very sure what is going on. I expect this would work on my code since it has heavy matrix calculations…

Well, it was worth a try. I’d guess that the running time is dominated by cache misses due to all the very large parameter containers involved, so speeding up the linear algebra or improving the vectorization might not help much. In your position, I would focus on within-chain parallelization. The simplest place to start would be with this line:

model {
...
   to_array_1d(cases) ~ neg_binomial_2(to_array_1d(pCases), phi);
...
}

This looks very similar to the example for reduce sum in the user guide: Parallelization

After looking at this line, you might also be able to avoid the call to to_array_1d by using a 1d array from the beginning, with some edits the aggregation step (You would have to manually calculate a mapping between a single index variable and the four you have now). I don’t know enough about Stan’s internals to know if that will actually help.