Parallelizing model with ragged parameters using slicing operations

Hello!

I have been working on a model with many ragged arrays, and with a runtime that is very slow. To speed things up I am trying to parallelize whatever I can, or at least vectorize as much as possible. However, I am a bit stuck because — comparing to the demos of reduce_sum that I have read — it seems like the use of segment to handle the ragged model parameters basically eliminates most of my options. Yet with these big for-loops, I am sure there are opportunities for improving the sampling efficiency that I’m just not seeing, and there isn’t a ton of existing guidance for this situation across other related posts.

My hope is that maybe someone out there has some ideas I could implement? I haven’t fully gotten my head around all the approaches to vectorization and parallelization, so I’m sure there are some obvious improvements I can make.

Thank you in advance!


The actual model script has a lot more going on, but I’m doing what I can here to just give you the relevant parts of the code. (For reference, the tiny reduce_sum I have implemented basically does not improve the efficiency at all, which makes sense — and I’m hoping to do much better.)

functions {
real log_dirichlet_lpdf(vector log_theta, vector alpha) {
    int N = rows(log_theta);
    if (N != rows(alpha)) {
      reject("Input must contain same number of elements as alpha");
    }
      
    return dot_product(alpha, log_theta) - log_theta[N]
          + lgamma(sum(alpha)) - sum(lgamma(alpha));
  }

 real multinomial_log_lpmf(array[] int y, vector log_theta) {
    int N = sum(y);
    int K = num_elements(log_theta);
    real lp = lgamma(N + 1);
    for (k in 1:K) {
      lp += log_theta[k] * y[k] - lgamma(y[k] + 1);
    }
  
    return lp;
  }

real partial_sum(array[,] int y_slice, int start, int end, array[] vector theta) {
    real interm_sum = 0;
    for (i in 1:(end - start + 1)) {
      interm_sum += multinomial_log_lpmf(y_slice[i] | theta[start + i - 1]);
    }
    return interm_sum;
  }
}
data {
  int<lower=1> n; // number of election units i
  real<lower=0> lambda; // exponential hyperprior inverse scale parameter
  int grainsize; // for partial_sum()
  int<lower=1> nR;
  int<lower=1> nM;
  int<lower=1> nRM;
  int<lower=1> nCRM;
  int<lower=2> C; // number of candidates c for election 1 (the global election)
  array[nR] int<lower=2> R;
  array[nM] int<lower=2> M;
  array[nRM] int<lower=2> RM;
  array[nCRM] int<lower=2> CRM;
  array[n] int<lower=1> zeta_r;
  array[n] int<lower=1> zeta_m;
  array[n] int<lower=1> zeta_rm;
  array[n] int<lower=1> zeta_crm;
  array[n, C] int<lower=0> y_c; // number of votes for candidate c in unit i
  array[sum(R[zeta_r])] int<lower=0> y_r;
  array[nRM] int<lower=1, upper=sum(RM)> alpha_rm_s;
  array[nCRM] int<lower=1, upper=(C * sum(CRM))> alpha_crm_s;
}

parameters {
  vector<lower=0>[sum(RM)] alpha_rm;
  vector<lower=0>[C * sum(CRM)] alpha_crm;
  // ...
}
transformed parameters {
  // There are some functions in this block 
  // that constrain varying lengthed segments of these vectors/arrays 
  // (used in the model block) to simplices. So these are all ragged
  // arrays that have been handled following the guidance of the 
  // manual, with an additional simplex constraint.
  vector[sum(RM[zeta_rm])] log_beta_rm;
  array[sum(CRM[zeta_crm])] vector[C] log_beta_crm;
  vector[sum(R[zeta_r])] log_theta_r;
  array[n] vector[C] log_theta_c;
  // ...
}

