Trigamma from boost?


I just saw a note from @bbbales2 on a PR that the trigamma function has bad precision in stan-math. Looking up the function I see that we ship our own implementation, but boost::math has a trigamma function. As we apparently never used the boost math trigamma I wonder if we should switch to their implementation?


Tagging @sakrejda

This might be a good time for us to find a way to make these decisions more consistent. I’ll give it a shot at it this weekend and post here.

1 Like

If you find an obvious way here that would be great - maybe a graphical approach? Please test the version without double to long double propagation (which deviates from from boosts default).

I was actually more looking for reasons which people recall as to why we dumped boost trigamma. If those do not exist, then the only explanation as to why we do not use boost is that it wasn’t available at the time when the function was introduced.

I would it consider it a good thing to use standard libraries for this type of things such that we don’t need to worry about maintenance (unless its not usable what we get from boost, of course; but I doubt that).

I think boost trigamma was never actually used, so its not the same case as lgamma that was replaced in

1 Like

Trigamma arises in higher order derivatives which I don’t think we currently use anywhere. It was introduced into boost only in version 1.58.0 (release April 17, 2015) hence the very old function not utilizing it. I had to do my own in nomad, for example, which came out before 1.58.0,

1 Like

Thanks for the info. There are a few digamma calls in our CDFs (for negative binomial or the beta). So whenever users deal with censored data from these models, then they will need the trigamma to get the gradient.

The boost trigamma looks a lot more involved to me so that we should probably just switch. boost also has the polygamma function as of now.

I’m not saying we have to go into this level of detail here but I’d like us to know when we switch out math functions why we do it so here’s what I did using trigamma as an exaample:

I wrote a quick test file (just appended below so the plots labels are obvious). It takes a really naive approach of doing an equally spaced sequence for both functions which is silly but good enough for this machine and a first shot. Downside is the output is 80Gb and I can only load sections into R for plotting at a time.

