Trying within-chain parallelization with reduce_sum increases runtime a lot

I am in the process of developing a model that will be fairly complex in the end, but I’m following the advice to start simple and add complexity one element at a time. Things are going well, but I have gotten to the point where things are complex enough that each iteration is starting to take a while.

This seemed like the right time to look into within-chain parallelization (I have access to a computer with many cores and lots of ram), but so far I have not had success. I am hoping to get some advice on how to efficiently parallelize this model so that iteration time remains tractable as I increase the complexity of the model.

Below I’ll include 3 versions of the model. The first has no within-chain parallelization and takes about 10 minutes (4 chains, 1000 warmup, 1000 post-warmup). 10 minutes isn’t a long time, but I intend to increase the complexity of the model quite a bit which will also entail using about 4x as much data.

The second version uses reduce_sum and slices over the response variable (rt) and broadcasts some large vectors computed in the transformed parameters block. This took more than 2 hours to generate the same number of samples with 4 threads per chain.

The third version packs everything into a N by M array and slices over that for reduce_sum. I didn’t let it finish, but it was on pace to take about 4 hours with 4 threads per chain. I’ve been using the automatically determined grain size, but fiddling with a few values did not yield any obvious improvements. update: with grain-size set to 135 (1/40th of the data length), it completed in just over an hour.

As an aside, I worked through the reduce_sum example here and observed roughly 4x speedup with 4 threads per chain, so I think parallelization is working in general. I suspect I’m doing something catastrophically wrong in my attempt to parallelize.

Okay, thanks for taking a look. I’m happy to get any advice on speeding up the base model, but I am also really curious why the way I’m using reduce_sum is causing such a huge slowdown.


The base model
to provide a sense of scale here are the size values I am using, ntrial=5400, nsub=60, nstim=90, nchar=8

data {
    int<lower=1> ntrial;
    int<lower=1> nsub;
    int<lower=1> nstim;
    int<lower=1> nchar;
    int<lower=1, upper=nsub> sid[ntrial];
    int<lower=1, upper=nstim> stim_id[ntrial];
    vector[ntrial] t;
    vector[ntrial] rt;
    matrix[nsub, nchar] Z; //assumed to be mean zero, scale 1
    
    real<lower=0> U; // rt values greater than or equal to U are considered MISSING
}
transformed data {
    // identify trials that are 'late' (after the 10 s timeout)
    real nu = 10;
    int latecount = 0;
    int good_idx_tmp[ntrial];
    int late_idx_tmp[ntrial];
    int idx = 0;
    for (i in 1:ntrial) {
        if (rt[i] >= U) {
            latecount += 1;
            late_idx_tmp[latecount] = i;
        } else {
            idx += 1;
            good_idx_tmp[idx] = i;
        }
    }
    int good_idx[ntrial-latecount] = good_idx_tmp[1:(ntrial-latecount)];
    int late_idx[latecount] = late_idx_tmp[1:latecount];
    
    // QR Reparameterization, see https://mc-stan.org/docs/2_27/stan-users-guide/QR-reparameterization-section.html
    matrix[nsub, nchar] Q_ast;
    matrix[nchar, nchar] R_ast;
    matrix[nchar, nchar] R_ast_inverse;
    // thin and scale the QR decomposition
    Q_ast = qr_thin_Q(Z) * sqrt(nsub - 1);
    R_ast = qr_thin_R(Z) / sqrt(nsub - 1);
    R_ast_inverse = inverse(R_ast);
}
parameters {
    real P0;
    real<lower=0> sigma_P0;
    vector[nsub] zP0;
    vector[nchar] theta_p0;
    
    real Delta;
    real<lower=0> sigma_Delta;
    vector[nsub] zDelta;
    
    real Alpha;
    real<lower=0> sigma_Alpha;
    vector[nsub] zAlpha;

    real Sigma;
    real<lower=0> sigma_sigma;
    vector[nsub] zSigma;
    
    real<lower=0> sigma_stim;
    vector[nstim] zStim;
}
transformed parameters {
    vector[nsub] p0 = P0 + zP0*sigma_P0 + Q_ast*theta_p0;
    vector[nsub] delta = Delta + zDelta*sigma_Delta;
    vector[nsub] alpha = 1 + exp(Alpha + zAlpha*sigma_Alpha);
    vector[nsub] sigma = exp(Sigma + zSigma*sigma_sigma);
    
    vector[nstim] stim = zStim*sigma_stim;
    
    vector[ntrial] mu = stim[stim_id] + p0[sid] + delta[sid].*(1 - exp(-alpha[sid].*t));
    
}
model {
    rt[good_idx] ~ lognormal(mu[good_idx], sigma[sid[good_idx]]);
    target += lognormal_lccdf(U | mu[late_idx], sigma[sid[late_idx]]); 
    
    P0 ~ normal(1,1);
    sigma_P0 ~ normal(0, 1);
    zP0 ~ student_t(nu, 0, 1);
    theta_p0 ~ normal(0, 0.25);
    
    Delta ~ normal(-1, 1);
    sigma_Delta ~ normal(0, 1);
    zDelta ~ student_t(nu, 0, 1);
    
    Alpha ~ normal(1, 1);
    sigma_Alpha ~ normal(0, 0.33);
    zAlpha ~ student_t(nu, 0, 1);
    
    Sigma ~ normal(-1, 1);
    sigma_sigma ~ normal(0, 0.33);
    zSigma ~ student_t(nu, 0, 1);
    
    sigma_stim ~ normal(0, 0.33);
    zStim ~ student_t(nu, 0, 1);
    
} 
generated quantities {
    vector[nchar] beta_p0 = R_ast_inverse*theta_p0;
    vector[nsub] pEnd = exp(p0 + delta.*(1 - exp(-alpha)));
    vector[ntrial] rt_hat = to_vector(lognormal_rng(mu, sigma[sid]));
    for (i in 1:ntrial) {
        if (rt_hat[i] > 10)
            rt_hat[i] = 10;
    }
}

