How to improve model sampling speed when applied to high-dimension data

Hello! I am new to stan and it’s my first time posting here. I am analyzing data from an epidemiology study and use catalytic epidemiology model. My input is a 4 dimension matrix of infected cases, specific to age, municipality, year, and sex. Since we divide the dimensions to many pieces, there are many combinations with 0 cases (so to some extent sparse data). I am using cmdstanr on clusters. The model works well on 400 municipalities, but the run time is about 7 hours on HPC (4 chains in total; one core for one chain). However, I’ve had a lot of difficulty fitting the model to the whole dataset (5570 municipalities in total). It cannot finish in a suitable amount of time. I am looking for advice about how to make the model more efficient. Thank you so much for any suggestions!

//----- Time-dependent catalytic model -----//
  data {
    int nA; // N age groups
    int nT; // N time points
    int nL; // N locations
    int nCasePoints;
    int nPoints;
    array[2,nL,nT,nA] int cases; //Cases
    array[2,nL] matrix[nT,nA] pop; // population
    array[2,nL,nT,nA] int pointIndex; // index of each case point for log-likelihood vector
    array[nA] int aMin; // index for age groups,Arrays defining the age group range (nA age groups thus nA minimum age) indices
    array[nA] int aMax; // index for age groups
  }

parameters {
  array[nL,nT] real<upper=-1> log_lambda; // time-varying FOI (log scale)
  array[2,nA] real logit_rho; // sex and age-dependent reporting rates (logit scale)
  real<lower=1> phi; // overdispersion parameter
}

transformed parameters {
// use input data and parameters to calculate intermediate values of estimates
  array[2,nL] matrix<lower=0, upper=1>[nT,100] S; // susceptible
  array[2,nL] matrix<lower=0, upper=1>[nT,100] I; // infected
  array[2,nL] matrix<lower=0, upper=1>[nT,100] R; // immune
  array[2,nL] matrix<lower=0, upper=1>[nT,nA] Sg; // aggregated to age groups (since later need to compare with the observed data which aggregated to age groups)
  array[2,nL] matrix<lower=0, upper=1>[nT,nA] Ig;
  array[2,nL] matrix<lower=0, upper=1>[nT,nA] Rg;
  array[2,nL,nT,nA] real pCases; // predicted cases
  array[nL,nT] real lambda = exp(log_lambda); // time-varying FOI (linear scale)
  array[2,nA] real rho=inv_logit(logit_rho); // sex and age-dependent reporting rates (linear scale)
  array[2,nL,nT,nA] real ITot; // number of infections

  // initial conditions (first year of transmission)
  // initialize all the data for t=1, assume proportion of infected (I) is proportional to FOI (lambda)
  for(s in 1:2){
    for(l in 1:nL){
      S[s,l,1,] = 1 - lambda[l,1]*rep_row_vector(1,100); 
      I[s,l,1,] = lambda[l,1]*rep_row_vector(1,100);
      R[s,l,1,] = lambda[l,1]*rep_row_vector(1,100);
    }
  }
  
  // loop through subsequent yearly timesteps
  for(s in 1:2){
  // first setup the data for age=1; at age=1, all time/location/sex are susceptible with no infection, so S=1; I/R=0
    for (t in 2:nT){
      S[s,1:nL,t,1] = rep_array(1,nL); // new babies
      I[s,1:nL,t,1] = rep_array(0,nL);
      R[s,1:nL,t,1] = rep_array(0,nL);
    }
  // then use the age=1 data (for all sex/time/location) to loop and get data from age=2 to An
    for(l in 1:nL){
      for(t in 2:nT){
        S[s,l,t,2:100] = S[s,l,t-1,1:99] - lambda[l,t]*S[s,l,t-1,1:99];
        I[s,l,t,2:100] = lambda[l,t]*S[s,l,t-1,1:99];
        R[s,l,t,2:100] = 1-S[s,l,t,2:100];
      }
    }
  }
  
  // aggregate to age groups
  for(s in 1:2) for(l in 1:nL) for(a in 1:nA) for(t in 1:nT) {
    Sg[s,l,t,a] = mean(S[s,l,t,aMin[a]:aMax[a]]);
    Ig[s,l,t,a] = mean(I[s,l,t,aMin[a]:aMax[a]]);
    Rg[s,l,t,a] = mean(R[s,l,t,aMin[a]:aMax[a]]);
    
    ITot[s,l,t,a] = Ig[s,l,t,a]*pop[s,l,t,a];
    pCases[s,l,t,a] = rho[s,a]*ITot[s,l,t,a];
  }
  
}


