How to use stan::math::ChainableStack?

Hi,

I’m wondering how to use stan::math::ChainableStack. For context im using RcppParallel and calling autodiff from it which uses stan::math::var’s.

stan::math::recover_memory() clears the global AD stack afaik so its not thread safe.

I have the appropriate STAN_THREADS flag set (i.e., -DSTAN_THREADS -pthread ).

For example, can I do something like:

/// local AD block
{
thread_local stan::math::ChainableStack ad_tape;

/// do stuff with stan::math::var’s

// do I need something similar to recover_memory here but a thread-safe version? if so what is it??

} // end of local AD block

Thanks.

recover_memory is thread safe if STAN_THREADS is used, and is the right thing to call

There are also some RAII options like nested_rev_autodiff, which is what Stan::math::gradient uses

1 Like

Thanks! and do i need to use thread_local or is it redundant?

I believe somewhere on the thread does need to create a thread_local instance, yes. You can make it static if you’re worried about it being created multiple times

2 Likes

can i use nested_rev_autodiff in combo with stan::math::ChainableStack ?

Yes. Creating a thread local ChainableStack prepares the thread for auto diff, while the nested_rev_autodiff class is helpful for a specific scope

2 Likes

Thanks!

So for example would this be right:

So using RcppParallel:


  // RcppParallel Parallel operator
  void operator()(std::size_t begin, std::size_t end) {


    for (std::size_t i = begin; i < end; ++i) { // iterating over COLUMNS (each col is a chain)


     {   ////////// local block which calls Stan math AD fn 

        stan::math::ChainableStack ad_tape; // each thread will get its own AD stack??

         Eigen::Matrix<double, -1, 1>  lp_grad_vec = my_AD_fn(theta, ...)

     } // end of local block

}

}

And then, within “my_AD_fn” I call stan::math::recover_memory() at the end. Or, should I use start_nested() at the beginning of the function with a recover_memory_nested() at the end? or does it not really matter either way?

Warning: I am not a RCPP/RCppParallel expert by any means. But, if you have some function that will be called from many threads, an “easy” recipe is to start the function with

static thread_local stan::math::ChainableStack ad_tape;
stan::math::nested_rev_autodiff nested;

This will guarantee both that the autodiff tape is initialized for the thread and that memory will be cleaned up after the function ends, even if e.g. an exception gets thrown (this is the main advantage over manually calling one of the recover_ methods)

1 Like

oh I thought nested AD is only for when you have “AD within AD” ? didnt know you could use it more generally

That was its original purpose, but it’s safe to use generally. The stan::math::gradient functor uses it for exactly this

1 Like