Algebra solver has a side effect on log probability


#1

Hi,

I’m trying to to estimate a structural model with a simple Bernoulli likelihood with a probability depending on the solution of an implicit function with no closed form solution. Therefore, I’m trying algebra_solver() for the first time.

As I learn about how doable this is I’m starting with some very simple fake data and I simplified the implicit function I need to solve to something I can actually solve without algebra_solver() (I dropped the nonlinear component). In other words, I can fit the model fine without using algebra_solver(), so I tried to fit the model with and without it. In fact, when I tried calling algebra_solver() I ignored the results and used the closed form solution I can calculate. I expected both runs to produce the same outcome (perhaps a lot slower when I use algebra_solver()). What is surprising is that when I comment out the line where algebra_solver() is called the model is estimated nicely (no divergent transitions, R hat < 1.1 and parameters close to the true value), but when I leave the algebra_solver() call in there (even though I am not actually using its output) the model fit has problems (30% divergent transitions, 25% of parameters have R hat > 1.1, but the parameters are still well estimated).

Perhaps I don’t understand how algebra_solver() works, but is it normal for it to have a side effect on the MCMC estimation when I don’t even update the log probability with its output?

I’m attaching my .stan file. The statement at line 133 is the one that changes things.

structural_dist_model.stan (5.9 KB)

I appreciate any help with this!


#2

@karimn Thanks for the heads up!

I suspect that what might be happening is the algebra solver is failing to converge and causing the MCMC iteration to fail and this is getting registered as a divergence and preventing the model from exploring in a certain direction.

So it’s not that autodiff would be messed up, it’d be that the algebra solver (which is being asked to run even though its results aren’t used) is telling MCMC “nonono, don’t go in that direction, I’m having trouble finding finding solutions”.

To verify this, could you make the algebra solve something that you know will always succeed? Like maybe just “solve 0 == 0”? So we know that it isn’t this?

edit: Enhanced the anthropomorphism


#3

Thanks for the quick response! I tried to create a function

  vector test_implicit(vector v_cutoff, vector param, real[] x_r, int[] x_i) {
    return rep_vector(0, 1);
  }

But then I get a lot of these

Rejecting initial value:
  Gradient evaluated at the initial value is not finite.
  Stan can't start sampling from this initial value.

Initialization between (-2, 2) failed after 100 attempts. 

Is this what you meant by a 0 == 0 test?


#4

Something like that, haha. But then I guess the gradients of the solutions with respect to the parameters don’t exist. How about something like:

vector test_implicit(vector v_cutoff, vector param, real[] x_r, int[] x_i) {
  vector[2] z;
  z[1] = y[1] - theta[1];
  z[2] = y[2] - theta[2];
  return z;
}

Or something similar? We just want a set of eqs. we know has a solution that’ll be super easy on the solver.


#5

Ok, I see what you mean. I actually went back to my original implicit function which was simplified so a solution would be trivial to find. What I was doing wrong before was passing in some parameters that were no longer needed so that probably drove MCMC nuts (bad crazy nuts, not good sampler nuts).

I reran without the unused parameters and I’m down to one divergent transition, verifying what you said.

My mistake is that I was thinking of the algebra_solver() solver like any math operation which should have no side-effect on the sampler. I neglected to consider the problem of failing to converge to a solution.

Thanks Ben for helping out with this. I’ll continue taking baby steps towards the nonlinear form of the implicit function I want solved.


#6

Cool beans! That is an interesting find! Thanks for reporting back

@charlesm93 Do you know of a way this could be diagnosed at runtime?