This post was motivated by comments from @tadej and @bbbales2 on my first complex PR (the details of that aren’t relevant here). They were pushing back at my usage of scalar to mean the value type of a complex number instead of the complex number itself. They’re right on the math, and even though I question the utility of the resulting metaprogram, I think it’s illustrative to understand Stan’s type system in particular and type systems in general. I have the metaprograms to implement this on a branch with tests that I’ll turn into a PR soon.
Scalar types
First, we have to be clear about the types we’re talking about. So let’s lay down some standard type-theoretic definitions.
The set of autodiff types is defined to be the smallest such that
-
var
is an autodiff type, -
fvar<double>
is an autodiff type, and -
fvar<T>
is an autodiff type ifT
is an autodiff type.
For example, this is all of our favorites, including var
, fvar<double>
, fvar<fvar<double>>
, fvar<var>
, fvar<fvar>as well as all the continuing higher order
fvar` nestings.
The set of scalar types is defined to be the smallest such that
-
double
andlong double
are scalar types, -
T
is a scalar type ifT
is an autodiff type, -
complex<double>
is a scalar type, and -
complex<T>
is a scalar type ifT
is an autodiff type.
Assignability
We need to know when it’s possible to assign to a variable of a given type. We want to do this so as to preserve co-variance, which in programming language parlance means that if you have a container of a type, then you inherit assignability. C++ standard template library collections are not covariant because you can assign int
to double
but you can’t assign std::vector<int>
to std::vector<double>
. For assignability among scalar types, we want to maintain covariance through std::complex
so that std::complex<double>
is assignable to std::complex<var>
.
The assignability relation among scalar types is defined to be
the smallest such that
-
double
is assignable to any typeT
, -
T1
is assignable tostd::complex<T2>
ifT1
is assignable to -
T2
, and -
std::complex<T1>
is assignable tostd::complex<T2>
ifT1
is
assignable toT2
.
At this point, we’d usually prove a theorem that the assignability relation organizes types into a join semi-lattice (meaning least upper bounds exist in the ordering). But it’s pretty straightforward if you draw the types out, so I’ll leave it as an exercise for the reader.
Monads and the return type metaprogram
We want a metaprogram that’ll calculate the return type of an arbitrary sequence of types, where the return type is the least upper bound of the types. This turns out to be a natural monad structure and that’s how we’re going to implement it as a metaprogram. The monadic style is going to eliminate some shared code in our existing implementations and increase the generality of boundary conditions. It’s always good to think about boundary conditions when dealing with types.
A monad as far as we’re going to be concerned is a very simple algebraic structure consisting of a unit and a binary operation. We’ll exploit this simple structure to let the unit define the result of the operation on no inputs, then we can chain inputs in by successively applying the binary operation. he specific monad we care about consists of the set of scalar types, for which double
is the unit and for which a binary operation is least-upper-bound in the lattice of scalar types.
The least-upper bound metaprogram is straightforward based on the definition of assignability. It leans on the promote_args_t
wrapper for the Boost promote_args
metaprogram. The promotion metaprogram returns double
if both arguments are arithmetic (it’s actually a bit more complicated to account for long double
), and the class type if either are a class type (and if they’re both a class type, they have to match or it’s undefined). This just forms the base case.
template <typename T1, typename T2>
struct scalar_lub {
using type = promote_args_t<T1, T2>;
};
It’s then overloaded for complex numbers to support covariance.
template <typename T1, typename T2>
struct scalar_lub<std::complex<T1>, T2> {
using type = std::complex<promote_args_t<T1, T2>>;
};
template <typename T1, typename T2>
struct scalar_lub<T1, std::complex<T2>> {
using type = std::complex<promote_args_t<T1, T2>>;
};
template <typename T1, typename T2>
struct scalar_lub<std::complex<T1>, std::complex<T2>> {
using type = std::complex<promote_args_t<T1, T2>>;
};
And we define the usual convenience typedef.
template <typename T1, typename T2>
using scalar_lub_t = typename scalar_lub<T1, T2>::type;
That’s the binary operator of our monad and double
is the unit. The metaprogram to compute return types is defined in terms of these, using the unit for the base case and the least-upper-bound operation for the inductive case.
template <typename... Ts>
struct return_type {
using type = double;
};
template <typename T, typename... Ts>
struct return_type<T, Ts...> {
using type = scalar_lub_t<scalar_type_t<T>,
typename return_type<Ts...>::type>;
};
The parameter packs let us define a base template that returns double
as the result. This will match empth sequences of types, which is the generalization beyond the current definition. The inductive case is coded as a specialization for at least one type T
. It’s implemented by recursively applying the metaprogram to the remaining types Ts
and then doing a least-upper-bound after pulling the scalar type out of T
using scalar_type_t
. That latter operation also strips out qualifiers and references and lets this also apply beyond scalar types to containers like matrices and standard vectors. For containers, their underlying scalar type is extracted before applying. This is another metaprogram that is just an enumeration.
And of course, the utility typedef.
template <typename... Args>
using return_type_t = typename return_type<Args...>::type;
That’s it. The monadic style is just there in the base case and then the recursion that applies an element at a time applying a binary operation. This is all done without side effects or assignment by using tail recursion, that is, calling the metaprogram return_type
on the tail types of the input.