In the course of preparing a pull request for brms on the Conway-Maxwell Poisson (preliminary code hosted here but not ready for prime time yet), @GuidoAMoreira and I stumbled upon a question that I’d hope you stanimals can help me with: when doing log_sum_exp
, should we do
for (k in 1:N) ans = log_sum_exp(ans, lterm(k));
or
for (k in 1:N) lterms[k + 1] = lterm(k);
ans = log_sum_exp(lterms);
?
I call these “sequential” and “bulk”, for lack of a better nomenclature.
I wrote a little script to experiment with this, and the results do not seem to make a lot of sense.
Here’s my R code:
library(cmdstanr)
compiled <- cmdstanr::cmdstan_model("log_sum_exp_test.stan")
mu <- 5
nu <- 1
if(nu == 2){
TV <- log(besselI(2*sqrt(mu), nu = 0))
}else{
if(nu == 1){
TV <- mu
}
}
test.data <- list(
N = 1000,
mu = mu,
nu = nu,
trueV = TV
)
raw <- compiled$sample(data = test.data, chains = 1,
iter_warmup = 0, iter_sampling = 1,
fixed_param = TRUE, show_messages = TRUE)
ofInterest <- c("trueAnswer", "bulkAnswer", "seqAnswer",
"diffSeq", "diffBulk")
print(raw, ofInterest, digits = 20)
and here’s the corresponding Stan program:
functions{
real signum(real x) {
real ans;
if(x < 0){
ans = -1;
}else{
if(x == 0){
ans = 0;
}else{
ans = 1;
}
}
return ans;
}
real robust_difference(real x, real y){
real sgn = signum(x-y);
real m = min({x, y});
real M = max({x, y});
return(sgn * exp(log_diff_exp(M, m)));
}
real log_COM_Poisson(int k, real log_mu, real nu){
return k * log_mu - nu * lgamma(k + 1);
}
}
data{
int<lower=0> N;
real<lower=0> mu;
real<lower=0> nu;
real trueV;
}
transformed data{
real lmu = log(mu);
}
generated quantities{
real lterms[N + 1];
real diffSeq;
real diffBulk;
real bulkAnswer;
real seqAnswer = log_COM_Poisson(0, lmu, nu);
real trueAnswer = trueV;
lterms[1] = seqAnswer;
for (k in 1:N){
lterms[k + 1] = log_COM_Poisson(k, lmu, nu);
seqAnswer = log_sum_exp(seqAnswer, lterms[k + 1]);
}
bulkAnswer = log_sum_exp(lterms);
diffSeq = robust_difference(seqAnswer, trueAnswer);
diffBulk = robust_difference(bulkAnswer, trueAnswer);
}
Results are something like this:
#### nu = 1
## mu = 5, nu = 1 and N = 1000 gives |diffBulk| < |diffSeq|
## mu = 6, nu = 1 and N = 1000 gives |diffBulk| < |diffSeq|
## mu = 11, nu = 1 and N = 1000 gives |diffBulk| < |diffSeq|
## mu = 12, nu = 1 and N = 1000 gives |diffBulk| < |diffSeq|
## mu = 12.5, nu = 1 and N = 1000 gives |diffBulk| > |diffSeq|
## mu = 13, nu = 1 and N = 1000 gives |diffBulk| < |diffSeq|
## mu = 17, nu = 1 and N = 1000 gives |diffBulk| == |diffSeq|
## mu = 50, nu = 1 and N = 1000 gives |diffBulk| == |diffSeq|
#### nu = 2
## mu = 5, nu = 2 and N = 1000 gives |diffBulk| == |diffSeq|
## mu = 6, nu = 2 and N = 1000 gives |diffBulk| < |diffSeq|
## mu = 11, nu = 2 and N = 1000 gives |diffBulk| < |diffSeq|
## mu = 12, nu = 2 and N = 1000 gives |diffBulk| < |diffSeq|
## mu = 12.5, nu = 2 and N = 1000 gives |diffBulk| > |diffSeq|
## mu = 13, nu = 2 and N = 1000 gives |diffBulk| > |diffSeq|
## mu = 17, nu = 2 and N = 1000 gives |diffBulk| > |diffSeq|
## mu = 50, nu = 2 and N = 1000 gives |diffBulk| == |diffSeq|
I think I don’t understand these results because I don’t really know how floating point works. So I’m calling on @bbbales2, @nhuurre, @martinmodrak, @wds15 and @Bob_Carpenter to please educate me on what is going on.
> sessionInfo()
R version 4.0.4 (2021-02-15)
Platform: x86_64-pc-linux-gnu (64-bit)
Running under: Ubuntu 20.04.2 LTS
Matrix products: default
BLAS: /usr/lib/x86_64-linux-gnu/atlas/libblas.so.3.10.3
LAPACK: /usr/lib/x86_64-linux-gnu/atlas/liblapack.so.3.10.3
locale:
[1] LC_CTYPE=pt_BR.UTF-8 LC_NUMERIC=C LC_TIME=pt_BR.UTF-8
[4] LC_COLLATE=en_US.UTF-8 LC_MONETARY=pt_BR.UTF-8 LC_MESSAGES=en_US.UTF-8
[7] LC_PAPER=pt_BR.UTF-8 LC_NAME=C LC_ADDRESS=C
[10] LC_TELEPHONE=C LC_MEASUREMENT=pt_BR.UTF-8 LC_IDENTIFICATION=C
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] cmdstanr_0.3.0
loaded via a namespace (and not attached):
[1] Rcpp_1.0.6 pillar_1.5.1 compiler_4.0.4 prettyunits_1.1.1
[5] tools_4.0.4 pkgbuild_1.2.0 jsonlite_1.7.2 lifecycle_1.0.0
[9] tibble_3.1.0 checkmate_2.0.0 gtable_0.3.0 pkgconfig_2.0.3
[13] rlang_0.4.10 DBI_1.1.1 cli_2.3.1 parallel_4.0.4
[17] curl_4.3 xfun_0.22 loo_2.4.1 gridExtra_2.3
[21] dplyr_1.0.5 knitr_1.31 generics_0.1.0 vctrs_0.3.6
[25] tidyselect_1.1.0 stats4_4.0.4 grid_4.0.4 inline_0.3.17
[29] glue_1.4.2 data.table_1.14.0 R6_2.5.0 processx_3.5.0
[33] fansi_0.4.2 rstan_2.26.1 purrr_0.3.4 ggplot2_3.3.3
[37] callr_3.6.0 posterior_0.1.3 magrittr_2.0.1 codetools_0.2-18
[41] matrixStats_0.58.0 scales_1.1.1 backports_1.2.1 ps_1.6.0
[45] ellipsis_0.3.1 StanHeaders_2.26.1 assertthat_0.2.1 abind_1.4-5
[49] colorspace_2.0-0 V8_3.4.0 utf8_1.2.1 munsell_0.5.0
[53] RcppParallel_5.0.3 crayon_1.4.1