I am very new to stan and the documentation seems a bit overwhelming at first. I am interested in the NUTS algorithm in stan/src/stan/mcmc/hmc/nuts in the stan-dev/stan repo. I want to know how to compile the NUTS algorithm and how to use it. My end goal here is to profile the NUTS algorithm for performance bottlenecks.
If anyone can explain how this works and/or point to specific documentation, that’ll be great. Thanks!
There is no documentation on profiling NUTS, but you can compile a model and then use
perf to view time spent calling certain functions (or sets of instructions). You can use something like heaptrack for memory allocations.
Doing this in the past we’ve found most of the time is spent inside of the gradient evaluations for NUTS.
As @stevebronder says for anything but the simplest target density functions the cost of running dynamic Hamiltonian Monte Carlo sampler in Stan is dominated by the autodiff evaluations of that target density function (to get the density and the gradient).
If you wanted to dig into it deeper the key is to construct a C++ class that follows the policy of Stan’s
log_prob object (really all you need is the log_prob and grad_lob_prob calls but there are some subtleties there). Then you can template either the sampler directly or through the service routes on that class and compile everything independently of the rest of Stan (language, math library, etc).