I am in the process of developing a model that will be fairly complex in the end, but I’m following the advice to start simple and add complexity one element at a time. Things are going well, but I have gotten to the point where things are complex enough that each iteration is starting to take a while.
This seemed like the right time to look into within-chain parallelization (I have access to a computer with many cores and lots of ram), but so far I have not had success. I am hoping to get some advice on how to efficiently parallelize this model so that iteration time remains tractable as I increase the complexity of the model.
Below I’ll include 3 versions of the model. The first has no within-chain parallelization and takes about 10 minutes (4 chains, 1000 warmup, 1000 post-warmup). 10 minutes isn’t a long time, but I intend to increase the complexity of the model quite a bit which will also entail using about 4x as much data.
The second version uses reduce_sum and slices over the response variable (rt) and broadcasts some large vectors computed in the transformed parameters block. This took more than 2 hours to generate the same number of samples with 4 threads per chain.
The third version packs everything into a N by M array and slices over that for reduce_sum. I didn’t let it finish, but it was on pace to take about 4 hours with 4 threads per chain. I’ve been using the automatically determined grain size, but fiddling with a few values did not yield any obvious improvements. update: with grain-size set to 135 (1/40th of the data length), it completed in just over an hour.
As an aside, I worked through the reduce_sum example here and observed roughly 4x speedup with 4 threads per chain, so I think parallelization is working in general. I suspect I’m doing something catastrophically wrong in my attempt to parallelize.
Okay, thanks for taking a look. I’m happy to get any advice on speeding up the base model, but I am also really curious why the way I’m using reduce_sum is causing such a huge slowdown.
The base model
to provide a sense of scale here are the size values I am using, ntrial=5400, nsub=60, nstim=90, nchar=8
data {
int<lower=1> ntrial;
int<lower=1> nsub;
int<lower=1> nstim;
int<lower=1> nchar;
int<lower=1, upper=nsub> sid[ntrial];
int<lower=1, upper=nstim> stim_id[ntrial];
vector[ntrial] t;
vector[ntrial] rt;
matrix[nsub, nchar] Z; //assumed to be mean zero, scale 1
real<lower=0> U; // rt values greater than or equal to U are considered MISSING
}
transformed data {
// identify trials that are 'late' (after the 10 s timeout)
real nu = 10;
int latecount = 0;
int good_idx_tmp[ntrial];
int late_idx_tmp[ntrial];
int idx = 0;
for (i in 1:ntrial) {
if (rt[i] >= U) {
latecount += 1;
late_idx_tmp[latecount] = i;
} else {
idx += 1;
good_idx_tmp[idx] = i;
}
}
int good_idx[ntrial-latecount] = good_idx_tmp[1:(ntrial-latecount)];
int late_idx[latecount] = late_idx_tmp[1:latecount];
// QR Reparameterization, see https://mc-stan.org/docs/2_27/stan-users-guide/QR-reparameterization-section.html
matrix[nsub, nchar] Q_ast;
matrix[nchar, nchar] R_ast;
matrix[nchar, nchar] R_ast_inverse;
// thin and scale the QR decomposition
Q_ast = qr_thin_Q(Z) * sqrt(nsub - 1);
R_ast = qr_thin_R(Z) / sqrt(nsub - 1);
R_ast_inverse = inverse(R_ast);
}
parameters {
real P0;
real<lower=0> sigma_P0;
vector[nsub] zP0;
vector[nchar] theta_p0;
real Delta;
real<lower=0> sigma_Delta;
vector[nsub] zDelta;
real Alpha;
real<lower=0> sigma_Alpha;
vector[nsub] zAlpha;
real Sigma;
real<lower=0> sigma_sigma;
vector[nsub] zSigma;
real<lower=0> sigma_stim;
vector[nstim] zStim;
}
transformed parameters {
vector[nsub] p0 = P0 + zP0*sigma_P0 + Q_ast*theta_p0;
vector[nsub] delta = Delta + zDelta*sigma_Delta;
vector[nsub] alpha = 1 + exp(Alpha + zAlpha*sigma_Alpha);
vector[nsub] sigma = exp(Sigma + zSigma*sigma_sigma);
vector[nstim] stim = zStim*sigma_stim;
vector[ntrial] mu = stim[stim_id] + p0[sid] + delta[sid].*(1 - exp(-alpha[sid].*t));
}
model {
rt[good_idx] ~ lognormal(mu[good_idx], sigma[sid[good_idx]]);
target += lognormal_lccdf(U | mu[late_idx], sigma[sid[late_idx]]);
P0 ~ normal(1,1);
sigma_P0 ~ normal(0, 1);
zP0 ~ student_t(nu, 0, 1);
theta_p0 ~ normal(0, 0.25);
Delta ~ normal(-1, 1);
sigma_Delta ~ normal(0, 1);
zDelta ~ student_t(nu, 0, 1);
Alpha ~ normal(1, 1);
sigma_Alpha ~ normal(0, 0.33);
zAlpha ~ student_t(nu, 0, 1);
Sigma ~ normal(-1, 1);
sigma_sigma ~ normal(0, 0.33);
zSigma ~ student_t(nu, 0, 1);
sigma_stim ~ normal(0, 0.33);
zStim ~ student_t(nu, 0, 1);
}
generated quantities {
vector[nchar] beta_p0 = R_ast_inverse*theta_p0;
vector[nsub] pEnd = exp(p0 + delta.*(1 - exp(-alpha)));
vector[ntrial] rt_hat = to_vector(lognormal_rng(mu, sigma[sid]));
for (i in 1:ntrial) {
if (rt_hat[i] > 10)
rt_hat[i] = 10;
}
}
Simple reduce_sum
functions {
real partial_sum(real[] rt_slice,
int start, int end,
vector mu,
vector sigma) {
return lognormal_lpdf(rt_slice | mu[start:end], sigma[start:end]);
}
}
data {
int<lower=1> ntrial;
int<lower=1> nsub;
int<lower=1> nstim;
int<lower=1> nchar;
int<lower=1, upper=nsub> sid[ntrial];
int<lower=1, upper=nstim> stim_id[ntrial];
vector[ntrial] t;
real rt[ntrial];
matrix[nsub, nchar] Z; //assumed to be mean zero, scale 1
real<lower=0> U; // rt values greater than or equal to U are considered MISSING
}
transformed data {
// identify trials that are 'late' (after the 10 s timeout)
real nu = 10;
int latecount = 0;
int good_idx_tmp[ntrial];
int late_idx_tmp[ntrial];
int idx = 0;
for (i in 1:ntrial) {
if (rt[i] >= U) {
latecount += 1;
late_idx_tmp[latecount] = i;
} else {
idx += 1;
good_idx_tmp[idx] = i;
}
}
int good_idx[ntrial-latecount] = good_idx_tmp[1:(ntrial-latecount)];
int late_idx[latecount] = late_idx_tmp[1:latecount];
// QR Reparameterization, see https://mc-stan.org/docs/2_27/stan-users-guide/QR-reparameterization-section.html
matrix[nsub, nchar] Q_ast;
matrix[nchar, nchar] R_ast;
matrix[nchar, nchar] R_ast_inverse;
// thin and scale the QR decomposition
Q_ast = qr_thin_Q(Z) * sqrt(nsub - 1);
R_ast = qr_thin_R(Z) / sqrt(nsub - 1);
R_ast_inverse = inverse(R_ast);
}
parameters {
real P0;
real<lower=0> sigma_P0;
vector[nsub] zP0;
vector[nchar] theta_p0;
real Delta;
real<lower=0> sigma_Delta;
vector[nsub] zDelta;
real Alpha;
real<lower=0> sigma_Alpha;
vector[nsub] zAlpha;
real Sigma;
real<lower=0> sigma_sigma;
vector[nsub] zSigma;
real<lower=0> sigma_stim;
vector[nstim] zStim;
}
transformed parameters {
vector[nsub] p0 = P0 + zP0*sigma_P0 + Q_ast*theta_p0;
vector[nsub] delta = Delta + zDelta*sigma_Delta;
vector[nsub] alpha = 1 + exp(Alpha + zAlpha*sigma_Alpha);
vector[nsub] sigma = exp(Sigma + zSigma*sigma_sigma);
vector[nstim] stim = zStim*sigma_stim;
vector[ntrial] mu = stim[stim_id] + p0[sid] + delta[sid].*(1 - exp(-alpha[sid].*t));
}
model {
// parallelize likelihood with reduce_sum
int grainsize = 1; // automatically determine grainsize
target += reduce_sum(partial_sum, rt[good_idx],
grainsize,
mu[good_idx], sigma[sid[good_idx]]);
//expect few of these, so not worth parallelizing
target += lognormal_lccdf(U | mu[late_idx], sigma[sid[late_idx]]);
P0 ~ normal(1,1);
sigma_P0 ~ normal(0, 1);
zP0 ~ student_t(nu, 0, 1);
theta_p0 ~ normal(0, 0.25);
Delta ~ normal(-1, 1);
sigma_Delta ~ normal(0, 1);
zDelta ~ student_t(nu, 0, 1);
Alpha ~ normal(1, 1);
sigma_Alpha ~ normal(0, 0.33);
zAlpha ~ student_t(nu, 0, 1);
Sigma ~ normal(-1, 1);
sigma_sigma ~ normal(0, 0.33);
zSigma ~ student_t(nu, 0, 1);
sigma_stim ~ normal(0, 0.33);
zStim ~ student_t(nu, 0, 1);
}
generated quantities {
vector[nchar] beta_p0 = R_ast_inverse*theta_p0;
vector[nsub] pEnd = exp(p0 + delta.*(1 - exp(-alpha)));
vector[ntrial] rt_hat = to_vector(lognormal_rng(mu, sigma[sid]));
for (i in 1:ntrial) {
if (rt_hat[i] > 10)
rt_hat[i] = 10;
}
}
Packed reduce_sum
functions {
real[] get_mu(real[,] packed) {
// packed: stim, p0, delta, alpha, and t.
int nrow = size(packed);
array[nrow] real mu;
for (i in 1:nrow) {
mu[i] = packed[i,1] + packed[i,2] + packed[i,3]*(1 - exp(-packed[i,4]*packed[i,5]));
}
return mu;
}
real partial_sum(real[,] packed,
int start, int end) {
// packed: rt, stim, p0, delta, alpha, t, and sigma.
// vector[ntrial] mu = stim[stim_id] + p0[sid] + delta[sid].*(1 - exp(-alpha[sid].*t));
return lognormal_lpdf(packed[,1] | get_mu(packed[,2:6]), packed[,7]);
}
}
data {
int<lower=1> grainsize;
int<lower=1> ntrial;
int<lower=1> nsub;
int<lower=1> nstim;
int<lower=1> nchar;
int<lower=1, upper=nsub> sid[ntrial];
int<lower=1, upper=nstim> stim_id[ntrial];
array[ntrial] real t;
array[ntrial] real rt;
matrix[nsub, nchar] Z; //assumed to be mean zero, scale 1
real<lower=0> U; // rt values greater than or equal to U are considered MISSING
}
transformed data {
// identify trials that are 'late' (after the 10 s timeout)
real nu = 10;
int latecount = 0;
int good_idx_tmp[ntrial];
int late_idx_tmp[ntrial];
int idx = 0;
for (i in 1:ntrial) {
if (rt[i] >= U) {
latecount += 1;
late_idx_tmp[latecount] = i;
} else {
idx += 1;
good_idx_tmp[idx] = i;
}
}
int good_idx[ntrial-latecount] = good_idx_tmp[1:(ntrial-latecount)];
int late_idx[latecount] = late_idx_tmp[1:latecount];
// QR Reparameterization, see https://mc-stan.org/docs/2_27/stan-users-guide/QR-reparameterization-section.html
matrix[nsub, nchar] Q_ast;
matrix[nchar, nchar] R_ast;
matrix[nchar, nchar] R_ast_inverse;
// thin and scale the QR decomposition
Q_ast = qr_thin_Q(Z) * sqrt(nsub - 1);
R_ast = qr_thin_R(Z) / sqrt(nsub - 1);
R_ast_inverse = inverse(R_ast);
}
parameters {
real P0;
real<lower=0> sigma_P0;
vector[nsub] zP0;
vector[nchar] theta_p0;
real Delta;
real<lower=0> sigma_Delta;
vector[nsub] zDelta;
real Alpha;
real<lower=0> sigma_Alpha;
vector[nsub] zAlpha;
real Sigma;
real<lower=0> sigma_sigma;
vector[nsub] zSigma;
real<lower=0> sigma_stim;
vector[nstim] zStim;
}
transformed parameters {
array[nsub] real p0 = to_array_1d(P0 + zP0*sigma_P0 + Q_ast*theta_p0);
array[nsub] real delta = to_array_1d(Delta + zDelta*sigma_Delta);
array[nsub] real alpha = to_array_1d(1 + exp(Alpha + zAlpha*sigma_Alpha));
array[nsub] real sigma = to_array_1d(exp(Sigma + zSigma*sigma_sigma));
array[nstim] real stim = to_array_1d(zStim*sigma_stim);
//vector[ntrial] mu = stim[stim_id] + p0[sid] + delta[sid].*(1 - exp(-alpha[sid].*t));
}
model {
// I want to pack a multidimensional array for better/faster slicing.
// I need: rt, stim, p0, delta, alpha, t, and sigma.
array[ntrial, 7] real packed;
packed[,1] = rt;
packed[,2] = stim[stim_id];
packed[,3] = p0[sid];
packed[,4] = delta[sid];
packed[,5] = alpha[sid];
packed[,6] = t;
packed[,7] = sigma[sid];
// parallelize likelihood with reduce_sum
target += reduce_sum(partial_sum, packed[good_idx, ],
grainsize);
//expect few of these, so not worth parallelizing
target += lognormal_lccdf(U | get_mu(packed[late_idx, 2:6]), packed[late_idx, 7]);
P0 ~ normal(1,1);
sigma_P0 ~ normal(0, 1);
zP0 ~ student_t(nu, 0, 1);
theta_p0 ~ normal(0, 0.25);
Delta ~ normal(-1, 1);
sigma_Delta ~ normal(0, 1);
zDelta ~ student_t(nu, 0, 1);
Alpha ~ normal(1, 1);
sigma_Alpha ~ normal(0, 0.33);
zAlpha ~ student_t(nu, 0, 1);
Sigma ~ normal(-1, 1);
sigma_sigma ~ normal(0, 0.33);
zSigma ~ student_t(nu, 0, 1);
sigma_stim ~ normal(0, 0.33);
zStim ~ student_t(nu, 0, 1);
}
generated quantities {
vector[nchar] beta_p0 = R_ast_inverse*theta_p0;
vector[nsub] pEnd;
vector[ntrial] rt_hat;
for (i in 1:nsub) {
pEnd[i] = exp(p0[i] + delta[i]*(1 - exp(-alpha[i])));
}
for (i in 1:ntrial) {
rt_hat[i] = lognormal_rng(stim[stim_id[i]] + p0[sid[i]] + delta[sid[i]]*(1 - exp(-alpha[sid[i]]*t[i])), sigma[sid[i]]); //hope this is fast
if (rt_hat[i] > 10)
rt_hat[i] = 10;
}
}
And in case anyone wants it:
A data generator
This doesn’t generate sid, stim_id, t, or Z, so it might be pretty useless. If anyone wants some python code to generate those, I’m happy to provide.
data {
int<lower=1> total_trials;
int<lower=1> nsub;
int<lower=1> nstim;
int<lower=1> nchar; //number of subject characteristics included
int sid[total_trials];
int stim_id[total_trials];
vector[total_trials] t;
matrix[nsub, nchar] Z; //assumed to have mean zero and scale 1.
}
transformed data {
vector[nsub] zeros = rep_vector(0, nsub);
real nu = 10;
matrix[nsub, nchar] Q_ast;
matrix[nchar, nchar] R_ast;
matrix[nchar, nchar] R_ast_inverse;
// thin and scale the QR decomposition
Q_ast = qr_thin_Q(Z) * sqrt(nsub - 1);
R_ast = qr_thin_R(Z) / sqrt(nsub - 1);
R_ast_inverse = inverse(R_ast);
}
parameters {}
model {}
generated quantities{
real P0 = normal_rng(1, 1);
real sigma_P0 = fabs(normal_rng(0, 1));
vector[nchar] theta_p0 = to_vector(normal_rng(rep_vector(0, nchar), 0.25));
vector[nsub] p0 = P0 + to_vector(student_t_rng(nu, zeros, 1))*sigma_P0 + Q_ast*theta_p0;
vector[nchar] beta_p0 = R_ast_inverse*theta_p0;
real Delta = normal_rng(-1, 1);
real sigma_Delta = fabs(normal_rng(0, 1));
vector[nsub] delta = Delta + to_vector(student_t_rng(nu, zeros, 1))*sigma_Delta;
real Alpha = normal_rng(1, 1);
real sigma_Alpha = fabs(normal_rng(0, 0.33));
vector[nsub] alpha = 1 + exp(Alpha + to_vector(student_t_rng(nu, zeros, 1))*sigma_Alpha);
real Sigma = normal_rng(-1, 1);
real sigma_sigma = fabs(normal_rng(0, 0.33));
vector[nsub] sigma = exp(Sigma + to_vector(student_t_rng(nu, zeros, 1))*sigma_sigma);
real sigma_stim = fabs(normal_rng(0, 0.33));
vector[nstim] stim = to_vector(student_t_rng(nu, rep_vector(0, nstim), 1))*sigma_stim;
vector[total_trials] mu = stim[stim_id] + p0[sid] + delta[sid].*(1 - exp(-alpha[sid].*t));
vector[total_trials] rt = to_vector(lognormal_rng(mu, sigma[sid]));
for (i in 1:total_trials) {
if (rt[i] > 10)
rt[i] = 10;
}
}