Help multithreading HMC with reduce_sum

Hello! I’m using Stan’s C++ template library to perform Hamiltonian NUTS MCMC, but I’m hindered by a slow likelihood+gradient query. One of the tactics I’m trying out is reduce_sum to expedite building the autodiff stack and computing the gradients, but I believe I’ve only multithreaded building the autodiff stack, not computing the gradients. Since I believe that computing the gradient is the slow step, I’d appreciate some help/insight into multithreading the gradient calculation (if it’s even possible).

I’m using stan v2_35_0a and stan math v4_9_0a.

My (log) likelihood query is the sum over ~10 sub-samples, each of which involving a nontrivial likelihood sub-query. To start, I tried using reduce_sum to split the sub-queries across two threads. Here’s some psuedocode:

struct LogLikelihoodEvaluator {
  stan::math::vari_value<double> operator()(
    const std::vector<stan::math::var>& dummy, int start, int end, std::ostream* messages) const {
      // dummy is unsued.
      // Iterate over sub-samples with index between start and end,
      // then return the sum of log likelihood over these sub-samples.
    }
};

stan::math::vari_value<double> GetLogLikelihood(/*...*/) {
  // ... Then to perform the master log likelihood query with two threads...
  stan::math::vari_value<double> ret = stan::math::reduce_sum<LogLikelihoodEvaluator>(dummy,subSamples.size()/2,&std::cout);
  // where dummy is a vector of empty stan::math::vari_value<double>s with the same size
  //   as my number of sub-samples to tell reduce_sum how many sub-samples I have.
  return ret; // Then pass the total log likelihood to stan::model::model_base_crtp::log_prob().
}

The result wasn’t faster than without reduce_sum. Although this did successfully multithread (I can tell by printing thread IDs) building the autodiff stack, I believe it doesn’t actually multithread the gradient calculation because reduce_sum collapses the autodiff stack singletons of the sub-samples into a single one, then the gradient calculation happens after log_prob() anyway.

I’m not an expert, so my questions for experts are:

  1. Is my reasoning correct?
  2. Is there a clean way for me to also assign computing gradients to the threads within reduce_sum, or perhaps is there a different method I could use?

I appreciate any help or insights, and I can provide more details if it would help. Thank you!