model {
  
  //--- Priors ---//
  for(l in 1:nL){
    log_lambda[l,] ~ normal(-6,1);
  }

  phi ~ normal(4,1);

  for(s in 1:2){
    logit_rho[s,] ~ normal(-5,0.5);
  }
  
  //--- Likelihood ---//
    //for(t in 1:nT) for (s in 1:2) cases[t,,s] ~ poisson(pCases[s,t,]);
  for (s in 1:2) for (l in 1:nL) for(t in 1:nT) cases[s,l,t,] ~ neg_binomial_2(pCases[s,l,t,], phi);

}

generated quantities {
  array[nPoints] real log_lik;
  real log_lik_sum = 0.0;
  real rmse = 0.0;
  
  for (s in 1:2){
    for(l in 1:nL){
      for(t in 1:nT){
        for(a in 1:nA){
          log_lik[pointIndex[s,l,t,a]] = neg_binomial_2_lpmf(cases[s,l,t,a] | pCases[s,l,t,a], phi);
          log_lik_sum += log_lik[pointIndex[s,l,t,a]];
          rmse += (pCases[s,l,t,a] - cases[s,l,t,a])^2;
        }
      }
    }
  }
  
  rmse = (rmse/nPoints) ^ .5;
}

I took a stab at vectorizing some of the calculations. There’s certainly more things you could vectorize, but I was getting very sleepy.

At least on my machine, with the very fake data that I generated, it cut the time by around 30-40%. None of my testing included the generated quantities block, so I commented out everything related to that.

If you really do have sparse data, it might be much more efficient to use a long, table-like data structure instead of arrays of matrices. Here’s the section of the manual: Sparse and Ragged Data Structures

//----- Time-dependent catalytic model -----//
  data {
    int nA; // N age groups
    int nT; // N time points
    int nL; // N locations
    // int nCasePoints;
    // int nPoints;
    array[2,nT,nL,nA] int cases; //Cases
    array[2,nT] matrix[nL,nA] pop; // population
    // array[2,nL,nT,nA] int pointIndex; // index of each case point for log-likelihood vector
    array[nA] int aMin; // index for age groups,Arrays defining the age group range (nA age groups thus nA minimum age) indices
    array[nA] int aMax; // index for age groups
  }

parameters {
  matrix<upper=-1>[nL, nT] log_lambda; // time-varying FOI (log scale)
  matrix[2,nA] logit_rho; // sex and age-dependent reporting rates (logit scale)
  real<lower=1> phi; // overdispersion parameter
}

