I’ve run head first into the undocumented require_any/all/not
s. I’m looking for a template parameter to flag to say “nothing is a var”.
Context I’m plugging along at https://github.com/stan-dev/math/issues/1787 and, in particular, trying to get a reverse mode specialization for add
(which is matrix addition).
My guiding light in this is the reverse mode specification for multiply
, but it has some templating tricks that are a touch beyond my understanding.
The relevant signature for the reverse mode specialization of multiply
is:
template <typename Mat1, typename Mat2,
require_all_eigen_t<Mat1, Mat2>* = nullptr,
require_any_eigen_vt<is_var, Mat1, Mat2>* = nullptr,
require_not_eigen_row_and_col_t<Mat1, Mat2>* = nullptr>
inline auto multiply(const Mat1& A, const Mat2& B) {/*...*/}
All well and good. (Tbqh I’m not sure why we’re checking something is a null pointer but whatevs, I’m willing to ignore that). Maybe my one question is that I have no idea what require_not_eigen_row_and_col_t<Mat1, Mat2>
is checking.
But nevertheless, this works and my reverse mode checks all fly through.
The problem
The code when nothing is a var is easy, but I’m struggling with the signature. Here the multiply
version doesn’t help because it’s specialized twice: once for doubles (to use OpenCL) and once for everything else. I would like to collapse this down to just one thing (there’s no point doing openCL for a matrix add).
The two signatures in the multiply
file are
template <typename Mat1, typename Mat2,
require_all_eigen_t<Mat1, Mat2>* = nullptr,
require_all_same_t<double, value_type_t<Mat1>,
value_type_t<Mat2>>* = nullptr,
require_not_eigen_row_and_col_t<Mat1, Mat2>* = nullptr>
inline auto multiply(const Mat1& m1, const Mat2& m2) { /* ... */}
and
template <typename Mat1, typename Mat2,
require_all_eigen_vt<std::is_arithmetic, Mat1, Mat2>* = nullptr,
require_any_not_same_t<double, value_type_t<Mat1>,
value_type_t<Mat2>>* = nullptr,
require_not_eigen_row_and_col_t<Mat1, Mat2>* = nullptr>
inline auto multiply(const Mat1& m1, const Mat2& m2) { /* ... */}
What is the correct require to say "anything but a var
"?
Any help would be greatly appreciated!