Simple reduce_sum

functions {
    real partial_sum(real[] rt_slice,
                     int start, int end,
                     vector mu,
                     vector sigma) {
        return lognormal_lpdf(rt_slice | mu[start:end], sigma[start:end]);
    }
}
data {
    int<lower=1> ntrial;
    int<lower=1> nsub;
    int<lower=1> nstim;
    int<lower=1> nchar;
    int<lower=1, upper=nsub> sid[ntrial];
    int<lower=1, upper=nstim> stim_id[ntrial];
    vector[ntrial] t;
    real rt[ntrial];
    matrix[nsub, nchar] Z; //assumed to be mean zero, scale 1
    
    real<lower=0> U; // rt values greater than or equal to U are considered MISSING
}
transformed data {
    // identify trials that are 'late' (after the 10 s timeout)
    real nu = 10;
    int latecount = 0;
    int good_idx_tmp[ntrial];
    int late_idx_tmp[ntrial];
    int idx = 0;
    for (i in 1:ntrial) {
        if (rt[i] >= U) {
            latecount += 1;
            late_idx_tmp[latecount] = i;
        } else {
            idx += 1;
            good_idx_tmp[idx] = i;
        }
    }
    int good_idx[ntrial-latecount] = good_idx_tmp[1:(ntrial-latecount)];
    int late_idx[latecount] = late_idx_tmp[1:latecount];
    
    // QR Reparameterization, see https://mc-stan.org/docs/2_27/stan-users-guide/QR-reparameterization-section.html
    matrix[nsub, nchar] Q_ast;
    matrix[nchar, nchar] R_ast;
    matrix[nchar, nchar] R_ast_inverse;
    // thin and scale the QR decomposition
    Q_ast = qr_thin_Q(Z) * sqrt(nsub - 1);
    R_ast = qr_thin_R(Z) / sqrt(nsub - 1);
    R_ast_inverse = inverse(R_ast);
}
parameters {
    real P0;
    real<lower=0> sigma_P0;
    vector[nsub] zP0;
    vector[nchar] theta_p0;
    
    real Delta;
    real<lower=0> sigma_Delta;
    vector[nsub] zDelta;
    
    real Alpha;
    real<lower=0> sigma_Alpha;
    vector[nsub] zAlpha;

    real Sigma;
    real<lower=0> sigma_sigma;
    vector[nsub] zSigma;
    
    real<lower=0> sigma_stim;
    vector[nstim] zStim;
}
transformed parameters {
    vector[nsub] p0 = P0 + zP0*sigma_P0 + Q_ast*theta_p0;
    vector[nsub] delta = Delta + zDelta*sigma_Delta;
    vector[nsub] alpha = 1 + exp(Alpha + zAlpha*sigma_Alpha);
    vector[nsub] sigma = exp(Sigma + zSigma*sigma_sigma);
    
    vector[nstim] stim = zStim*sigma_stim;
    
    vector[ntrial] mu = stim[stim_id] + p0[sid] + delta[sid].*(1 - exp(-alpha[sid].*t));
    
}
model {
    // parallelize likelihood with reduce_sum
    int grainsize = 1; // automatically determine grainsize
    target += reduce_sum(partial_sum, rt[good_idx],
                         grainsize,
                         mu[good_idx], sigma[sid[good_idx]]);

    //expect few of these, so not worth parallelizing
    target += lognormal_lccdf(U | mu[late_idx], sigma[sid[late_idx]]); 
    
    P0 ~ normal(1,1);
    sigma_P0 ~ normal(0, 1);
    zP0 ~ student_t(nu, 0, 1);
    theta_p0 ~ normal(0, 0.25);
    
    Delta ~ normal(-1, 1);
    sigma_Delta ~ normal(0, 1);
    zDelta ~ student_t(nu, 0, 1);
    
    Alpha ~ normal(1, 1);
    sigma_Alpha ~ normal(0, 0.33);
    zAlpha ~ student_t(nu, 0, 1);
    
    Sigma ~ normal(-1, 1);
    sigma_sigma ~ normal(0, 0.33);
    zSigma ~ student_t(nu, 0, 1);
    
    sigma_stim ~ normal(0, 0.33);
    zStim ~ student_t(nu, 0, 1);
    
} 
generated quantities {
    vector[nchar] beta_p0 = R_ast_inverse*theta_p0;
    vector[nsub] pEnd = exp(p0 + delta.*(1 - exp(-alpha)));
    vector[ntrial] rt_hat = to_vector(lognormal_rng(mu, sigma[sid]));
    for (i in 1:ntrial) {
        if (rt_hat[i] > 10)
            rt_hat[i] = 10;
    }
}

