Hello! I’m using Stan’s C++ template library to perform Hamiltonian NUTS MCMC, but I’m hindered by a slow likelihood+gradient query. One of the tactics I’m trying out is reduce_sum
to expedite building the autodiff stack and computing the gradients, but I believe I’ve only multithreaded building the autodiff stack, not computing the gradients. Since I believe that computing the gradient is the slow step, I’d appreciate some help/insight into multithreading the gradient calculation (if it’s even possible).
I’m using stan v2_35_0a
and stan math v4_9_0a
.
My (log) likelihood query is the sum over ~10 sub-samples, each of which involving a nontrivial likelihood sub-query. To start, I tried using reduce_sum
to split the sub-queries across two threads. Here’s some psuedocode:
struct LogLikelihoodEvaluator {
stan::math::vari_value<double> operator()(
const std::vector<stan::math::var>& dummy, int start, int end, std::ostream* messages) const {
// dummy is unsued.
// Iterate over sub-samples with index between start and end,
// then return the sum of log likelihood over these sub-samples.
}
};
stan::math::vari_value<double> GetLogLikelihood(/*...*/) {
// ... Then to perform the master log likelihood query with two threads...
stan::math::vari_value<double> ret = stan::math::reduce_sum<LogLikelihoodEvaluator>(dummy,subSamples.size()/2,&std::cout);
// where dummy is a vector of empty stan::math::vari_value<double>s with the same size
// as my number of sub-samples to tell reduce_sum how many sub-samples I have.
return ret; // Then pass the total log likelihood to stan::model::model_base_crtp::log_prob().
}
The result wasn’t faster than without reduce_sum
. Although this did successfully multithread (I can tell by printing thread IDs) building the autodiff stack, I believe it doesn’t actually multithread the gradient calculation because reduce_sum
collapses the autodiff stack singletons of the sub-samples into a single one, then the gradient calculation happens after log_prob()
anyway.
I’m not an expert, so my questions for experts are:
- Is my reasoning correct?
- Is there a clean way for me to also assign computing gradients to the threads within
reduce_sum
, or perhaps is there a different method I could use?
I appreciate any help or insights, and I can provide more details if it would help. Thank you!