Hello!
I have been working on a model with many ragged arrays, and with a runtime that is very slow. To speed things up I am trying to parallelize whatever I can, or at least vectorize as much as possible. However, I am a bit stuck because — comparing to the demos of reduce_sum
that I have read — it seems like the use of segment
to handle the ragged model parameters basically eliminates most of my options. Yet with these big for-loops, I am sure there are opportunities for improving the sampling efficiency that I’m just not seeing, and there isn’t a ton of existing guidance for this situation across other related posts.
My hope is that maybe someone out there has some ideas I could implement? I haven’t fully gotten my head around all the approaches to vectorization and parallelization, so I’m sure there are some obvious improvements I can make.
Thank you in advance!
The actual model script has a lot more going on, but I’m doing what I can here to just give you the relevant parts of the code. (For reference, the tiny reduce_sum
I have implemented basically does not improve the efficiency at all, which makes sense — and I’m hoping to do much better.)
functions {
real log_dirichlet_lpdf(vector log_theta, vector alpha) {
int N = rows(log_theta);
if (N != rows(alpha)) {
reject("Input must contain same number of elements as alpha");
}
return dot_product(alpha, log_theta) - log_theta[N]
+ lgamma(sum(alpha)) - sum(lgamma(alpha));
}
real multinomial_log_lpmf(array[] int y, vector log_theta) {
int N = sum(y);
int K = num_elements(log_theta);
real lp = lgamma(N + 1);
for (k in 1:K) {
lp += log_theta[k] * y[k] - lgamma(y[k] + 1);
}
return lp;
}
real partial_sum(array[,] int y_slice, int start, int end, array[] vector theta) {
real interm_sum = 0;
for (i in 1:(end - start + 1)) {
interm_sum += multinomial_log_lpmf(y_slice[i] | theta[start + i - 1]);
}
return interm_sum;
}
}
data {
int<lower=1> n; // number of election units i
real<lower=0> lambda; // exponential hyperprior inverse scale parameter
int grainsize; // for partial_sum()
int<lower=1> nR;
int<lower=1> nM;
int<lower=1> nRM;
int<lower=1> nCRM;
int<lower=2> C; // number of candidates c for election 1 (the global election)
array[nR] int<lower=2> R;
array[nM] int<lower=2> M;
array[nRM] int<lower=2> RM;
array[nCRM] int<lower=2> CRM;
array[n] int<lower=1> zeta_r;
array[n] int<lower=1> zeta_m;
array[n] int<lower=1> zeta_rm;
array[n] int<lower=1> zeta_crm;
array[n, C] int<lower=0> y_c; // number of votes for candidate c in unit i
array[sum(R[zeta_r])] int<lower=0> y_r;
array[nRM] int<lower=1, upper=sum(RM)> alpha_rm_s;
array[nCRM] int<lower=1, upper=(C * sum(CRM))> alpha_crm_s;
}
parameters {
vector<lower=0>[sum(RM)] alpha_rm;
vector<lower=0>[C * sum(CRM)] alpha_crm;
// ...
}
transformed parameters {
// There are some functions in this block
// that constrain varying lengthed segments of these vectors/arrays
// (used in the model block) to simplices. So these are all ragged
// arrays that have been handled following the guidance of the
// manual, with an additional simplex constraint.
vector[sum(RM[zeta_rm])] log_beta_rm;
array[sum(CRM[zeta_crm])] vector[C] log_beta_crm;
vector[sum(R[zeta_r])] log_theta_r;
array[n] vector[C] log_theta_c;
// ...
}
model {
alpha_rm ~ exponential(lambda);
alpha_crm ~ exponential(lambda);
int pos_m = 1;
int pos_r = 1;
int pos_crm = 1;
for (i in 1:n) {
int pos_a = alpha_rm_s[zeta_rm[i]];
for (r in 1:R[zeta_r[i]]) {
target += log_dirichlet_lpdf(segment(log_beta_rm, pos_m, M[zeta_m[i]]) | segment(alpha_rm, pos_a, R[zeta_r[i]]));
pos_m += M[zeta_m[i]];
pos_a += M[zeta_m[i]];
}
// ...
int pos_aCRM = alpha_crm_s[zeta_crm[i]];
for (crm in 1:CRM[zeta_crm[i]]) {
log_beta_crm[crm + pos_crm - 1] ~ log_dirichlet(segment(alpha_crm, pos_aCRM, C));
pos_aCRM += C;
}
target += multinomial_log_lpmf(segment(y_r, pos_r, R[zeta_r[i]]) | segment(log_theta_r, pos_r, R[zeta_r[i]]));
// ...
// target += multinomial_log_lpmf(y_c[i] | log_theta_c[i]);
pos_r += R[zeta_r[i]];
pos_crm += CRM[zeta_crm[i]];
}
// this is my attempt at parallelizing the commented out multinomial_log_lpmf above
target += reduce_sum(partial_sum, y_c, grainsize, log_theta_c);
}