Packed reduce_sum

functions {
    real[] get_mu(real[,] packed) {
        // packed: stim, p0, delta, alpha, and t.
        int nrow = size(packed);
        array[nrow] real mu;
        for (i in 1:nrow) {
            mu[i] = packed[i,1] + packed[i,2] + packed[i,3]*(1 - exp(-packed[i,4]*packed[i,5]));
        }
        return mu;
    }
    real partial_sum(real[,] packed,
                     int start, int end) {
        // packed: rt, stim, p0, delta, alpha, t, and sigma.
        // vector[ntrial] mu = stim[stim_id] + p0[sid] + delta[sid].*(1 - exp(-alpha[sid].*t));
        return lognormal_lpdf(packed[,1] | get_mu(packed[,2:6]), packed[,7]);
    }
}
data {
    int<lower=1> grainsize;
    int<lower=1> ntrial;
    int<lower=1> nsub;
    int<lower=1> nstim;
    int<lower=1> nchar;
    int<lower=1, upper=nsub> sid[ntrial];
    int<lower=1, upper=nstim> stim_id[ntrial];
    array[ntrial] real t;
    array[ntrial] real rt;
    matrix[nsub, nchar] Z; //assumed to be mean zero, scale 1
    
    real<lower=0> U; // rt values greater than or equal to U are considered MISSING
}
transformed data {
    // identify trials that are 'late' (after the 10 s timeout)
    real nu = 10;
    int latecount = 0;
    int good_idx_tmp[ntrial];
    int late_idx_tmp[ntrial];
    int idx = 0;
    for (i in 1:ntrial) {
        if (rt[i] >= U) {
            latecount += 1;
            late_idx_tmp[latecount] = i;
        } else {
            idx += 1;
            good_idx_tmp[idx] = i;
        }
    }
    int good_idx[ntrial-latecount] = good_idx_tmp[1:(ntrial-latecount)];
    int late_idx[latecount] = late_idx_tmp[1:latecount];
    
    // QR Reparameterization, see https://mc-stan.org/docs/2_27/stan-users-guide/QR-reparameterization-section.html
    matrix[nsub, nchar] Q_ast;
    matrix[nchar, nchar] R_ast;
    matrix[nchar, nchar] R_ast_inverse;
    // thin and scale the QR decomposition
    Q_ast = qr_thin_Q(Z) * sqrt(nsub - 1);
    R_ast = qr_thin_R(Z) / sqrt(nsub - 1);
    R_ast_inverse = inverse(R_ast);
    
}
parameters {
    real P0;
    real<lower=0> sigma_P0;
    vector[nsub] zP0;
    vector[nchar] theta_p0;
    
    real Delta;
    real<lower=0> sigma_Delta;
    vector[nsub] zDelta;
    
    real Alpha;
    real<lower=0> sigma_Alpha;
    vector[nsub] zAlpha;

    real Sigma;
    real<lower=0> sigma_sigma;
    vector[nsub] zSigma;
    
    real<lower=0> sigma_stim;
    vector[nstim] zStim;
}
transformed parameters {
    array[nsub] real p0 = to_array_1d(P0 + zP0*sigma_P0 + Q_ast*theta_p0);
    array[nsub] real delta = to_array_1d(Delta + zDelta*sigma_Delta);
    array[nsub] real alpha = to_array_1d(1 + exp(Alpha + zAlpha*sigma_Alpha));
    array[nsub] real sigma = to_array_1d(exp(Sigma + zSigma*sigma_sigma));
    
    array[nstim] real stim = to_array_1d(zStim*sigma_stim);
    
    //vector[ntrial] mu = stim[stim_id] + p0[sid] + delta[sid].*(1 - exp(-alpha[sid].*t));
    
}
model {
    // I want to pack a multidimensional array for better/faster slicing.
    // I need: rt, stim, p0, delta, alpha, t, and sigma.
    array[ntrial, 7] real packed;
    packed[,1] = rt;
    packed[,2] = stim[stim_id];
    packed[,3] = p0[sid];
    packed[,4] = delta[sid];
    packed[,5] = alpha[sid];
    packed[,6] = t;
    packed[,7] = sigma[sid];
    
    // parallelize likelihood with reduce_sum
    target += reduce_sum(partial_sum, packed[good_idx, ],
                         grainsize);

    //expect few of these, so not worth parallelizing
    target += lognormal_lccdf(U | get_mu(packed[late_idx, 2:6]), packed[late_idx, 7]); 
    
    P0 ~ normal(1,1);
    sigma_P0 ~ normal(0, 1);
    zP0 ~ student_t(nu, 0, 1);
    theta_p0 ~ normal(0, 0.25);
    
    Delta ~ normal(-1, 1);
    sigma_Delta ~ normal(0, 1);
    zDelta ~ student_t(nu, 0, 1);
    
    Alpha ~ normal(1, 1);
    sigma_Alpha ~ normal(0, 0.33);
    zAlpha ~ student_t(nu, 0, 1);
    
    Sigma ~ normal(-1, 1);
    sigma_sigma ~ normal(0, 0.33);
    zSigma ~ student_t(nu, 0, 1);
    
    sigma_stim ~ normal(0, 0.33);
    zStim ~ student_t(nu, 0, 1);
    
} 
generated quantities {
    vector[nchar] beta_p0 = R_ast_inverse*theta_p0;
    vector[nsub] pEnd; 
    vector[ntrial] rt_hat;
    for (i in 1:nsub) {
        pEnd[i] = exp(p0[i] + delta[i]*(1 - exp(-alpha[i])));
    }
    for (i in 1:ntrial) {
        rt_hat[i] = lognormal_rng(stim[stim_id[i]] + p0[sid[i]] + delta[sid[i]]*(1 - exp(-alpha[sid[i]]*t[i])), sigma[sid[i]]); //hope this is fast
        if (rt_hat[i] > 10)
            rt_hat[i] = 10;
    }
}

