I’ve been updating ctsem (hierarchical state space stuff) to use pre-computed jacobians rather than finite differences via a loop in stan, but along the way I seem to have done something that has broken gradient calculations. There are now certain model structures (like a state dependent variance term) that give NaN’s for some gradient elements, but finite diff gradients seem to work. The code is, I think, too much to expect anyone else to dig through, but I’m wondering if there are known problematic operations / coding errors that might cause something like this – functional log prob and finite diff grads, but broken autodiff grads?
This combination would imply that the gradients were implemented incorrectly, or perhaps that the implementation isn’t as numerically stable as the function implementation itself. This can sometimes happen with iterative algorithms, such as those defined by power series (i.e. naively differentiating through the power series instead of figuring out an accurate power series for the derivative directly).
Thanks – though just to be clear, the jacobians in question are used in a regular stan program (not custom c++), ie for the log_prob. The jacobian of log prob with respect to parameters is still all autodiff. At present the problem is ‘solved’ by filling certain matrices at a point when, so far as I can see, they shouldn’t need to be filled (and clearly don’t need to be for the log_prob). Maybe I’ve stumbled into a stan bug? am investigating…
Ok I found the problem. I’m not sure where the blame lies, probably with me, but it also seems like somewhat unnecessary / problematic behaviour from Stan. Basically:
for i in rows{
A = -.3 * state ; // A is still NaN after this when i = 1 because state is NaN
if(i==1) initialise state
if(i>1) do stuff with A and state
target+= f(state-somedata)
}
In this form, the likelihood is properly evaluated and some gradients are correct, while some are NaN.
Removing the first calculation, which is just a pointless computation involving some NaN’s that nothing was dependent on, ensures that all gradients are properly computed.
I thought Stan generally did all the hard work in the background to figure out which calculations determine the gradient, but here it seems to be including spurious calculations and breaking because of it. Obviously, it’s silly behaviour in my code also, but this was pretty hard to track down – mostly because I thought this wasn’t a problem.
I do a lot of re-using matrices in this way, is it somehow cleaner to create new local instances when the previous calculations for a particular matrix become irrelevant?
The problem is that these extraneous nodes contribute nothing only because the partial derivatives should be zero and we get an adjoint update of the form
a_{n} = a_{n} + partial * a_{n + 1}
= a_{n} + 0 * a_{n + 1} = a_{n}
But if a_{n}
is NaN
then this behavior is longer guaranteed as 0 * NaN
is undefined. On most systems 0 * NaN
evaluates to NaN
and propagates that NaN
through the rest of the reverse mode sweep.
Yes.
Ok. I guess I imagined the extraneous nodes would be somehow pruned because of no dependence before the next time the element was set.
So, just to understand – if the computation of A is costly, is it more, less, or similarly efficient to declare A within the loop instead of before?
If A
is constant then you definitely want to compute it outside of the loop, per usual computational advice.
What may change in the autodiff context is when you have a big hunk of memory that you try to reuse by filling with new values at each iteration of the loop. There you don’t actually save that much memory (at every iteration you’ll create implicit autodiff variables) and you might run into problems with unused NaN
s causing problems.
If I understand your problem, that is.
Cheers, yes A definitely has to be computed within the loop, just a question of where it should be declared. If the main benefit of declaring within the loop is avoiding this NaN issue, I’ll stick with declaring it outside and just be aware that this is one reason gradients might go missing – this is much simpler when only a few elements of A are changing at each step.
In general, it’s much much easier to reason about immutable structures that get initialized in a usable state and never change. As soon as you start modifying structures, you have to be much more careful about leaving it in a usable state (such as not having NaN values in it).
One thing you can do is put checks for NaN before the calls to diagnose problems, then later remove them for more efficiency.