transformed parameters {
// use input data and parameters to calculate intermediate values of estimates
  array[2,nT] matrix<lower=0, upper=1>[nL,100] S; // susceptible
  array[2,nT] matrix<lower=0, upper=1>[nL,100] I; // infected
  array[2,nT] matrix<lower=0, upper=1>[nL,100] R; // immune
  // array[2,nL] matrix<lower=0, upper=1>[nT,nA] Sg; // aggregated to age groups (since later need to compare with the observed data which aggregated to age groups)
  array[2,nT] matrix<lower=0, upper=1>[nL,nA] Ig;
  // array[2,nL] matrix<lower=0, upper=1>[nT,nA] Rg;
  array[2,nT,nL,nA] real pCases; // predicted cases
  matrix[nL,nT] lambda = exp(log_lambda); // time-varying FOI (linear scale)
  matrix[2,nA] rho=inv_logit(logit_rho); // sex and age-dependent reporting rates (linear scale)
  array[2,nT,nL,nA] real ITot; // number of infections

  // initial conditions (first year of transmission)
  // initialize all the data for t=1, assume proportion of infected (I) is proportional to FOI (lambda)
  for(s in 1:2){
    for(l in 1:nL){
      S[s,1,l,] = 1 - lambda[l,1]*rep_row_vector(1,100); 
      I[s,1,l,] = lambda[l,1]*rep_row_vector(1,100);
      R[s,1,l,] = lambda[l,1]*rep_row_vector(1,100);
    }
  }
  
  // loop through subsequent yearly timesteps
  for(s in 1:2){
  // first setup the data for age=1; at age=1, all time/location/sex are susceptible with no infection, so S=1; I/R=0
    for (t in 2:nT){
      S[s,t,1:nL,1] = rep_vector(1,nL); // new babies
      I[s,t,1:nL,1] = rep_vector(0,nL);
      R[s,t,1:nL,1] = rep_vector(0,nL);
    }
  // then use the age=1 data (for all sex/time/location) to loop and get data from age=2 to An
    // for(l in 1:nL){
      for(t in 2:nT){
        I[s,t,:,2:100] = diag_pre_multiply(lambda[:,t], S[s,t-1, :,1:99]);

        S[s,t,:,2:100] = S[s,t-1,:,1:99] - I[s,t,:,2:100];
        R[s,t,:,2:100] = 1-S[s,t,:,2:100];
      }
    // }
  }
  
  // aggregate to age groups
  for(s in 1:2) for(l in 1:nL) for(a in 1:nA) for(t in 1:nT) {
    // Sg[s,l,t,a] = mean(S[s,l,t,aMin[a]:aMax[a]]);
    Ig[s,t,l,a] = mean(I[s,t,l,aMin[a]:aMax[a]]);
    // Rg[s,l,t,a] = mean(R[s,l,t,aMin[a]:aMax[a]]);
    
    ITot[s,t,l,a] = Ig[s,t,l,a]*pop[s,t,l,a];
    pCases[s,t,l,a] = rho[s,a]*ITot[s,t,l,a];
  }
  
}


model {
  
  //--- Priors ---//
//   for(l in 1:nL){
//     log_lambda[l,] ~ normal(-6,1);
//   }

  to_vector(log_lambda) ~ normal(-6.0,1.0);
  phi ~ normal(4,1);

//   for(s in 1:2){
//     logit_rho[s,] ~ normal(-5,0.5);
//   }
  to_vector(logit_rho) ~ normal(-5,0.5);
  
  //--- Likelihood ---//
    //for(t in 1:nT) for (s in 1:2) cases[t,,s] ~ poisson(pCases[s,t,]);
//   for (s in 1:2) for (l in 1:nL) for(t in 1:nT) cases[s,t,l] ~ neg_binomial_2(pCases[s,t,l], phi);
   to_array_1d(cases) ~ neg_binomial_2(to_array_1d(pCases), phi);

}

d

Finally, I commented out the Sg and Rg parameters and calculations, since they weren’t used in subsequent computation. If you need something related to those parameters after the fit, you should be able to summarize from the posterior distributions for S and R.

1 Like

Check also compilation options in the blog post Options for improving Stan sampling speed – The Stan Blog

@kaskogsholm I think lambda[l,1]*rep_row_vector(1,100) could be changed to rep_row_vector(lambda[l,1],100), although probably don’t affect the total speed much.

Thank you so much for the help. I have tried your suggestions about vectorization on a chunk of my data and they cut the time by a half! I will keep vectorizing calculations to see if I can further speedup the sampling. Do you know whether inner chain parallelization by reduce_sum will help with the speed? I have tried to use reduce_sum to parallelize likelihood calculations but it becomes even slower (even for only two threads which I don’t think there will be overheading), so I am a little bit confused

I’m glad we made some progress. Unfortunately, I don’t know much about reduce_sum or other related features in Stan since I’ve not gotten to the point of needing them myself.

@avehtari True, and I think there’s probably some trick where you pre-allocate some of the rep_vector’s that are being constructed each iteration, based on something I’ve read here before, but I can’t remember the specifics.