The main problem right now is that var_context
concept is extraordinarily monolithic and wraps all of the variable transformations. To clean that up we’d need to pull all of the transforms exposed in https://github.com/stan-dev/stan/blob/develop/src/stan/io/reader.hpp and https://github.com/stan-dev/stan/blob/develop/src/stan/io/writer.hpp into their own separate classes that handle the constraining/unconstraining for each transform. For example,
template <typename T, typename T_LB>
class vector_lower_bound: public vector_transform<T> {
private:
T_LB lower_bound_;
public:
vector_lower_bound(T_LB lower_bound, size_t N):
lower_bound_(lower_bound), vector_transform<T>(N) {}
size_t unconstrained_dim() { return N_; }
void validate(vector_t constr_val) {
stan::math::check_greater_or_equal("stan::io::vector_lower_bound",
"Lower bound constraint", constr_val,
lower_bound_);
}
template <typename Jacobian>
vector_t constrain(const vector_t& unconstr_val, lp_accumulator<T_LB>& acc) {
vector_t constr_val;
for (idx_t n = 0; n < N_; ++n) {
if (Jacobian) {
T_LB lp;
output(n) = stan::prob::lb_constrain(unconstr_val[n], lower_bound_, lp);
acc.push_back(lp);
} else {
output(n) = stan::prob::lb_constrain(unconstr_val[n], lower_bound);
}
}
}
vector_t constrain(const vector_t& unconstr_val,
std::vector<double> constr_val) {
for (idx_t n = 0; n < N_; ++n)
constr_val.push_back(stan::prob::lb_constrain(unconstr_val[n],
lower_bound_));
}
void unconstrain(const vector_t& constr_val, vector_t::InnerIterator& it) {
for (idx_t n = 0; n < N_; ++n)
*(it++) = stan::prob::lb_free(constr_val[n], lower_bound_);
}
};
I already did a bunch of these in the old branch, https://github.com/stan-dev/stan/tree/feature/refactor_io_reader_writer/src/stan/io/transforms.
Once we have all the transformed defined in their own classes we can then clean up the var_context
itself. First, a rename – what has been called a “variable context” is more property a “data access layer”. The new DAL would then have a single constrain/unconstrain method that takes in the transform classes, either as a base function or a template. For example,
// base_dal
vector_t get_vector(const std::string& name, size_t N) {
if (!contains_r(name)) {
std::stringstream msg;
msg << "Variable " << name
<< " not found in variable context" << std::endl;
throw std::runtime_error(msg.str());
}
std::vector<size_t> dims = dims_r(name);
if (dims.size() != 1) {
std::stringstream msg;
msg << "Requested the vector " << name
<< ", but " << name
<< " is not defined as a vector in variable context"
<< std::endl;
throw std::runtime_error(msg.str());
}
if (dims[0] != N) {
std::stringstream msg;
msg << "Requested the vector " << name
<< " with size " << N << ", but " << name
<< " is defined with size " << dims[0] << " in variable context"
<< std::endl;
throw std::runtime_error(msg.str());
}
std::vector<double> vals = vals_r(name);
vector_t output(N);
for (vector_t_idx n = 0; n < N; ++n)
output(n) = vals[n];
return output;
}
// constr_dal
template <typename T>
void get_constrained_vector(const string& name,
const stan::io::transform& transform,
vector_t param) {
param.resize(transform.size());
param = get_vector(name, transform.size());
transform.validate(param);
}
// unconstr_dal
template <typename T>
void get_unconstrained_vector(const string& name,
const stan::io::transform& transform,
vector_t::InnerIterator& it) {
vector_t constr_val = get_vector(name, transform.size());
transform.unconstrain(constr_val, it);
}
Then we can drastically clean up (and simplify) the generated model code to look something like
// constructor
model_name {
...
vector_lb_transform<input_type, lb_type>
param_name1_transform__(lb_value);
data_var_context.get_constrained_vector("param_name1",
param_name1_transform,
param_name1__);
...
}
// init
void init(vector_t& unconstr_state) {
vector_t::InnerIterator it(unconstr_state);
...
vector_lb_transform<input_type, lb_type>
param_name1_transform__(lb_value);
get_unconstrained_vector("name1", param_name1_transform, it);
...
}
// what write_array was
void constr_state(vector_t unconstr_state,
std::vector<double>& constr_state) {
...
vector_lb_transform<input_type, lb_type>
param_name1_transform__(lb_value);
param_name1_transform__.constrain(unconstr_state, constr_state);
...
}
// log_prob
template <typename T>
T log_prob(vector_t& unconstr_state) {
lp_accumulator<T> lp_acc__;
...
vector_lb_transform<input_type, lb_type>
param_name1_transform__(lb_value);
vector_t param_name1__ =
param_name1_transform__.constrain<Jacobian>(unconstr_state, lp_acc__);
...
}
At this point DAL implementations are solely responsible for providing the actual data access. For example, a C++ DAL would be extremely easy to write at this point.