Log_sum_exp: sequentially or "in bulk"?

In the course of preparing a pull request for brms on the Conway-Maxwell Poisson (preliminary code hosted here but not ready for prime time yet), @GuidoAMoreira and I stumbled upon a question that I’d hope you stanimals can help me with: when doing log_sum_exp, should we do

for (k in 1:N) ans = log_sum_exp(ans, lterm(k));

or

for (k in 1:N) lterms[k + 1] = lterm(k);
ans = log_sum_exp(lterms);

?
I call these “sequential” and “bulk”, for lack of a better nomenclature.

I wrote a little script to experiment with this, and the results do not seem to make a lot of sense.
Here’s my R code:

library(cmdstanr)

compiled <- cmdstanr::cmdstan_model("log_sum_exp_test.stan")
mu <- 5
nu <- 1
if(nu == 2){
  TV <- log(besselI(2*sqrt(mu), nu = 0))  
}else{
  if(nu == 1){
    TV <- mu 
  }
}
test.data <- list(
  N = 1000,
  mu = mu,
  nu = nu,
  trueV = TV
)

raw <- compiled$sample(data = test.data, chains = 1, 
                       iter_warmup = 0, iter_sampling = 1,
                       fixed_param = TRUE, show_messages = TRUE)

ofInterest <- c("trueAnswer", "bulkAnswer", "seqAnswer",
                "diffSeq", "diffBulk")
print(raw, ofInterest, digits = 20)

and here’s the corresponding Stan program:

functions{
  real  signum(real x) {
    real ans;
    if(x < 0){
      ans = -1;
    }else{
      if(x == 0){
        ans = 0;
      }else{
        ans = 1;
      }
    }
    return ans;
  }
  real robust_difference(real x, real y){
    real sgn = signum(x-y);
    real m = min({x, y});
    real M = max({x, y});
    return(sgn * exp(log_diff_exp(M, m)));
  }
  real log_COM_Poisson(int k, real log_mu, real nu){
    return k * log_mu - nu * lgamma(k + 1);
  }
}
data{
  int<lower=0> N;
  real<lower=0> mu;
  real<lower=0> nu;
  real trueV;
}
transformed data{
  real lmu = log(mu);
}
generated quantities{
  real lterms[N + 1];
  real diffSeq;
  real diffBulk;
  real bulkAnswer;
  real seqAnswer = log_COM_Poisson(0, lmu, nu);
  real trueAnswer = trueV;
  lterms[1] = seqAnswer;
  for (k in 1:N){
    lterms[k + 1] = log_COM_Poisson(k, lmu, nu);
    seqAnswer = log_sum_exp(seqAnswer, lterms[k + 1]);
  }
  bulkAnswer = log_sum_exp(lterms);
  diffSeq = robust_difference(seqAnswer, trueAnswer);
  diffBulk = robust_difference(bulkAnswer, trueAnswer);
}

Results are something like this:

#### nu = 1
## mu = 5, nu = 1 and N = 1000 gives |diffBulk| < |diffSeq|
## mu = 6, nu = 1 and N = 1000 gives |diffBulk| < |diffSeq|
## mu = 11, nu = 1 and N = 1000 gives |diffBulk| < |diffSeq|
## mu = 12, nu = 1 and N = 1000 gives |diffBulk| < |diffSeq|
## mu = 12.5, nu = 1 and N = 1000 gives |diffBulk| > |diffSeq|
## mu = 13, nu = 1 and N = 1000 gives |diffBulk| < |diffSeq|
## mu = 17, nu = 1 and N = 1000 gives |diffBulk| == |diffSeq|
## mu = 50, nu = 1 and N = 1000 gives |diffBulk| == |diffSeq|
#### nu = 2
## mu = 5, nu = 2 and N = 1000 gives |diffBulk| == |diffSeq|
## mu = 6, nu = 2 and N = 1000 gives |diffBulk| < |diffSeq|
## mu = 11, nu = 2 and N = 1000 gives |diffBulk| < |diffSeq|
## mu = 12, nu = 2 and N = 1000 gives |diffBulk| < |diffSeq|
## mu = 12.5, nu = 2 and N = 1000 gives |diffBulk| > |diffSeq|
## mu = 13, nu = 2 and N = 1000 gives |diffBulk| > |diffSeq|
## mu = 17, nu = 2 and N = 1000 gives |diffBulk| > |diffSeq|
## mu = 50, nu = 2 and N = 1000 gives |diffBulk| == |diffSeq|