And in case anyone wants it:
A data generator
This doesn’t generate sid, stim_id, t, or Z, so it might be pretty useless. If anyone wants some python code to generate those, I’m happy to provide.

data {
    int<lower=1> total_trials;
    int<lower=1> nsub;
    int<lower=1> nstim;
    int<lower=1> nchar; //number of subject characteristics included
    
    int sid[total_trials];
    int stim_id[total_trials];
    vector[total_trials] t;
    matrix[nsub, nchar] Z; //assumed to have mean zero and scale 1.
}
transformed data {
    vector[nsub] zeros = rep_vector(0, nsub);
    real nu = 10;
    
    matrix[nsub, nchar] Q_ast;
    matrix[nchar, nchar] R_ast;
    matrix[nchar, nchar] R_ast_inverse;
    // thin and scale the QR decomposition
    Q_ast = qr_thin_Q(Z) * sqrt(nsub - 1);
    R_ast = qr_thin_R(Z) / sqrt(nsub - 1);
    R_ast_inverse = inverse(R_ast);
}
parameters {}
model {}
generated quantities{
    real P0 = normal_rng(1, 1);
    real sigma_P0 = fabs(normal_rng(0, 1));
    vector[nchar] theta_p0 = to_vector(normal_rng(rep_vector(0, nchar), 0.25));
    vector[nsub] p0 = P0 + to_vector(student_t_rng(nu, zeros, 1))*sigma_P0 + Q_ast*theta_p0;
    vector[nchar] beta_p0 = R_ast_inverse*theta_p0;
    
    real Delta = normal_rng(-1, 1);
    real sigma_Delta = fabs(normal_rng(0, 1));
    vector[nsub] delta = Delta + to_vector(student_t_rng(nu, zeros, 1))*sigma_Delta;
    
    real Alpha = normal_rng(1, 1);
    real sigma_Alpha = fabs(normal_rng(0, 0.33));
    vector[nsub] alpha = 1 + exp(Alpha + to_vector(student_t_rng(nu, zeros, 1))*sigma_Alpha);
    
    real Sigma = normal_rng(-1, 1);
    real sigma_sigma = fabs(normal_rng(0, 0.33));
    vector[nsub] sigma = exp(Sigma + to_vector(student_t_rng(nu, zeros, 1))*sigma_sigma);
    
    real sigma_stim = fabs(normal_rng(0, 0.33));
    vector[nstim] stim = to_vector(student_t_rng(nu, rep_vector(0, nstim), 1))*sigma_stim;
    
    vector[total_trials] mu = stim[stim_id] + p0[sid] + delta[sid].*(1 - exp(-alpha[sid].*t));
    vector[total_trials] rt = to_vector(lognormal_rng(mu, sigma[sid]));
    for (i in 1:total_trials) {
        if (rt[i] > 10)
            rt[i] = 10;
    }
}
1 Like

