How to use reduce-sum for a function returning vector (implementation of Dirichlet Process Mixtures)

For example, I have a linear model: Y_i = X\beta + Z\alpha_i + \epsilon, where \alpha_i are subject level random effects.
We assume that \alpha_i \sim Dirichlet Process (a,G_0) where G_0 is a multivariate normal N(0, \Sigma).
Using stick-breaking representation, if we allow maximum K classes, the distribution can be written as \sum \limits_{k=1}^K\pi_k\delta_{\alpha_{ik}} with \pi_k = V_k\prod \limits_{l<k}(1-V_l), V_h \sim beta(1, a), \alpha_{ik} \sim N(0, \Sigma).

A standard implementation of this in stan requires to evaluate log(\sum \limits_{k=1}^K \pi_k l(\alpha_{ik})) where l(\alpha_{ik}) denotes evaluated log-likelihood for k th realization of \alpha_i from N(0, \Sigma).

Implementation of this in stan is simple and I am able to do it.
The problem is that the number of subjects in our data is very high. The evaluation of each l(\alpha_{ik}) is very costly. This motivates us to use ‘reduce-sum’ so that the computation can be split parallelly.

In my understanding, this can be done in two ways-
(1) Use reduce-sum for each k for computing each l(\alpha_{ik}), once all are computed use log-sum-exp
(2) Use a function that returns a vector output of length K. Each of the elements of the vector is a partial sum for evaluating each l(\alpha_{ik}). Use reduce-sum on that function

(1) involves multiple evaluations of reduce-sum and I doubt the efficiency of that.
(2) will be efficient as it requires only one evaluation of reduce-sum but the function used in reduce-sum returns a vector but in the documentation, it is specified that the function maps to R. Although conceptually it is very simple i.e. instead of scaler sum it will be doing a vector sum but I don’t know if reduce-sum is capable to do that. I shall be very happy if any of the experts or developers give me any suggestions.

This is not possible with reduce_sum, since that can only return a single scalar. Have a look at the documentation for map_rect instead

Yes, your only choice is to return scalars from the function given to reduce_sum. So you have to return partial sums over pi_k * l(alpha_k). So the function forming the partial sums can itself do log_sum_exp on the set of terms it is supposed to calculate, but then has to return the results on the natural scale … mean to exponentiate it.

In case you are worried about numerical stability of things, then you can use the reduce_sum_static thing. This will let you control how large the partial sums are. Making these large enough things should work out, since then you can apply on hopefully large enough partial sums the log_sum_exp and let reduce_sum_static do the remaining summation on the natural scale.

Hi @andrjohns, I shall take a look at it. Thanks!

Hi @wds15,
I can write down K functions each calculating \pi_kL(\alpha_k) for k=1,...,K and L(\alpha_k) is the likelihood (not log-likelihood; it was a mistake). Then I can use K reduce-sum, one for each k. Can you comment on the parallelization or speeding up as I am using multiple reduce-sum? Or do you think any other way to do it?

Note that, log(a_1(1:n) + a_2(1:n)) \neq \\ log(a_1(1:n_1) + a_2(1:n_1)) + log(a_1((n_1+1):n_2) + a_2((n_1+1):n_2))
That’s why this has become tricky.
Pardon me if I am misunderstanding anything.

So you have basically k different likelihoods for your data? How about starting with one reduce_sum for these k likelihoods. In case each of these is very expensive, then I would probably try using the static version of reduce_sum with a grainsize of 1. if you have more cpus to spend, then you can nest the reduce_sum calls (but make sure that the outer reduce_sum is really about splitting huge amounts of work).

Hi @wds15,

You are right, I have to evaluate k different likelihoods and use log_sum_exp. The problem with using one reduce-sum is that the mixture of the likelihoods lacks the property that g(x_1) + g(x_2) = g(x_1 + x_2).
Are you suggesting using a total of k reduce-sums? Note that k is a small number between 5 to 10.

Do you need the k likelihoods per data-row or do you need k different likelihoods evaluated for the entire data set? Maybe you write down the stan model without reduce_sum here in a schematic way?

Hi @wds15,

This is the vital question. For me, I need k different likelihoods evaluated for the entire data-set. I am giving some part of the code here for better understanding:

real ps1 = 0;
real ps2 = 0;
real ps3 = 0;
real ps4 = 0;
real ps5 = 0;
real ps[K];
  