model {
  alpha_rm ~ exponential(lambda);
  alpha_crm ~ exponential(lambda);

  int pos_m = 1;
  int pos_r = 1;
  int pos_crm = 1;
  for (i in 1:n) {
    int pos_a = alpha_rm_s[zeta_rm[i]];
    for (r in 1:R[zeta_r[i]]) {
       target += log_dirichlet_lpdf(segment(log_beta_rm, pos_m, M[zeta_m[i]]) | segment(alpha_rm, pos_a, R[zeta_r[i]]));
       pos_m += M[zeta_m[i]];
       pos_a += M[zeta_m[i]];
    }
    // ...
    int pos_aCRM = alpha_crm_s[zeta_crm[i]];
    for (crm in 1:CRM[zeta_crm[i]]) {
      log_beta_crm[crm + pos_crm - 1] ~ log_dirichlet(segment(alpha_crm, pos_aCRM, C));
      pos_aCRM += C;
    }

    target += multinomial_log_lpmf(segment(y_r, pos_r, R[zeta_r[i]]) | segment(log_theta_r, pos_r, R[zeta_r[i]]));
    // ...
    // target += multinomial_log_lpmf(y_c[i] | log_theta_c[i]);
    pos_r += R[zeta_r[i]];
    pos_crm += CRM[zeta_crm[i]];
  }
  // this is my attempt at parallelizing the commented out multinomial_log_lpmf above
  target += reduce_sum(partial_sum, y_c, grainsize, log_theta_c);
}

There is no need to slice the data. You can also just slice a sequence from 1 to n of your for loop. This is what brms does as well…so maybe take a brms generated model with threading as an example and go from there.

1 Like

I see, that’s actually a really good piece of advice, thank you! It makes sense in retrospect but I hadn’t realized until now from the users guide, especially because it isn’t said explicitly, that I should be putting as much of the content of the model block inside a single reduce_sum function (instead of a separate one for each function I want to parallelize). I’ve implemented multithreading in other languages but I often forget that Stan is no different and the approach should be similar.

On that point, I have a related question. I see that I’m basically going to be replacing the for (i in 1:n) {} loop in the model block with reduce_sum, but how do I handle parallelization of the transformed parameters block?

This might be that I am not completely understanding the relation between the two blocks in terms of parallelization (which I assure you I do not), but in order to perform that simplex constraint that I mention above, I have calls to this following function in the transformed parameters block:

functions {
vector inv_ilr_log_simplex_constrain_lp(vector y) {
    int N = rows(y) + 1;
    vector[N - 1] ns = linspaced_vector(N - 1, 1, N - 1);
    vector[N - 1] w = y ./ sqrt(ns .* (ns + 1));
    vector[N] z = append_row(reverse(cumulative_sum(reverse(w))), 0) - append_row(0, ns .* w);
    real r = log_sum_exp(z);
    vector[N] log_x = z - r;
    target += 0.5 * log(N);
    target += log_x[N];
    return log_x;
  }
}

I’m sure that I’m not interpreting this correctly, but it seems like if I have compound iteration to the log likelihood (target +=) outside of the reduce_sum in the model block, then I won’t see the kind of speed up that I would otherwise hope for. Is the approach to this just to have another separate reduce_sum in the transformed parameters block? How would you approach this?

Stan has support for profiling. You should instrument your code with this technique. It is relatively easy to use and well documented. That should tell you if the above function is indeed important for performance. Stan programs are slowed down by costly numerics and a large autodiff tree - that’s the basic rule of thumb to have in mind. The autodiff tree grows with more instantiations of things which are derived/related to parameters.

Thank you for this suggestion as well! In addition to successfully implementing your advice for reduce_sum in the model block, I also profiled the relevant sections of code. I’m sure that this will not scale linearly depending on the data and compute, but right now the parameter transformations I’m performing outside of the reduce_sum are taking an equal amount of compute time. I am not going to move those computations into the model block because it’s important that these variables are written to the output. Besides ensuring that the code in the transformed parameters block is vectorized (if possible), with that limitation is there anything else that I can do so that the transformations take advantage of the multithreading?

You really should move the things in the transformed parameter block into the reduce_sum called function then by all means. It is much more efficient to calculate things again in the generated quantities block. Remember that things in the generated quantities block only get evaluated for accepted draws and we also do not calculate gradients for these things. The speedup by reduce_sum will well make up for a re-calculation in generated quantities.

1 Like

This hadn’t occurred to me, thanks again!