An addendum: is there any reason to think I might have more luck using a map_rect-based approach?

No, map_rect won’t do better than reduce_sum.

I’d recommend to switch over to write down your model with brms if that’s possible. Then you can just turn on threading via a command line switch.

Have a read on the vignette for threading part of the brms package to get some more intuition about things.

But before you continue with this, I would strongly encourage you to run profiling. You need to find out the runtime cost of the transformed parameter block vs the likelihood. Profiling is available for a few releases now.

Thanks for the suggestion! I looked at the threading vignette for BRMS and it looks like it will automatically parallelize log likelihood calculation. Since the likelihood for my model is lognormal, I assume it is not very expensive (but I will confirm that with profiling). In that case, I would not expect BRMS’s approach to give me much of a speedup. That’s what the ‘simple reduce_sum’ model I posted is meant to do, but it resulted in a substantial slowdown.

The threading vignette also notes the need to play around with grainsize, so that’s probably something I need to do.

I can also try to implement some of the changes over in this thread about
How to most efficiently reduce_sum in a hierarchical logistic model
.

Thanks again!

So I looked at the profiling, and in the un-parallelized model about half the runtime was in the first likelihood statement and the other half was in the big transformed parameter.

I followed the advice in this post: Reduce_sum with hierarchical vector auto-regression - #9 by wds15

I ended up slicing by participants. This meant I had to do some reshaping of my data into a (ragged-ish) 2d array, but once I did that it sped up non-threaded performance a bit, which is surprising as I replaced a vectorized operation with a loop, and I thought that was usually bad.

Once I had the big transform and likelihood sliced up, moving into multi-threading gave me the big speed-ups I expected. 4 threads per chain gave me approx. 4x speedup vs the original model and about 3.5x the unthreaded equivalent model. Adding more threads (up to using all the physical cores on the machine) gave marginal additional speedups (10s here and there) and messing with grainsize also gave me some minor speedups, but unless runtimes start getting a lot longer, I doubt I’ll see meaningful speedups from additional fiddling.

Thanks for the helpful posts here and in other threads!

Here’s the sped up model, in case anyone is interested:

functions {
    // computes the log likelihood for one subject
    real subj_ll(real[] rt, vector stim, real p0, real delta, real alpha, real[] t, real sigma){
        int sz = size(rt);
        real mu[sz];
        for (i in 1:sz) {
            mu[i] = stim[i] + p0 + delta*(1 - exp(-alpha*t[i]));
        }
        return lognormal_lpdf(rt | mu, sigma);
    }
    
    // recieves a slice, computes for each subject in the slice
    real partial_sum(array[,] real rt, int start, int end, array[,] int stim_id, vector stim, vector p0, vector delta, vector alpha, array[,] real t, vector sigma, array[] int ragged_n) {
        int N = end - start + 1;
        real accum = 0;
        for (si in 1:N) {
            int i = si + start - 1;
            accum += subj_ll(rt[si,1:ragged_n[i]], stim[stim_id[i,1:ragged_n[i]]], p0[i], delta[i], alpha[i], t[i,1:ragged_n[i]], sigma[i]);
        }
        return accum;
    }
}
data {
    int<lower=1> grainsize;
    int<lower=1> ntrial;
    int<lower=1> nsub;
    int<lower=1> nstim;
    int<lower=1> nchar;
    int<lower=1, upper=nsub> sid[ntrial];
    int<lower=1, upper=nstim> stim_id[ntrial];
    vector[ntrial] t;
    vector[ntrial] rt;
    matrix[nsub, nchar] Z; //assumed to be mean zero, scale 1
    
    real<lower=0> U; // rt values greater than or equal to U are considered MISSING
}
transformed data {
    // identify trials that are 'late' (after the 10 s timeout)
    real nu = 10;
    int latecount = 0;
    int good_idx_tmp[ntrial];
    int late_idx_tmp[ntrial];
    int idx = 0;
    for (i in 1:ntrial) {
        if (rt[i] >= U) {
            latecount += 1;
            late_idx_tmp[latecount] = i;
        } else {
            idx += 1;
            good_idx_tmp[idx] = i;
        }
    }
    int good_idx[ntrial-latecount] = good_idx_tmp[1:(ntrial-latecount)];
    int late_idx[latecount] = late_idx_tmp[1:latecount];
    
    // QR Reparameterization, see https://mc-stan.org/docs/2_27/stan-users-guide/QR-reparameterization-section.html
    matrix[nsub, nchar] Q_ast;
    matrix[nchar, nchar] R_ast;
    matrix[nchar, nchar] R_ast_inverse;
    // thin and scale the QR decomposition
    Q_ast = qr_thin_Q(Z) * sqrt(nsub - 1);
    R_ast = qr_thin_R(Z) / sqrt(nsub - 1);
    R_ast_inverse = inverse(R_ast);
    
    // now the project is to transform my long-form data into ragged square arrays.

    // work out how many good trials per subject we have
    array[nsub] int ragged_n = rep_array(0, nsub);
    for (i in 1:(ntrial-latecount)){
        ragged_n[sid[good_idx[i]]] += 1;
    }
    int W = max(ragged_n);
    
    // we need to square up stim_id, t, and rt
    array[nsub, W] int stim_id_sq;
    array[nsub, W] real t_sq;
    array[nsub, W] real rt_sq;
    
    array[nsub] int sub_idx = rep_array(1, nsub);
    for (g in 1:(ntrial-latecount)) {
        int i = good_idx[g];
        stim_id_sq[sid[i], sub_idx[sid[i]]] = stim_id[i];
        t_sq[sid[i], sub_idx[sid[i]]] = t[i];
        rt_sq[sid[i], sub_idx[sid[i]]] = rt[i];
        
        sub_idx[sid[i]] += 1;
    }
}
parameters {
    real P0;
    real<lower=0> sigma_P0;
    vector[nsub] zP0;
    vector[nchar] theta_p0;
    
    real Delta;
    real<lower=0> sigma_Delta;
    vector[nsub] zDelta;
    vector[nchar] theta_delta;
    
    real Alpha;
    real<lower=0> sigma_Alpha;
    vector[nsub] zAlpha;
    vector[nchar] theta_alpha;

    real Sigma;
    real<lower=0> sigma_sigma;
    vector[nsub] zSigma;
    //vector[nchar] theta_sigma;
    
    real<lower=0> sigma_stim;
    vector[nstim] zStim;
}
transformed parameters {
    vector[nsub] p0;
    vector[nsub] delta;
    vector[nsub] alpha;
    vector[nsub] sigma;
    vector[nstim] stim;
    
    profile("transform1") {
        p0 = P0 + zP0*sigma_P0 + Q_ast*theta_p0;
        delta = Delta + zDelta*sigma_Delta + Q_ast*theta_delta;
        alpha = 1 + exp(Alpha + zAlpha*sigma_Alpha + Q_ast*theta_alpha);
        sigma = exp(Sigma + zSigma*sigma_sigma);
        stim = zStim*sigma_stim;
    }
    //profile("transform2") { //230s
    //    mu = stim[stim_id] + p0[sid] + delta[sid].*(1 - exp(-alpha[sid].*t));
    //}
}
model {
    profile("xlik"){ 
        // partial_sum(array[,] real rt, int start, int end, array[,] int stim_id, vector stim, vector p0, vector delta, vector alpha, array[,] real t, vector sigma, array[] int ragged_n)
        //target += partial_sum(rt_sq, 1, nsub, stim_id_sq, stim, p0, delta, alpha, t_sq, sigma, ragged_n);
        target += reduce_sum(partial_sum, rt_sq, grainsize, stim_id_sq, stim, p0, delta, alpha, t_sq, sigma, ragged_n);
    }
    profile("lik2"){
        // lccdf(mu, sigma)
        target += lognormal_lccdf(U | stim[stim_id[late_idx]] + p0[sid[late_idx]] + delta[sid[late_idx]].*(1 - exp(-alpha[sid[late_idx]].*t[late_idx])), sigma[sid[late_idx]]); 
    }
    
    profile("priors"){
        P0 ~ normal(1,1);
        sigma_P0 ~ normal(0, 1);
        zP0 ~ student_t(nu, 0, 1);
        theta_p0 ~ normal(0, 0.25);

        Delta ~ normal(-1, 1);
        sigma_Delta ~ normal(0, 1);
        zDelta ~ student_t(nu, 0, 1);
        theta_delta ~ normal(0, 0.25);

        Alpha ~ normal(1, 1);
        sigma_Alpha ~ normal(0, 0.33);
        zAlpha ~ student_t(nu, 0, 1);
        theta_alpha ~ normal(0, 0.25);

        Sigma ~ normal(-1, 1);
        sigma_sigma ~ normal(0, 0.33);
        zSigma ~ student_t(nu, 0, 1);
        #theta_sigma ~ normal(0, 0.25);

        sigma_stim ~ normal(0, 0.33);
        zStim ~ student_t(nu, 0, 1);
    }
} 
generated quantities {
    vector[nchar] beta_p0;
    vector[nchar] beta_delta;
    vector[nchar] beta_alpha;
    vector[nsub] pEnd;
    vector[ntrial] rt_hat;
    
    profile("gen_betas"){
        beta_p0 = R_ast_inverse*theta_p0;
        beta_delta = R_ast_inverse*theta_delta;
        beta_alpha = R_ast_inverse*theta_alpha;
    }
        
    profile("gen_pEnd") {
        pEnd = exp(p0 + delta.*(1 - exp(-alpha)));
    }
    profile("gen_rt_hat") {
        rt_hat = to_vector(lognormal_rng(stim[stim_id] + p0[sid] + delta[sid].*(1 - exp(-alpha[sid].*t)), sigma[sid]));

        for (i in 1:ntrial) {
            if (rt_hat[i] > 10)
                rt_hat[i] = 10;
        }
    }
}
2 Likes

How can you profile which part of the model takes longer? I did not know about that.

It’s super helpful! I found this vignette explaining how to profile stan code helpful. Basically, you wrap the code of interest in a ‘profile’ block:

profile("priors") {
  target += std_normal_lpdf(beta);
  target += std_normal_lpdf(alpha);
}
profile("likelihood") {
  target += bernoulli_logit_lpmf(y | X * beta + alpha);
}

In my environment, after I compile and run the code this produces a csv file called ‘profile.csv’ with information about the runtimes of enclosed code. So you might see that the code called “priors” ran 10000 times and took 12s total, and the code called “likelihood” ran the same number of times and took 160s total. It tells you some other things too, but the total runtime is what I found most useful.

One thing to be careful about is variable scope. Say you have in a transformed parameters block, this line:

real alpha = some_function(beta);

If you wanted to know how long ‘some_function’ takes, you might just wrap that line in a profile block, but that would hide alpha from its intended scope. Instead:

real alpha;
profile("CustomFunction") {
  alpha = some_function(beta);
}