Wrap boost::math::chebyshev_transform into stan::var?

I CLEARLY NEED MORE SLEEP… Obviously, zstr.adj() is not going to give me what I want, I’m supposed to access r.adj() Sorry, everyone!

Hi! I am approximating a univariate function (not available in closed-form) with Boost’s chebyshev_transform. The chebyshev_transform class’s operator() returns the approximate value of the function for a given input, and the prime method returns its first derivative. I’d like to use that approximation with the Stan math library via Rcpp (StanHeaders). It ought to be trivial to create a function that returns a var (for reverse-mode AD) where val() and adj() are determined using those methods. However, the documentation on creating custom functions is a little beyond me.

I know that this is a lot of code, but below is what I’ve tried:

#include <limits>
#include <tuple>
#include <vector>
#include <boost/math/tools/roots.hpp>
#include <boost/math/special_functions/chebyshev_transform.hpp>
#include <stan/math.hpp>
#include <RcppEigen.h>

// The function I want to approximate is dblZstar, which returns the root of dblFd1.
// Sorry, lots of code to set up this function and approximation.
namespace zstar
{
  using boost::math::chebyshev_transform;

  inline double dblFd1(double rho, double x)
  {
    using std::pow;
    using std::tanh;
    using stan::math::pi;

    double result;
    result = -1.0/2.0*pi()*rho;
    result -= 1.0/4.0*pow(pi(), 2)*x*pow(tanh((1.0/2.0)*pi()*x), 2);
    result += (1.0/4.0)*pow(pi(), 2)*x;
    result += (1.0/2.0)*pi()*tanh((1.0/2.0)*pi()*x);

    return result;
  }

  inline double dblFd2(double x) // 1st derivative of dblFd1
  {
    using std::pow;
    using std::tanh;
    using stan::math::pi;

    double result;
    result = (1.0/4.0)*pow(pi(), 3)*x*pow(tanh((1.0/2.0)*pi()*x), 3);
    result -= 1.0/4.0*pow(pi(), 3)*x*tanh((1.0/2.0)*pi()*x);
    result -= 1.0/2.0*pow(pi(), 2)*pow(tanh((1.0/2.0)*pi()*x), 2);
    result += (1.0/2.0)*pow(pi(), 2);

    return result;
  }

  inline double dblFd3(double x) { // 2nd derivative of dblFd1
    using std::pow;
    using std::tanh;
    using stan::math::pi;

    double result;
    result = -3.0/8.0*pow(pi(), 4)*x*pow(tanh((1.0/2.0)*pi()*x), 4);
    result += (1.0/2.0)*pow(pi(), 4)*x*pow(tanh((1.0/2.0)*pi()*x), 2);
    result -= 1.0/8.0*pow(pi(), 4)*x;
    result += (3.0/4.0)*pow(pi(), 3)*pow(tanh((1.0/2.0)*pi()*x), 3);
    result -= 3.0/4.0*pow(pi(), 3)*tanh((1.0/2.0)*pi()*x);
    return result;
  }

  double dblZstar(double rho) // Function that we will approximate with Chebyshev
  {
    using std::tuple;
    using std::make_tuple;
    using boost::math::tools::schroder_iterate;

    double guess = 0.0;
    double min = -25.0;
    double max = 25.0;

    const int digits = std::numeric_limits<double>::digits;
    int get_digits = static_cast<int>(digits * 0.4);

    std::uintmax_t maxit = 20;

    double result = schroder_iterate(
      [rho](const double &x) -> tuple<double, double, double> {
        double d1 = dblFd1(rho, x);
        double d2 = dblFd2(x);
        double d3 = dblFd3(x);
        return make_tuple(d1, d2, d3);
      }, guess, min, max, get_digits, maxit);

    return result;
  }

  // Make the approximation
  const chebyshev_transform<double> CHEB(dblZstar, -1.05, 1.05);
}

// I tried different things, but the last thing I tried was precomputed_gradients
namespace stan
{
  namespace math
  {
    double genZstar(double rho)
    {
      return zstar::CHEB(rho);
    }

    var genZstar(const var &rho)
    {
      using std::vector;
      using Rcpp::Rcout;

      double r = rho.val();
      double fr(zstar::CHEB(r));
      double df_dr = zstar::CHEB.prime(r);

      vector<var> vars(1);
      vars[0] = rho;

      vector<double> gradients(1);
      gradients[0] = df_dr;

       return precomputed_gradients(fr, vars, gradients);
    }
  }
}

// [[Rcpp::export]]
Rcpp::List lstZstar(double rho)
{
  using Rcpp::List;
  using Rcpp::Named;
  using stan::math::var;

  var r(rho);
  var zstr = stan::math::genZstar(r);
  zstr.grad();

  const double zstr_val = zstr.val();
  const double zstr_adj = zstr.adj();

  stan::math::recover_memory();

  return List::create(
    Named("zstar") = zstr_val,
    Named("dzstar") = zstr_adj,
    Named("dzstar_") = zstar::CHEB.prime(rho));
}

In R, executing lstZstar(0.5) returns the correct value. However, irrespective of the value of rho, the derivative always equals 1.

Any help whatsoever would be much appreciated.

Adding a new function is beyond most of us. The main problem is making sure that whatever functions are used by the function you are calling have appropriate derivatives. Then you can usually just autodiff through the Boost implementation, which is simplest.

If you want to add a new function yourself and you have a double-based implementation of the partial derivatives of the outputs w.r.t. the inputs, you need to get it into the math library and implement both the function and its gradients. Usually it’s more efficient to implement an adjoint-Jacobian product version so that you don’t need to explicitly calculate a Jacobian and multiply it.

There are lots of examples in the math lib. Then you need to add the signature to the language either in the OCaml parser or by declaring the signature in the functions block of Stan.

@stevebronder may know if there’s a working example somewhere of this. That’d probably be easiest. If not, Steve’s probably the best person to help around this.

1 Like

Thank you, @Bob_Carpenter. I feel much better knowing that this is beyond most people.

In the end, I did the following:

// Approximate zstar
const chebyshev_transform<double> approxZstar(dblZstar, -1.05, 1.05);

inline double genApproxZstar(double rho)
{
 return approxZstar(rho);
}

// Get Stan to use the chebyshev_transform class's prime method to
// calculate the derivative.
inline stan::math::var genApproxZstar(const stan::math::var &rho)
{
 return stan::math::make_callback_var(
   approxZstar(rho.val()), [rho](const auto &vi) mutable {
     rho.adj() += vi.adj_ * approxZstar.prime(rho.val());
   });
}

My R package uses the Stan Maths C++ library directly via Rcpp. Therefore, I feel that it is unnecessary for me to incorporate this in Stan. The above code seems to work for my purposes. I am a little concerned about the potential for memory leaks since I didn’t bother with arena memory. I haven’t, however, picked up any problems yet.