I’ve been playing about with the new reduce_sum() function for within-chain parallelisation, and I’m unclear on how it works when the x argument (that will be broken into x_subset for passing to the partial_sum function) is passed an array with more than 1 dimension. The documentation says that reduce_sum will take an array of any dimension, but I’ve not been able to find an example of it in use other than those with 1D arrays (e.g. https://mc-stan.org/users/documentation/case-studies/reduce_sum_tutorial.html).
My questions are: (1) when given an array of more dimensions than 1, how will reduce_sum choose to cut it (I’m mainly interested in the 2D case)? Based on what I’ve been doing with a 2D integer valued array, I think it’s cutting it by rows, but some clarification would be a great help! (2) Related to this, when it’s given an array with 2 or more dimensions, how will grainsize=1 (let reduce_sum choose number of slices) choose these? Is it counting the size of the array element-wise, or in terms of it’s dimensions, or something else entirely?
Thanks!