Reduce_sum with array #dims > 1

I’ve been playing about with the new reduce_sum() function for within-chain parallelisation, and I’m unclear on how it works when the x argument (that will be broken into x_subset for passing to the partial_sum function) is passed an array with more than 1 dimension. The documentation says that reduce_sum will take an array of any dimension, but I’ve not been able to find an example of it in use other than those with 1D arrays (e.g. https://mc-stan.org/users/documentation/case-studies/reduce_sum_tutorial.html).

My questions are: (1) when given an array of more dimensions than 1, how will reduce_sum choose to cut it (I’m mainly interested in the 2D case)? Based on what I’ve been doing with a 2D integer valued array, I think it’s cutting it by rows, but some clarification would be a great help! (2) Related to this, when it’s given an array with 2 or more dimensions, how will grainsize=1 (let reduce_sum choose number of slices) choose these? Is it counting the size of the array element-wise, or in terms of it’s dimensions, or something else entirely?

Thanks!

1 Like

Say we have N data items and C is a size of the partial sum chosen. Then:

vector[5] tota[N]; => vector[5] partial[C];

real total[N,3]; => real partial[C,3];

The grainsize only controls how large C is.

Great. Thanks for that- really helpful!