What if in the same way we’re talking about wrapping functions with performance counters we provide a utility to wrap functions with nan/inf checks?
So like say we have some Stan code that looks like:
y ~ normal(X * b + intercept, sigma);
That vaguely translates in C++ to something like:
target += normal_lpdf(y, X * b + intercept, sigma, msgs);
Couldn’t we write a C++ function that wraps this like:
target += search_for_nans_and_infs(normal_lpdf, y, X * b + intercept, sigma, msgs);
that will effectively just call normal_lpdf(y, X * b + intercept, sigma, msgs)
, but it can search all the inputs for nans and infs on call, and then it can throw a vari
on the stack so that it can search for nans and infs in the gradients on the way back.
If it finds a NaN in the values or gradients, then it can print all the argument values and stuff out so that someone can figure out what’s going on where in their problem.
This thread had me thinking about this: Initial value rejected, but the likelihood exists for that initial value . In this example the gradients blew up and it took a long time to track down (“I spent weeks with this :/”).
The only error that Stan printed was:
Chain 1: Rejecting initial value:
Chain 1: Gradient evaluated at the initial value is not finite.
Chain 1: Stan can't start sampling from this initial value.
which is like correct, but it doesn’t give info on how to dig down further. There’s some unresolved difficulty in the thread of how to resolve the problem (which this wouldn’t address), but I think we could make it easier to find the problems.
Maybe in the future we could do something like:
Chain 1: Gradient evaluated at the initial value is not finite.
Chain 1: Rerun model with --debug_infs to get more information
And then we could automatically instrument all the function calls with nan/inf check varis and print debugging information.
It could be like:
Adjoint of alpha is -inf in y = beta_lpdf(x = 0.7, alpha = 100, beta = 200) when
adjoint of y is 1.7.
Maybe we’d have a default argument that says stop after printing the first 5 errors to avoid blowing up everyone’s terminals with the output.
How useful would something like this be? Numeric problems are hard to track down because checks aren’t comprehensive – sometimes an overflow is just fine. So maybe if we have a way to turn on/off some pedantic checking we could at least make it easy to find the problems.