[ This was originally a comment on another thread on consolidated output. ]
I’m very much behind consolidating argument groups into structures (aka “parameter objects”).
Existing service method
Here’s the existing service method for our default sampler (copied from the Wiki on consolidated output):
template <class Model>
int hmc_nuts_diag_e_adapt(Model& model, stan::io::var_context& init,
stan::io::var_context& init_inv_metric,
unsigned int random_seed, unsigned int chain,
double init_radius, int num_warmup,
int num_samples, int num_thin, bool save_warmup,
int refresh, double stepsize,
double stepsize_jitter, int max_depth,
double delta, double gamma, double kappa,
double t0, unsigned int init_buffer,
unsigned int term_buffer, unsigned int window,
callbacks::interrupt& interrupt,
callbacks::logger& logger,
callbacks::writer& init_writer,
callbacks::writer& sample_writer,
callbacks::writer& diagnostic_writer);
Parameter objects
There are obvious groupings of arguments, and what this thread is discussing is the last bundle. I think it’d help overall if we refactored these service functions into bundles of arguments. The above would look like this:
int
hmc_nuts_diag_e_adapt(
const base_model& model,
const inits& inits,
const seed& seed,
const iterations& its,
const nuts_config& max_depth,
const step_adapt& sac,
const metric_adapt& mac,
outputs& out);
model–all
base_model& model,
inits—HMC & NUTS
( double init_radius,
stan::io::var_context& init,
stan::io::var_context& init_inv_metric,
double stepsize ),
seed—all
( unsigned int random_seed,
unsigned int chain ),
iterations–HMC & NUTS
( int num_warmup,
int num_samples,
int num_thin,
bool save_warmup,
int refresh ),
nuts_config–NUTS
( int max_depth ),
step_adapt–HMC and NUTS w. stepsize adaption
( double delta,
double gamma,
double kappa,
double t0 ),
metric_adapt–HMC and NUTS w. inverse metric adaptation
( unsigned int init_buffer,
unsigned int term_buffer,
unsigned int window ),
outputs–HMC and NUTS
( callbacks::interrupt& interrupt,
callbacks::logger& logger,
callbacks::writer& init_writer,
callbacks::writer& sample_writer,
callbacks::writer& diagnostic_writer )
I’m pretty sure these are all safe in the sense that we’ll have the same clusters of arguments for many different samplers and optimizers. That is, each service function can be built using a few common grouped argument structures. This is less ambitious than having a top-level structure, but will support many of the same goals as the original design. Some issues with these structures:
-
the templated model can be replaced with a base class so that this can all be precompiled. As @betanalpha has pointed out, we’ll need to carefully test this to make sure it doesn’t cost any measurable efficiency. It should be a big win in terms of compile time.
-
There’s an ongoing discussion to refactor the elements of the output object.
-
The inverse metric adaptation (which is what we’re really doing) should also have the regularization parameters exposed for our estimation.
-
The
var_context
lets us neatly finesse the actual types of some arguments, -
I dropped
stepsize_jitter
, which I think is OK at this point. Alternatively, we could include it in a group of HMC config arguments (it shouldn’t go withmax_depth
, which is HMC only, unless thenuts_config
inherits fromhmc_config
). -
num warmup and num sample iterations may be better off on their own—the other parameters in that bundle are about filtering output (removing inits, refreshing output, and thinning).
-
Putting the chain id and random seed together seems awkward, but the chain id is primarily used to identify the output chain and to advance the RNG initially so the chains use separate slices of the same RNG.
Builders for “named” args and defaults
Of course, we then need easy builders for the chunks to deal with defaults. Pretty much all of these can get defaults, and a full default construction would just be:
step_adapt sa; // all default values
Then I’m imagining a standard builder with overloaded operator=
rather than a “build” method, e.g.,
step_adapt sa = step_adapt::build().delta(1.3).t0(0.01);
That’s like named arguments in that they can come in any order and they have defaults if they aren’t used at all.
For all of this to be super flexible, it’d also be nice to have a var_context_builder, something like:
var_context vc = var_context::build().matrix("a", m).real("f".2.3);
Not quite sure how to implement that, though.