for (n in 1:N) {
    for (j in (part_dur[n]+2):part_dur[n+1]) {
      log_hazard1 = dot_product(b_mat_start1[n], bs_start[j])*dot_product(b_mat_surv1[n], bs_surv[j]) + beta_surv*XX1[j] + link_surv*indicator[temp1] + intcpt_surv;
      log_cum_hazard1 = exp(beta_surv*XX1[j] + link_surv*indicator[temp1] + intcpt_surv)*(duration[j]*(.5))*dot_product(wt, exp(dot_product(b_mat_start1[n], bs_start[j])*(nodes[(15*(j-1) + 1):(15*j)]*((b_mat_surv1[n])'))));
      ps1 = ps1 + log_hazard1 - log_cum_hazard1;

      log_hazard2 = dot_product(b_mat_start2[n], bs_start[j])*dot_product(b_mat_surv2[n], bs_surv[j]) + beta_surv*XX1[j] + link_surv*indicator[temp1] + intcpt_surv;
      log_cum_hazard2 = exp(beta_surv*XX1[j] + link_surv*indicator[temp1] + intcpt_surv)*(duration[j]*(.5))*dot_product(wt, exp(dot_product(b_mat_start2[n], bs_start[j])*(nodes[(15*(j-1) + 1):(15*j)]*((b_mat_surv2[n])'))));
      ps2 = ps2 + log_hazard2 - log_cum_hazard2;

      log_hazard3 = dot_product(b_mat_start3[n], bs_start[j])*dot_product(b_mat_surv3[n], bs_surv[j]) + beta_surv*XX1[j] + link_surv*indicator[temp1] + intcpt_surv;
      log_cum_hazard3 = exp(beta_surv*XX1[j] + link_surv*indicator[temp1] + intcpt_surv)*(duration[j]*(.5))*dot_product(wt, exp(dot_product(b_mat_start3[n], bs_start[j])*(nodes[(15*(j-1) + 1):(15*j)]*((b_mat_surv3[n])'))));
      ps3 = ps3 + log_hazard3 - log_cum_hazard3;

      log_hazard4 = dot_product(b_mat_start4[n], bs_start[j])*dot_product(b_mat_surv4[n], bs_surv[j]) + beta_surv*XX1[j] + link_surv*indicator[temp1] + intcpt_surv;
      log_cum_hazard4 = exp(beta_surv*XX1[j] + link_surv*indicator[temp1] + intcpt_surv)*(duration[j]*(.5))*dot_product(wt, exp(dot_product(b_mat_start4[n], bs_start[j])*(nodes[(15*(j-1) + 1):(15*j)]*((b_mat_surv4[n])'))));
      ps4 = ps4 + log_hazard4 - log_cum_hazard4;

      log_hazard5 = dot_product(b_mat_start5[n], bs_start[j])*dot_product(b_mat_surv5[n], bs_surv[j]) + beta_surv*XX1[j] + link_surv*indicator[temp1] + intcpt_surv;
      log_cum_hazard5 = exp(beta_surv*XX1[j] + link_surv*indicator[temp1] + intcpt_surv)*(duration[j]*(.5))*dot_product(wt, exp(dot_product(b_mat_start5[n], bs_start[j])*(nodes[(15*(j-1) + 1):(15*j)]*((b_mat_surv5[n])'))));
      ps5 = ps5 + log_hazard5 - log_cum_hazard5;

      temp1 = temp1 + 1;
      }
    }
    ps[1] = log(eta[1]) + ps1; ps[2] = log(eta[2]) + ps2; ps[3] = log(eta[3]) + ps3;
    ps[4] = log(eta[4]) + ps4; ps[5] = log(eta[5]) + ps5;
  target += log_sum_exp(ps);

Here, n denotes the subjects and another nested loop is there with j to take care of the observations within a subject. Here, ps1,…, ps5 are the 5 likelihoods calculated.

Hope this makes the things little more clear. Please let me know your suggestions.

Then you I’d use reduce_sum_static with grainsize equal to 1 and let that calculate the sum over the data-set. You have to return not the log of the likelihood, but the exponentiated value. What you get back from reduce_sum would then have to be log transformed again. Let’s first do this.

… but if I were you, I would first try out if the summation without the log_sum_exp is a problem wrt to numerical stability. For that you can just replace the log_sum_exp with log(sum(exp(ps))) and check if things work still ok. If yes, then the inner sum of exp(ps) can be replaced by reduce_sum_static.

Hi @wds15,

log(sum(exp(ps))) is not at all stable for this computation, I was getting log(0) error but then log_sum_exp came to rescue.

Here, a single reduce_sum_static would be appropriate if I could return the vector ps as an output of the function I am using for reduce_sum_static. Unfortunately, vector returning functions are not compatible. Not that the vector ps is additive componentwise but sum_exp(ps) is not additive. Only way I think about is to use reduce_sum_static 5 times, each for evaluating ps1,…,ps5.

Do you think there is a better way?

Yes, then you can use reduce sum in sequence only. Maybe a reduce log sum exp should be written…
Should not be hard to do, but some work given doc an testing.

Hi @wds15,

Even if you can make reduce_sum compatible with the functions returning vectors it will solve many problems including these. In that case, reduce_sum should return a vector where each element is the componentwise sum. It looks conceptually simple.

1 Like