As mentioned at Parallel dynamic HMC merits we can improve adaptation by collecting adapt info from multiple chains during warmup. The following concept is an MPI framework to do that through master-slave paradigm.(code at https://github.com/yizhang-yiz/stan/tree/mpi_warmup_framework).
We use n+1 processes to run n chains, with proc 0 being master and the rest being slaves. Each slave runs a chain that communicates with master every m iterations during warmup. That is, at iteration j*m, j=1, 2, \dots, master proc and slave proc do:
- Slave i generates adapt information and send it to master. The adapt information to be sent is defined by functor F_s:
struct Fs {
template<typename Sampler, typename Model>
void operator()(Sampler& sampler, Model& model,
Eigen::MatrixXd& adapt_info, int slave_id) {...}
};
- Master collects adapt information from each slave, and use it to generate ensemble adapt information, through functor E_m
struct Em {
template<typename Sampler, typename Model>
void operator()(Sampler& sampler, Model& model,
Eigen::MatrixXd& ensemble_adaptation_info) {...}
};
- Master distribute ensemble adaptation information to slaves.
- Slave receive new adaptation information and combine it with its current adaptation through functor G_s
struct Gs {
template<typename Sampler, typename Model>
void operator()(Sampler& sampler, Model& model,
Eigen::MatrixXd& adapt_info) {...}
};
- Slaves do another m iterations before next round of ensemble adaptation.
Interval m and functor F_s, E_m, and G_s define the ensemble adaptation algorithm. But the asynchronized parallel communication framework remains the same(which is the point of the design).
Which an algorithm defined, the run_adaptive_sampler
would be something like this
for (i = 0; i < num_intervals; ++i) {
if (is_master_proc) {
// construct master
warmup_dynamic_loader_master master(warmup_comm, internval);
// master recv/send adapt info
master(sampler, model, f, g, h);
} else {
// construct slave
warmup_dynamic_loader_slave slave(warmup_comm, internval);
// slave move forward
util::generate_transitions(sampler, interval, nbegin, num_warmup + num_samples,
num_thin, refresh, save_warmup, true, writer, s,
model, rng, interrupt, logger);
nbegin += interval;
// slave send/receive adapt info
slave(sampler, model, f, g);
}
}
Tagging @betanalpha and @Bob_Carpenter for the soundness of the design and possible algorithm(functors mentioned above) to try.