Yeah, my bad, I left this part out.
I think what needs to happen is this code:
stacks_ = allocate_chainable_stacks(J);
for(int j = 0; j < J; j++) {
stacks_[j].execute(std::bind(func, individual_params[j]));
}
needs changed to something like:
stacks_ = allocate_chainable_stacks(J);
// These new vars will act like the proxies
// We don't use the vars from the original stack cause that isn't threadsafe
std::vector<var> individual_params_allocated_on_new_stack(J);
for(int j = 0; j < J; j++) {
individual_params_allocated_on_new_stack[j] = copy_to_new_stack(stacks_[j], individual_params[j]);
stacks_[j].execute(std::bind(func, individual_params_allocated_on_new_stack[j]));
}
So we copy the input arguments into varis allocated on the new stack. Because of this, we’d need to be able to know which inputs are varis or not (so we don’t allocate vars for things that aren’t vars).
There’s a variable for this in the actual adj_jac_apply implementation, but that abstraction is probably the wrong thing. I can rewrite this as a vari pretty easily if that would help (otherwise I’ll skip that just to avoid dumping more code here).
These copies shouldn’t be that big. It’s just the number of inputs + outputs.
Good call. But a simple solution to this is to just keep track of all the stacks we allocate?
That could be embedded in:
stacks_ = allocate_chainable_stacks(J);
And when we zero’ed the main stack, we could just loop through and zero each of those.
Yeah, my reasoning for this was at first we’ll only have a couple signatures (map_rect + map_reduce or whatever goes in with this). Eventually everything will presumably get passed to us as some sort of stan_language_closure_functor object, at which point ever function will have the same signature.
Yeah but this goes in that direction.
Like, map_rect is awkward because beforehand you chunk up your J job items into G groups.
But really we want to leave all J things to be parallelized however they may be, right? This is compatible with that idea.
So we’d want the allocation of these autodiff stacks to be fast enough that we could allocate one for every task we did.
I’ll assert without evidence that I think this can be fast enough. Though presumably if we allocate each thread a stack at the beginning and it just re-uses it for different tasks it would be faster.
I’ll argue that isn’t necessary :D. But that’s more cause I want a simple MPI implementation.
I think the ticket to making nested parallelism work would be making sure:
stacks_ = allocate_chainable_stacks(J);
is threadsafe.
I think this should be equally okay in either implementation? The difference here is that there’d be more ScopedChainableAutodiffs and no copying autodiff stacks after a task finished. The operation of the ScopedChainableAutodiffs themselves would otherwise stay the same.
We’d have to wait on tasks to finish before we copied autodiff stacks back to the main stack too, so I think we’re nervous about these either way.
Nah nah, good critique.
It’s also compatible with a possible future MPI implementation (where copying things back would not be an option).