Trying within-chain parallelization with reduce_sum increases runtime a lot

So I looked at the profiling, and in the un-parallelized model about half the runtime was in the first likelihood statement and the other half was in the big transformed parameter.

I followed the advice in this post: Reduce_sum with hierarchical vector auto-regression - #9 by wds15

I ended up slicing by participants. This meant I had to do some reshaping of my data into a (ragged-ish) 2d array, but once I did that it sped up non-threaded performance a bit, which is surprising as I replaced a vectorized operation with a loop, and I thought that was usually bad.

Once I had the big transform and likelihood sliced up, moving into multi-threading gave me the big speed-ups I expected. 4 threads per chain gave me approx. 4x speedup vs the original model and about 3.5x the unthreaded equivalent model. Adding more threads (up to using all the physical cores on the machine) gave marginal additional speedups (10s here and there) and messing with grainsize also gave me some minor speedups, but unless runtimes start getting a lot longer, I doubt I’ll see meaningful speedups from additional fiddling.

Thanks for the helpful posts here and in other threads!

Here’s the sped up model, in case anyone is interested:

functions {
    // computes the log likelihood for one subject
    real subj_ll(real[] rt, vector stim, real p0, real delta, real alpha, real[] t, real sigma){
        int sz = size(rt);
        real mu[sz];
        for (i in 1:sz) {
            mu[i] = stim[i] + p0 + delta*(1 - exp(-alpha*t[i]));
        }
        return lognormal_lpdf(rt | mu, sigma);
    }
    
    // recieves a slice, computes for each subject in the slice
    real partial_sum(array[,] real rt, int start, int end, array[,] int stim_id, vector stim, vector p0, vector delta, vector alpha, array[,] real t, vector sigma, array[] int ragged_n) {
        int N = end - start + 1;
        real accum = 0;
        for (si in 1:N) {
            int i = si + start - 1;
            accum += subj_ll(rt[si,1:ragged_n[i]], stim[stim_id[i,1:ragged_n[i]]], p0[i], delta[i], alpha[i], t[i,1:ragged_n[i]], sigma[i]);
        }
        return accum;
    }
}
data {
    int<lower=1> grainsize;
    int<lower=1> ntrial;
    int<lower=1> nsub;
    int<lower=1> nstim;
    int<lower=1> nchar;
    int<lower=1, upper=nsub> sid[ntrial];
    int<lower=1, upper=nstim> stim_id[ntrial];
    vector[ntrial] t;
    vector[ntrial] rt;
    matrix[nsub, nchar] Z; //assumed to be mean zero, scale 1
    
    real<lower=0> U; // rt values greater than or equal to U are considered MISSING
}
transformed data {
    // identify trials that are 'late' (after the 10 s timeout)
    real nu = 10;
    int latecount = 0;
    int good_idx_tmp[ntrial];
    int late_idx_tmp[ntrial];
    int idx = 0;
    for (i in 1:ntrial) {
        if (rt[i] >= U) {
            latecount += 1;
            late_idx_tmp[latecount] = i;
        } else {
            idx += 1;
            good_idx_tmp[idx] = i;
        }
    }
    int good_idx[ntrial-latecount] = good_idx_tmp[1:(ntrial-latecount)];
    int late_idx[latecount] = late_idx_tmp[1:latecount];
    
    // QR Reparameterization, see https://mc-stan.org/docs/2_27/stan-users-guide/QR-reparameterization-section.html
    matrix[nsub, nchar] Q_ast;
    matrix[nchar, nchar] R_ast;
    matrix[nchar, nchar] R_ast_inverse;
    // thin and scale the QR decomposition
    Q_ast = qr_thin_Q(Z) * sqrt(nsub - 1);
    R_ast = qr_thin_R(Z) / sqrt(nsub - 1);
    R_ast_inverse = inverse(R_ast);
    
    // now the project is to transform my long-form data into ragged square arrays.

    // work out how many good trials per subject we have
    array[nsub] int ragged_n = rep_array(0, nsub);
    for (i in 1:(ntrial-latecount)){
        ragged_n[sid[good_idx[i]]] += 1;
    }
    int W = max(ragged_n);
    
    // we need to square up stim_id, t, and rt
    array[nsub, W] int stim_id_sq;
    array[nsub, W] real t_sq;
    array[nsub, W] real rt_sq;
    
    array[nsub] int sub_idx = rep_array(1, nsub);
    for (g in 1:(ntrial-latecount)) {
        int i = good_idx[g];
        stim_id_sq[sid[i], sub_idx[sid[i]]] = stim_id[i];
        t_sq[sid[i], sub_idx[sid[i]]] = t[i];
        rt_sq[sid[i], sub_idx[sid[i]]] = rt[i];
        
        sub_idx[sid[i]] += 1;
    }
}
parameters {
    real P0;
    real<lower=0> sigma_P0;
    vector[nsub] zP0;
    vector[nchar] theta_p0;
    
    real Delta;
    real<lower=0> sigma_Delta;
    vector[nsub] zDelta;
    vector[nchar] theta_delta;
    
    real Alpha;
    real<lower=0> sigma_Alpha;
    vector[nsub] zAlpha;
    vector[nchar] theta_alpha;

    real Sigma;
    real<lower=0> sigma_sigma;
    vector[nsub] zSigma;
    //vector[nchar] theta_sigma;
    
    real<lower=0> sigma_stim;
    vector[nstim] zStim;
}
transformed parameters {
    vector[nsub] p0;
    vector[nsub] delta;
    vector[nsub] alpha;
    vector[nsub] sigma;
    vector[nstim] stim;
    
    profile("transform1") {
        p0 = P0 + zP0*sigma_P0 + Q_ast*theta_p0;
        delta = Delta + zDelta*sigma_Delta + Q_ast*theta_delta;
        alpha = 1 + exp(Alpha + zAlpha*sigma_Alpha + Q_ast*theta_alpha);
        sigma = exp(Sigma + zSigma*sigma_sigma);
        stim = zStim*sigma_stim;
    }
    //profile("transform2") { //230s
    //    mu = stim[stim_id] + p0[sid] + delta[sid].*(1 - exp(-alpha[sid].*t));
    //}
}
model {
    profile("xlik"){ 
        // partial_sum(array[,] real rt, int start, int end, array[,] int stim_id, vector stim, vector p0, vector delta, vector alpha, array[,] real t, vector sigma, array[] int ragged_n)
        //target += partial_sum(rt_sq, 1, nsub, stim_id_sq, stim, p0, delta, alpha, t_sq, sigma, ragged_n);
        target += reduce_sum(partial_sum, rt_sq, grainsize, stim_id_sq, stim, p0, delta, alpha, t_sq, sigma, ragged_n);
    }
    profile("lik2"){
        // lccdf(mu, sigma)
        target += lognormal_lccdf(U | stim[stim_id[late_idx]] + p0[sid[late_idx]] + delta[sid[late_idx]].*(1 - exp(-alpha[sid[late_idx]].*t[late_idx])), sigma[sid[late_idx]]); 
    }
    
    profile("priors"){
        P0 ~ normal(1,1);
        sigma_P0 ~ normal(0, 1);
        zP0 ~ student_t(nu, 0, 1);
        theta_p0 ~ normal(0, 0.25);

        Delta ~ normal(-1, 1);
        sigma_Delta ~ normal(0, 1);
        zDelta ~ student_t(nu, 0, 1);
        theta_delta ~ normal(0, 0.25);

        Alpha ~ normal(1, 1);
        sigma_Alpha ~ normal(0, 0.33);
        zAlpha ~ student_t(nu, 0, 1);
        theta_alpha ~ normal(0, 0.25);

        Sigma ~ normal(-1, 1);
        sigma_sigma ~ normal(0, 0.33);
        zSigma ~ student_t(nu, 0, 1);
        #theta_sigma ~ normal(0, 0.25);

        sigma_stim ~ normal(0, 0.33);
        zStim ~ student_t(nu, 0, 1);
    }
} 
generated quantities {
    vector[nchar] beta_p0;
    vector[nchar] beta_delta;
    vector[nchar] beta_alpha;
    vector[nsub] pEnd;
    vector[ntrial] rt_hat;
    
    profile("gen_betas"){
        beta_p0 = R_ast_inverse*theta_p0;
        beta_delta = R_ast_inverse*theta_delta;
        beta_alpha = R_ast_inverse*theta_alpha;
    }
        
    profile("gen_pEnd") {
        pEnd = exp(p0 + delta.*(1 - exp(-alpha)));
    }
    profile("gen_rt_hat") {
        rt_hat = to_vector(lognormal_rng(stim[stim_id] + p0[sid] + delta[sid].*(1 - exp(-alpha[sid].*t)), sigma[sid]));

        for (i in 1:ntrial) {
            if (rt_hat[i] > 10)
                rt_hat[i] = 10;
        }
    }
}
2 Likes