When looking at the output I think it’s good to focus on a few things:

  1. setting, I haven’t looked up which version you wanted me to run (re: double-long-double conversion). I can do that in the next round.

  2. having something to compare, ideally an implementation that we know the error for, etc… or just multiple implementations. I imagine boost’s implementation is better but I would like to grab reference values or values from another implementation. I know mentioned a C library that’s widely respected for special functions and I found it once but don’t have the link right now (@bgoodri if you know what I’m talking about I’d inlcude it in the comparison).

  3. finding the regions with problems, for this run it looks like even in absolute error the error is approaching 2e-9 as x approaches zero so maybe the branch where we accumulate 1/x^2 is a problem.

  4. our implementation looks like it might have some easy problems to fix as we accumulate 1/x^2 starting from the smallest x valueit might be a problem because we start with the largest 1/x^2 values. So maybe just fixing that would be good. Here’s what we do:

  z = x;
  value = 0.0;
  while (z < large) {
    value += 1.0 / (z * z);
    z += 1.0;

For example in R if you run that section of code for x = 0.001 you get this answer:

> sprintf("%10.20f", f(0.001))
[1] "1000001.64243319106753915548"

If you construct a vector of ‘value’ and then sum using R’s sum function you get something different;

> sprintf("%10.20f", g(0.001))
[1] "1000001.64243319083470851183"

And if you sum from smallest to largest values you get what R gets:

> sprintf("%10.20f", h(0.001))
[1] "1000001.64243319083470851183"

or we could use a Kahan sum ( for things like this in general. It’s easy to implement and there are some modern versions.

  1. I should add timings to that c++ code so we know what the perforance implications are of making a switch

  2. This should be easier so I was going to take that test code and write a c++ template function stuff out so it’s easy to check any given function and produce the output required. Most of that test code is boilerplate.

Here’s the C++ test code:

#include <iostream>
#include <limits>
#include <stan/math/prim/scal.hpp>
#include <stan/math/prim/scal/fun/trigamma.hpp>
#include <boost/math/special_functions/trigamma.hpp>

int main (int argc, char** argv) {
  double start = 0;
  double stop = 1e6; std::numeric_limits<double>::max();
  int n_steps = 1e9;
  double step = (stop - start) / double(n_steps);
  double x = start;
  double stan_trg;
  double boost_trg;
  double stan_boost_trg;
    std::cout << "x, stan::math::trigamma(x), boost::math::trigamma(x), difference::abs";
    std::cout << ", difference::rel \n";
  while (x >= start && x < stop) {
    try {
      stan_trg = stan::math::trigamma_impl(x);
    } catch (const std::exception&) {
      stan_trg = std::numeric_limits<double>::quiet_NaN();
    try {
      boost_trg = boost::math::trigamma(x);
    } catch (const std::exception&) {
      boost_trg = std::numeric_limits<double>::quiet_NaN();
    try {
      stan_boost_trg = stan::math::trigamma_impl(x) - boost::math::trigamma(x);
    } catch (const std::exception&) {
      stan_boost_trg = std::numeric_limits<double>::quiet_NaN();
    std::cout << std::setprecision(std::numeric_limits<double>::digits10)  << x << ", ";
    std::cout << std::setprecision(std::numeric_limits<double>::digits10)  << stan_trg << ", "
    std::cout << std::setprecision(std::numeric_limits<double>::digits10)  << boost_trg << ", 
    std::cout << std::setprecision(std::numeric_limits<double>::digits10)  << stan_boost_trg <
    std::cout << std::setprecision(std::numeric_limits<double>::digits10)  << stan_boost_trg /
    x = x + step;
  return 0;

Compiles on ‘develop’ with:

g++ -I . -I ./lib/eigen_3.3.3 -I ./lib/boost_1.69.0 -o switch-test ./switch-test.cpp

The .csv (as written, easy to modify in the C++ code) is 80Gb of text and you can load sections in R (‘skip’ controls how many ros are skipped, ‘nrow’ controls how many rows to load) with:

d = read.csv('test.csv', header = TRUE, colClasses = 'numeric', check.names =FALSE, nrows=10, skip=10)

I’m going to be out for a while but I’ll see if I can’t write some more sensible code this afternoon for including timing.

In terms of what evidence I think we would want: a multi-function comparison that shows that whatever version we pick has low error and doesn’t tank performance.

@wds15 can you easily come up with a Stan line that kicks off this calculation in derivatives? I would throw together a test model that focuses on a trigamma region with problems to see if it matters.

@bbales: where in trigamma did you see problems?


arb, which uses ball arithmetic.

Unless you say that boost is terrible, we should just switch. Boost math folks actively maintain this function. So it’s better to let them maintain and test it. If it’s bad then we should file an issue to them. The future maintenance burden is with them which is great for us. That’s why I would switch without thinking at this point (one more look at performance would be good).

That’s the one, thanks!

So boost is definitely slower for small arguments (>2x slower) but Stan’s absolute error vs. boost steps up from tiny to >1.5 at x <= 0.0001 and Stan’s relative error therefore blows up.

timing fig: boost-in-red.pdf (12.3 KB)

absolute error fig: stan-trigamma-error-step.pdf (4.3 KB)

Since <1e-4 is a region that’s probably going to get hit when this is used in a sampler and the error really changes dramatically (and it’s no question Stan has the problem no boost) it makes sense to switch over.

Did you avoid the propagation from boost which converts double arguments to long double internally?

2x slowdown is what I saw for lgamma as well when using this (default) over propagation. You can refer to the lgamma branch I have to see how to override it.

(Thanks for all the effort)

EDIT: I just looked myself into the performance of this and this is what I get for 10^5 calls:

BM_old_boost_median     198590 ns     198328 ns       3596 ## boost with defaults
BM_boost_median         109344 ns     109096 ns       6088 ## boost without overpropagation
BM_stan_median           62312 ns      62208 ns      11497 ## stan

So maybe we switch as you suggest, but file an issue for boost math where we point out that obviously some argument ranges of the function can be computed faster. Then they can consider using that for a sub-range (they do that sort of thing).

yech, I had caclulated the relative error backwards, did it as (f(x) - g(x))/g(x) this time…t

so the error for us spikes (to 1.5e-8) at 1e-4 and decays as you go to zero (that makes sense now). There are some spikes up to 6e-9 through x=6, and I haven’t checked yet on the later values.

I get slightly different values for the timings (this is boost/stan in black and boost (with no-promote policy) / stan in red: no-promote-in-red

So I guess the mean values are not that meaningful. This is in the 0-10 range, I know I should run another test for higher values at some point.

The problem was second and third derivatives of trigamma using finite differences that were low precision vs. our autodiff.

Is it templated enough to autodiff through? It’s a one-argument function, so it should be.

This raises another point that a stan::math::var is not considered a Boost::Real and thus does not work well with all Boost Math functions. I think it would need some more things defined.

The boost math includes the polygamma function which is the nth derivative of the log-gamma. So we can just use that.

We should probably test the boost math trigamma with your new test framework to see if it is robust under this testing procedure which showed not ideal behaviour. Let me see if I can figure that out.

We should make sure var and fvar<T> work with that concept if T works with that concept. But I think our value for fvar<T> is off, as it returns something of type T, not of double. I’m not sure we can make that work.

I don’t see anything else there that’s not implement. Did you see something we’re missing or know how to test if we’ve got it all (or as much of it as we can)?