The closures design doc was accepted in December but there’s been no activity since then. I think it’s time to do this.
Tagging everyone who participated in the design-doc discussion:
@Bob_Carpenter, @seantalts, @wds15, @Matthijs, @bbbales2
Unfortunately the design-doc has no plan for how to implement it. The implementation section is only a sentence about C++ lambdas. It will not work because lambdas cannot autodiff correctly in higher-order functions.
Probably the easiest C++ implementation is to convert captured variables into extra parameters (“lambda-lifting”) and piggyback on the variadic arguments proposal. I think I could make this work by the time Stan math supports variadic arguments.
But let’s consider supporting proper closures directly in the math library.
Currently the higher-order functions take a “functor” struct.
struct foo_functor__ {
template <typename T0__>
typename boost::math::tools::promote_args<T0__>::type
operator()(const T0__& x, std::ostream* pstream__) const {
return foo(x, pstream__);
}
};
The functor struct contains no data and is default-constructible. ODE solvers store the struct and call it with either var or double arguments depending on where they need autodiff.
The proposed variadic arguments are implemented by passing along an additional opaque tuple.
The parameter pack tuple is accessed with helper functions
-
count_vars(...)
– total number of autodiffable vars -
value_of(...)
– create an identical double-only tuple (same data, no autodiff) -
accumulate_adjoints(double*, ...)
– copy the gradients into a vector -
save_varis(vari**, ...)
– copy the vari pointers into a vector -
deep_copy_vars(...)
– allocate new independent varis for nested autodiff
A “closure” object would combine the callable functor and the parameter pack.
The closure then exposes an API that should be equivalent to the above.
Here’s what I think it would look like.
The user defines a closure and passes it to a higher-order function hof
data { real c; ... }
parameters {
real m;
real s;
...
} model {
real foo(real x) {
return -(x-m)^2/s + c;
}
... = hof(foo, ...);
The compiler translates it to C++
foo_fn__ foo(c, m, s);
... = hof(foo, ...);
where the class is
class foo_fn__ {
double& c;
var& m;
var& s;
public:
int num_vars__;
foo_fn__(double& c_, var& m_, var& s_)
: c(c_), m(m_), s(s_), num_vars__(2) { }
template<typename T>
var operator()(T x) {
return -square(x-m)/s + c;
}
// double interface
class foo_fn_dbl__ {
double& c;
double m;
double s;
public:
foo_fn_dbl__(double& c_, double m_, double s_)
: c(c_), m(m_), s(s_) { }
template<typename T>
double operator()(T x) {
return -square(x-m)/s + c;
}
};
foo_fn_dbl__ value_of_() {
return foo_fn_dbl__(c, value_of(m), value_of(s));
}
// vari interface
void set_zero_adjoints() {
m.vi_->set_zero_adjoint();
s.vi_->set_zero_adjoint();
}
void accumulate_adjoints(double* gradients) {
gradients[0] += m.adj();
gradients[1] += s.adj();
}
void save_varis(vari** varis) {
varis[0] = m.vi_;
varis[1] = s.vi_;
}
foo_fn__ deep_copy_vars() {
// captured vars are references, can't store the copies in the struct
// place the copies e.g. on the nested ad stack
var* m_ = ChainableStack::instance_->memalloc_.alloc_array<var>(1);
*m_ = var(new vari(value_of(m), false));
var* s_ = ChainableStack::instance_->memalloc_.alloc_array<var>(1);
*s_ = var(new vari(value_of(s), false));
return foo_fn__(c, *m_, *s_);
}
};
Additionally MPI map_rect
requires some way to serialize the object but other than that, is this API sufficient for all of our higher-order functions?