@maedoc yeah the doc isn’t ready yet for this stuff yet. You’ll want to use the develop version of Math if you’re going to work on this stuff, cause it changes rapidly (for the better).
The API doc is here and it’s updated to develop I believe: https://mc-stan.org/math/. So any time you’re curious about a weird looking function, check there. That’ll get you the doxygen in the code.
Here’s something like your function. I didn’t check it or finish it, but you can start here:
template <typename T1, typename T2,
require_matrix_t<T1>* = nullptr,
require_col_vector_t<T2>* = nullptr,
require_any_st_var<T1, T2>* = nullptr>
auto myfunc(const T1& m, const T2& v) {
using inner_ret_type = Eigen::VectorXd;
using ret_type = return_var_matrix_t<Eigen::VectorXd, Mat1, Mat2>;
if (!is_constant<T1>::value && !is_constant<T2>::value) {
arena_t<promote_scalar_t<var, T1>> arena_m = m;
arena_t<promote_scalar_t<var, T2>> arena_v = v;
arena_t<ret_type> ret = arena_m.val() * arena_v.val() - (arena_v.val().array() * arena_m.val().rowwise().sum().array()).matrix();
reverse_pass_callback([ret, arena_m, arena_v]() mutable {
arena_m.adj() += ... something;
arena_v.adj() += ... something;
});
return ret_type(ret);
} else if (!is_constant<T1>::value) {
arena_t<promote_scalar_t<var, T1>> arena_m = m;
arena_t<promote_scalar_t<double, T2>> arena_v = v;
arena_t<ret_type> ret = arena_m.val() * arena_v - (arena_v.array() * arena_m.val().rowwise().sum().array()).matrix();
reverse_pass_callback([ret, arena_m, arena_v]() mutable {
arena_m.adj() += ... something;
});
return ret_type(ret);
} else if (!is_constant<T2>::value) {
arena_t<promote_scalar_t<double, T1>> arena_m = m;
arena_t<promote_scalar_t<var, T2>> arena_v = v;
arena_t<ret_type> ret = arena_m * arena_v.val() - (arena_v.val().array() * arena_m.rowwise().sum().array()).matrix();
reverse_pass_callback([ret, arena_m, arena_v]() mutable {
arena_v.adj() += arena_m.val().transpose() * ret.adj_op(); // Maybe?
});
return ret_type(ret);
}
}
-
First new thing are the requires. Instead of typing types explicitly we use C++ template SFINAE sorta stuff to accept arguments. This is cause for any argument there are a ton of different types that work. The first here says, “Make sure the first argument is a matrix”. The second says “Make sure the second is a column vector”. The last says, “Make sure at least one is an autodiff type (has scalar type var)”. Search for these in the upper right of https://mc-stan.org/math/
-
return_var_matrix_t
– there are actually two types of autodiff matrices in Stan now (design doc for the second one here), and this picks the right return type given the types of the input autodiff variables.
-
There are a lot of arena_t<T>
expression in the code. arena_t<T>
says “give me a variable equivalent to the type T
that is stored in the autodiff arena”. This is used to save variables in the forward pass that are needed in the reverse pass.
-
Because your function as two arguments and requires at least one argument to be an autodiff type, then you need to handle the three combinations of double/autodiff types. That’s what the if/else
stuff with is_constant
is handling. (is_constant<T1>::value == TRUE
means this is not an autodiff type).
-
reverse_pass_callback
takes a lambda argument. This function will be called in the reverse pass. It’s the equivalent of chain, but thanks to lambda captures it’s way easier to write. Capture everything by copying. arena_t
types are cheap to copy by design. The code inside this is responsible for incrementing the adjoints of the input variables (which you saved copies of them in arena_m
and arena_v
. The mutable
thing is required on the lambda so that you can write the adjoints (otherwise you’ll get an error about the adjoints being const
).
-
There are a lot of accessors .val()
and .val_op()
and .adj()
and .adj_op()
that are Stan extensions to Eigen types (so you won’t find them in the EIgen docs). val()
and adj()
give you the values and adjoints of a Stan matrix type. Sometimes with multiplications you’ll get compile errors, and in that case use .val_op()
and .adj_op()
(I’m not sure the difference and hopefully this gets cleaned up cuz it’s confusing)
-
How we test these things is the expect_ad
function in test/unit/math/mix. The assumption of expect_ad
is that you’ve written a version of your function that works with doubles that you trust. It compares finite difference gradients to the ones computed with autodiff up through third derivatives I think. The tests for elt_multiply
(elementwise multiplication) are here, look at them as an example. The expect_ad_matvar
tests the new matrix autodiff type.
-
Here’s a recently completed new function as an example of all this
I would start with a one argument version of your function and work from there. Just set the vector or the matrix equal to a constant inside the function. The one argument versions of this are much simpler. The multiple argument thing adds some annoying complexity (that isn’t so bad once you know what it is but can be pretty annoying to work through).