I think I don’t understand these results because I don’t really know how floating point works. So I’m calling on @bbbales2, @nhuurre, @martinmodrak, @wds15 and @Bob_Carpenter to please educate me on what is going on.

> sessionInfo()
R version 4.0.4 (2021-02-15)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 20.04.2 LTS

Matrix products: default
BLAS:   /usr/lib/x86_64-linux-gnu/atlas/libblas.so.3.10.3
LAPACK: /usr/lib/x86_64-linux-gnu/atlas/liblapack.so.3.10.3

locale:
 [1] LC_CTYPE=pt_BR.UTF-8       LC_NUMERIC=C               LC_TIME=pt_BR.UTF-8       
 [4] LC_COLLATE=en_US.UTF-8     LC_MONETARY=pt_BR.UTF-8    LC_MESSAGES=en_US.UTF-8   
 [7] LC_PAPER=pt_BR.UTF-8       LC_NAME=C                  LC_ADDRESS=C              
[10] LC_TELEPHONE=C             LC_MEASUREMENT=pt_BR.UTF-8 LC_IDENTIFICATION=C       

attached base packages:
[1] stats     graphics  grDevices utils     datasets  methods   base     

other attached packages:
[1] cmdstanr_0.3.0

loaded via a namespace (and not attached):
 [1] Rcpp_1.0.6         pillar_1.5.1       compiler_4.0.4     prettyunits_1.1.1 
 [5] tools_4.0.4        pkgbuild_1.2.0     jsonlite_1.7.2     lifecycle_1.0.0   
 [9] tibble_3.1.0       checkmate_2.0.0    gtable_0.3.0       pkgconfig_2.0.3   
[13] rlang_0.4.10       DBI_1.1.1          cli_2.3.1          parallel_4.0.4    
[17] curl_4.3           xfun_0.22          loo_2.4.1          gridExtra_2.3     
[21] dplyr_1.0.5        knitr_1.31         generics_0.1.0     vctrs_0.3.6       
[25] tidyselect_1.1.0   stats4_4.0.4       grid_4.0.4         inline_0.3.17     
[29] glue_1.4.2         data.table_1.14.0  R6_2.5.0           processx_3.5.0    
[33] fansi_0.4.2        rstan_2.26.1       purrr_0.3.4        ggplot2_3.3.3     
[37] callr_3.6.0        posterior_0.1.3    magrittr_2.0.1     codetools_0.2-18  
[41] matrixStats_0.58.0 scales_1.1.1       backports_1.2.1    ps_1.6.0          
[45] ellipsis_0.3.1     StanHeaders_2.26.1 assertthat_0.2.1   abind_1.4-5       
[49] colorspace_2.0-0   V8_3.4.0           utf8_1.2.1         munsell_0.5.0     
[53] RcppParallel_5.0.3 crayon_1.4.1      
1 Like

Do “bulk”. What’s confusing you?

1 Like

Why they are different is what’s confusing me. It’s not so much that I expect them to be the same but that I want to understand what’s going on so I can use that knowledge in the future. Also, bulk presupposes a fixed size object, whilst sequential doesn’t.

1 Like

Hm, depending on the log_sum_exp implementation, the two versions may do “very” different things. Google eg “Kahan Summation”.

Can you repeat the experiments, but instead of the sign you report the relative difference? I’d suspect the relative difference to be quite small, in the order of 1e-16.

The reason for the difference is probably just that log(exp(x))! = x due to rounding errors. You shouldn’t check for exact equality with floating point numbers.

3 Likes

Do bulk and have a look at the definition of log sum exp in Stan math…then you see quickly why.

1 Like